Unverified Commit c28ad199 authored by Peng Zhang's avatar Peng Zhang Committed by GitHub
Browse files

[1/n] chore: decouple quantization implementation from vLLM dependency (#7992)

parent 570d3343
from contextlib import contextmanager
from typing import Any, Dict, Optional
import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts,
get_config_file_name,
moe_align_block_size,
try_get_optimal_moe_config,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE,
......@@ -37,4 +38,6 @@ __all__ = [
"fused_moe",
"fused_experts",
"get_config_file_name",
"moe_align_block_size",
"try_get_optimal_moe_config",
]
......@@ -22,10 +22,6 @@ try:
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config,
)
......@@ -59,7 +55,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import (
GPTQConfig,
GPTQLinearMethod,
GPTQMarlinConfig,
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.quantization.modelopt_quant import (
......
import logging
from dataclasses import dataclass
from fractions import Fraction
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from sglang.srt.layers.linear import LinearBase, set_weight_attrs
from sglang.srt.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs
from sglang.srt.layers.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter,
permute_param_layout_,
)
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.utils import replace_parameter
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
from sglang.srt.layers.quantization.marlin_utils import (
apply_gptq_marlin_linear,
check_marlin_supported,
check_marlin_supports_shape,
marlin_is_k_full,
marlin_make_empty_g_idx,
marlin_make_workspace,
marlin_moe_permute_scales,
marlin_permute_scales,
marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx,
marlin_zero_points,
verify_marlin_supported,
)
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.srt.layers.quantization.utils import replace_parameter, unpack_cols
try:
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
GPTQMarlinLinearMethod,
marlin_moe_permute_scales,
)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
)
from vllm.scalar_type import scalar_types
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
ops = None
GPTQLinearMethod = MarlinLinearMethod = Any
from sglang.srt.utils import is_cuda
FusedMoEMethodBase = QuantizeMethodBase
_is_cuda = is_cuda()
class scalar_types:
uint4b8 = "uint4b8"
uint8b128 = "uint8b128"
if _is_cuda:
from sgl_kernel import fused_marlin_moe
FusedMoEMethodBase = QuantizeMethodBase
logger = logging.getLogger(__name__)
......@@ -54,6 +62,38 @@ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
)
def gptq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty(
(num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = torch.ops.sgl_kernel.gptq_marlin_repack(
b_q_weight[e], perm[e], size_k, size_n, num_bits
)
return output
@dataclass
class MarlinLinearLayerConfig:
full_weight_shape: tuple[int, int] # [in, out]
partition_weight_shape: tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
zero_points: bool
has_g_idx: bool
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
......@@ -151,11 +191,16 @@ class GPTQConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[GPTQLinearMethod]:
) -> Optional["LinearMethodBase"]:
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import get_linear_quant_method
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
if isinstance(layer, LinearBase):
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
elif isinstance(layer, FusedMoE):
raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
return None
class GPTQMarlinConfig(QuantizationConfig):
......@@ -313,14 +358,6 @@ class GPTQMarlinConfig(QuantizationConfig):
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
# TODO: re-enable after SGLang syncs with vllm >= 0.7.3
# if layer.num_experts > 32:
# # For MoEs with many experts the moe_wna16 kernel is faster
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
# layer, prefix
# )
# else:
# return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
@classmethod
......@@ -344,112 +381,439 @@ class GPTQMarlinConfig(QuantizationConfig):
if (num_bits, sym) not in cls.TYPE_MAP:
return False
assert (
VLLM_AVAILABLE
), "vllm is not installed, to use gptq_marlin, please install vllm"
return check_marlin_supported(
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
)
class MarlinConfig(QuantizationConfig):
"""Config class for Marlin.
class GPTQLinearMethod(LinearMethodBase):
"""Linear method for GPTQ.
Reference: https://github.com/IST-DASLab/marlin/tree/master
Args:
quant_config: The GPTQ quantization config.
"""
def __init__(
def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config
def create_weights(
self,
group_size: int,
lm_head_quantized: bool,
) -> None:
# Group size for the quantization.
self.group_size = group_size
self.lm_head_quantized = lm_head_quantized
if self.group_size != 128 and self.group_size != -1:
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin, but got group_size of "
f"{self.group_size}"
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
self.use_shuffle = True
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (
input_size != input_size_per_partition
and self.quant_config.group_size != -1
):
if self.quant_config.desc_act:
self.use_shuffle = False
else:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
# Tile size used by marlin kernels.
self.tile_size = 16
g_idx = RowvLLMParameter(
data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader,
)
qzeros_args = {
"data": torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader": weight_loader,
}
weight_scale_args = {
"data": torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader": weight_loader,
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
# Min out_features dim
self.min_n_threads = 64
else:
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
# Min in_features dim
self.min_k_threads = 128
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = torch.nn.Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if self.use_shuffle:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty(
(0,), dtype=torch.int, device=layer.g_idx.device
)
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
reshaped_x = x.reshape(-1, x.shape[-1])
output = ops.gptq_gemm(
reshaped_x,
layer.qweight,
layer.qzeros,
layer.scales,
layer.g_idx,
self.use_shuffle,
self.quant_config.weight_bits,
)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
def __repr__(self) -> str:
return (
f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})"
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
# Verify supported on platform.
verify_marlin_supported(
quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size,
)
@classmethod
def get_name(cls) -> str:
return "marlin"
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
self.kernel_config = MarlinLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=(
input_size_per_partition,
output_size_per_partition,
),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act,
)
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
# Determine sharding
if marlin_repeat_scales_on_all_ranks(
self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
else:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size
# Quantized weights
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
# Activation order
g_idx = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader,
)
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
qzeros_args = {
"data": torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader": weight_loader,
}
weight_scale_args = {
"data": torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader": weight_loader,
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
return cls(group_size, lm_head_quantized)
else:
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
is_marlin_format = check_marlin_format(hf_quant_cfg)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
is_valid_user_quant = (
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, "qweight").device
c = self.kernel_config
check_marlin_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size,
)
if is_marlin_format and is_valid_user_quant:
msg = "The model is serialized in {} format. Using {} kernel.".format(
cls.get_name(), cls.get_name()
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace(device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
self.w_q_name = "qweight"
self.w_s_name = "scales"
self.w_zp_name = "qzeros"
self.w_gidx_name = "g_idx"
def _transform_param(
layer: torch.nn.Module, name: Optional[str], fn: Callable
) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = torch.ops.sgl_kernel.gptq_marlin_repack(
x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
)
logger.info(msg)
return cls.get_name()
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(
x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size,
)
return x
return None
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name)
)
_transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[MarlinLinearMethod]:
# Delay the import to avoid circular dependency
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
if c.zero_points:
grouped_k = (
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
)
_transform_param(
layer,
self.w_zp_name,
lambda x: marlin_zero_points(
unpack_cols(
x.t(),
c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1],
),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
),
)
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
_transform_param(layer, self.w_q_name, transform_w_q)
_transform_param(layer, self.w_s_name, transform_w_s)
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
return MarlinLinearMethod(self)
return None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
c = self.kernel_config
def _get_weight_params(
layer: torch.nn.Module,
) -> tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor], # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)
w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
weight_scale=w_s,
weight_zp=w_zp, # type: ignore
g_idx=w_gidx, # type: ignore
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=self.workspace,
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
bias=bias,
)
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
......@@ -467,6 +831,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
intermediate_size = extra_weight_attrs.pop("intermediate_size")
self.is_k_full = (not self.quant_config.desc_act) or (
......@@ -644,20 +1011,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
requires_grad=False,
)
# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
marlin_w13_qweight = gptq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
marlin_w2_qweight = gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales
......@@ -698,13 +1065,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
assert (
scoring_func == "softmax"
), "Only softmax score func is supported for now."
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -713,11 +1086,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
correction_bias=e_score_correction_bias,
)
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
......@@ -730,6 +1102,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
quant_type_id=self.quant_config.quant_type.id,
num_bits=self.quant_config.weight_bits,
is_k_full=self.is_k_full,
).to(orig_dtype)
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py
import logging
from typing import Any, Optional
import numpy
import torch
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import get_device_capability
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
logger = logging.getLogger(__name__)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# In case there is a performance issue with Marlin, the variable below can be
# changed to False, which allows Marlin to perform global reductions in fp16
# precision (instead of fp32), and therefore, save on some memory movements.
USE_FP32_REDUCE_DEFAULT = True
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(
has_zp: Optional[bool] = None,
include_fp_type: bool = True,
device_capability: Optional[int] = None,
):
if device_capability is None:
major, minor = get_device_capability()
capability = major * 10 + minor
device_capability = -1 if capability is None else capability
if device_capability < 80:
return []
# - has_zp is True: return quant_types that has zero points
# - has_zp is False: return quant_types that has not zero points
# - has_zp is None: both
if has_zp is None:
types0 = query_marlin_supported_quant_types(
False, include_fp_type, device_capability
)
types1 = query_marlin_supported_quant_types(
True, include_fp_type, device_capability
)
return types0 + types1
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4]
else:
# GPTQ style, unsigned + symmetric bias
res = [scalar_types.uint4b8, scalar_types.uint8b128]
if include_fp_type:
res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
return res
def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None,
) -> tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = get_device_capability()
capability = major * 10 + minor
device_capability = -1 if capability is None else capability
supported_types = query_marlin_supported_quant_types(
has_zp, True, device_capability
)
if quant_type not in supported_types:
return (
False,
f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).",
)
if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return (
False,
f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.",
)
return True, None
def check_marlin_supported(
quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
device_capability: Optional[int] = None,
) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
return cond
def verify_marlin_supported(
quant_type: ScalarType, group_size: int, has_zp: bool = False
) -> None:
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
if not cond:
assert err_msg is not None
raise ValueError(err_msg)
def verify_marlin_supports_shape(
output_size_per_partition: int,
input_size_per_partition: int,
input_size: int,
group_size: int,
) -> None:
# Validate output_size_per_partition
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
# Validate input_size_per_partition
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
if group_size < input_size and input_size_per_partition % group_size != 0:
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
def check_marlin_supports_shape(
output_size_per_partition: int,
input_size_per_partition: int,
input_size: int,
group_size: int,
) -> tuple[bool, Optional[str]]:
try:
verify_marlin_supports_shape(
output_size_per_partition, input_size_per_partition, input_size, group_size
)
except ValueError as e:
return False, e.__str__()
return True, None
def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
output_size_per_partition = (
getattr(layer, "output_size_per_partition", None) or layer.output_size
)
input_size_per_partition = (
getattr(layer, "input_size_per_partition", None) or layer.input_size
)
return check_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=layer.input_size,
group_size=group_size,
)[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin
supports_router_weight = not layer.apply_router_weight_on_input
# moe marlin requires the activation to be silu
supports_activation = layer.activation == "silu"
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
supports_shape = (
hidden_size % 128 == 0
and intermediate_size_per_partition % max(64, group_size) == 0
)
supports_group_size = group_size in [-1, 32, 64, 128]
return (
supports_shape
and supports_group_size
and supports_router_weight
and supports_activation
)
def marlin_make_workspace(
device: torch.device, max_blocks_per_sm: int = 1
) -> torch.Tensor:
# In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
sms = torch.cuda.get_device_properties(device).multi_processor_count
return torch.zeros(
sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False
)
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
def marlin_repeat_scales_on_all_ranks(
act_order: bool, group_size: int, is_row_parallel: bool
) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
)
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
)
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
def get_scale_perms():
scale_perm: list[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: list[int] = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def marlin_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int
) -> torch.Tensor:
scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
size_n: int,
group_size: int,
):
num_experts = s.shape[0]
output = torch.empty(
(num_experts, s.shape[1], s.shape[2]),
device=s.device,
dtype=s.dtype,
)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
return output
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
scale_perm, _ = get_scale_perms()
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp
def moe_awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
):
num_experts = q_zp_packed.shape[0]
output = torch.empty(
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
device=q_zp_packed.device,
dtype=q_zp_packed.dtype,
)
for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
return output
def maybe_warn_marlin_atomic_add(device, dtype):
if torch.compiler.is_dynamo_compiling():
return
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
logger.info_once(
"You are running Marlin kernel with bf16 on GPUs before SM90. "
"You can consider change to fp16 to achieve better performance "
"if possible."
)
def maybe_warn_marlin_atomic_add_env():
if torch.compiler.is_dynamo_compiling():
return
# TODO(yiyun): Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
if True:
return
# if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
# return
logger.info_once(
"Marlin kernel can achieve better performance for small size_n "
"with experimental use_atomic_add feature. "
"You can consider set environment variable "
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible."
)
def should_use_atomic_add_reduce(
m: int, n: int, k: int, device: torch.device, dtype: torch.dtype
) -> bool:
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
if n >= 2048 or k < 2048 or device.type != "cuda":
return False
# disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
# TODO: Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
if not True:
maybe_warn_marlin_atomic_add_env()
return False
# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
maybe_warn_marlin_atomic_add(device, dtype)
return False
return True
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
wtype: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
wtype,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def apply_awq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
class MarlinConfig(QuantizationConfig):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""
def __init__(
self,
group_size: int,
lm_head_quantized: bool,
) -> None:
super().__init__()
# Group size for the quantization.
self.group_size = group_size
self.lm_head_quantized = lm_head_quantized
if self.group_size != 128 and self.group_size != -1:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin, but got group_size of "
f"{self.group_size}"
)
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4
# Tile size used by marlin kernels.
self.tile_size = 16
# Min out_features dim
self.min_n_threads = 64
# Min in_features dim
self.min_k_threads = 128
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return (
f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})"
)
@classmethod
def get_name(cls) -> str:
return "marlin"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
return cls(group_size, lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = hf_quant_cfg.get(
"checkpoint_format"
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
is_valid_user_quant = (
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
)
if is_marlin_format and is_valid_user_quant:
msg = "The model is serialized in {} format. Using {} kernel.".format(
cls.get_name(), cls.get_name()
)
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
return MarlinLinearMethod(self)
return None
class MarlinLinearMethod(LinearMethodBase):
"""Linear method for Marlin.
Args:
quant_config: The Marlin quantization config.
"""
def __init__(self, quant_config: MarlinConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}"
)
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}."
)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}."
)
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}."
)
if (
self.quant_config.group_size != -1
and input_size_per_partition % self.quant_config.group_size != 0
):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2
)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError("Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition
* self.quant_config.tile_size
// self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader,
)
# Determine if channelwise or not
input_groups = (
1
if self.quant_config.group_size == -1
else input_size_per_partition // self.quant_config.group_size
)
weight_scale_args = {
"data": torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
"weight_loader": weight_loader,
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
else:
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition // self.quant_config.min_n_threads
) * self.quant_config.max_parallel
workspace = BasevLLMParameter(
data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int),
weight_loader=weight_loader,
)
layer.register_parameter("B", qweight)
layer.register_parameter("s", scales)
layer.register_parameter("workspace", workspace)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B = torch.nn.Parameter(layer.B.data, requires_grad=False)
layer.s = torch.nn.Parameter(layer.s.data, requires_grad=False)
layer.workspace = torch.nn.Parameter(layer.workspace.data, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.B
scales = layer.s
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.marlin_gemm(
x_2d, qweight, scales, workspace, size_m, size_n, size_k
)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
if bias is not None:
output.add_(bias) # In-place add
return output
......@@ -19,6 +19,36 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__)
def get_weight_perm(num_bits: int):
perm_list: List[int] = []
for i in range(32):
perm1: List[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list)
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
class MoeWNA16Config(QuantizationConfig):
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from typing import Optional
import numpy
import torch
from sgl_kernel.scalar_type import ScalarType
def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def quantize_weights(
w: torch.Tensor,
quant_type: ScalarType,
group_size: Optional[int],
zero_points: bool = False,
ref_zero_points_after_scales: bool = False,
):
assert (
quant_type.is_integer()
), "Floating point quantization may work but has not been tested"
assert not zero_points or group_size is not None, (
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
if group_size == -1:
group_size = size_k
# Reshape to [groupsize, -1]
if group_size is not None and group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
max_q_val = quant_type.max()
min_q_val = quant_type.min()
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
maybe_w_zp = None
if group_size is not None:
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = (
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
)
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
)
# Quantize
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and maybe_w_zp is not None:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias
# Restore original shapes
if group_size is not None and group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
w_s = w_s.reshape((-1, size_n)).contiguous()
if maybe_w_zp is not None:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s if group_size is not None else None,
maybe_w_zp,
)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from types import MappingProxyType
from typing import List, Mapping, Tuple, Union
from typing import List, Mapping, Optional, Tuple, Union
import numpy
import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.scalar_type import ScalarType
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
_is_cuda = is_cuda()
......@@ -143,3 +145,162 @@ def replace_parameter(
if not isinstance(new, torch.nn.Parameter):
new = torch.nn.Parameter(new, requires_grad=False)
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
def quantize_weights(
w: torch.Tensor,
quant_type: ScalarType,
group_size: Optional[int],
zero_points: bool = False,
ref_zero_points_after_scales: bool = False,
):
assert (
quant_type.is_integer()
), "Floating point quantization may work but has not been tested"
assert not zero_points or group_size is not None, (
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
if group_size == -1:
group_size = size_k
# Reshape to [groupsize, -1]
if group_size is not None and group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
max_q_val = quant_type.max()
min_q_val = quant_type.min()
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
maybe_w_zp = None
if group_size is not None:
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = (
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
)
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
)
# Quantize
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and maybe_w_zp is not None:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias
# Restore original shapes
if group_size is not None and group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
w_s = w_s.reshape((-1, size_n)).contiguous()
if maybe_w_zp is not None:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s if group_size is not None else None,
maybe_w_zp,
)
......@@ -2,10 +2,11 @@ import functools
from typing import Optional
import torch
from sgl_kernel.scalar_type import scalar_types
def get_scalar_type(num_bits: int, has_zp: bool):
from sglang.srt.layers.quantization.scalar_type import scalar_types
if has_zp:
assert num_bits == 4
return scalar_types.uint4
......
import math
import numpy as np
import pytest
import torch
from sgl_kernel import awq_marlin_repack
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.quant_utils import (
from sglang.srt.layers.quantization.scalar_type import scalar_types
from sglang.srt.layers.quantization.utils import (
get_pack_factor,
pack_cols,
quantize_weights,
......
......@@ -51,13 +51,12 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
model_config=model_config, load_config=load_config, device_config=device_config
)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.quantization.gptq import (
GPTQLinearMethod,
GPTQMarlinLinearMethod,
)
from sglang.srt.layers.linear import UnquantizedLinearMethod
linear_method_cls = (
GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod)
)
......@@ -162,7 +161,7 @@ class TestGPTQModelDynamicWithMarlin(CustomTestCase):
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--dtype", "float16"],
other_args=["--dtype", "bfloat16"],
)
@classmethod
......
import itertools
import sys
import unittest
import torch
sys.path.insert(0, "/home/hadoop-hmart-waimai-rank/vllm")
# from sglang.srt.layers.moe.topk import select_experts
from sgl_kernel import fused_marlin_moe
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
# from vllm.model_executor.layers. import select_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize,
)
from vllm.scalar_type import scalar_types
def stack_and_dev(tensors: list[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def torch_moe(a, w1, w2, score, topk, expert_map):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
C = torch.matmul(A, B) # [M, K]
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B, D = a.shape
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
# Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# Calculate routing
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
)
# Activation function
act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = per_token_quant_int8(act_out)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
)
# Apply routing weights and sum
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def marlin_fused_moe(
N, E, K, a, w1, w2, num_bits, group_size, act_order, score, topk, ep_size
):
quant_type = scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
if ep_size > 1:
local_e = E // ep_size
e_ids = torch.randperm(E, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((E,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
w_ref1_l = []
qweight1_l = []
scales1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
s1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(n=K)
quant_res = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n=N)
quant_res = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
topk_weights, topk_ids = fused_topk(a, score, topk, False)
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=a,
# router_logits=score,
# top_k=topk,
# num_expert_group=E,
# use_grouped_topk=False,
# renormalize=False,
# topk_group=None,
# )
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
global_num_experts=E,
expert_map=e_map,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
num_bits=num_bits,
is_k_full=True,
)
return marlin_output, torch_output
class TestW8A8Int8FusedMoE(unittest.TestCase):
DTYPES = [torch.float16]
M = [1, 16]
N = [128]
K = [256]
E = [4, 10]
TOP_KS = [2, 4]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
NUM_BITS = [4]
EP_SIZE = [1, 4]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w4a8_int8_fused_moe(
self, M, N, K, E, topk, block_size, dtype, seed, num_bits, ep_size
):
torch.manual_seed(seed)
a = torch.randn((M, K), dtype=dtype) / 10
# Generate int8 weights
w1_fp16 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2
w2_fp16 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2
score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode():
marlin_out, ref_out = marlin_fused_moe(
N=N,
E=E,
K=K,
a=a,
w1=w1_fp16,
w2=w2_fp16,
num_bits=num_bits,
group_size=-1,
act_order=False,
score=score,
topk=topk,
ep_size=ep_size,
)
# Check results
if (
torch.mean(
torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32))
)
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
> 0.1
):
print(f"marlin_out: {marlin_out}")
print(f"ref_out: {ref_out}")
print(
torch.mean(
torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32))
)
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
)
torch.testing.assert_close(marlin_out, ref_out, atol=2e-2, rtol=0)
def test_w4a8_int8_fused_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
self.NUM_BITS,
self.EP_SIZE,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
block_size=params[5],
dtype=params[6],
seed=params[7],
num_bits=params[8],
ep_size=params[9],
):
self._w4a8_int8_fused_moe(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)
import sgl_kernel
import torch
x = torch.randn(10, 10, device="cuda")
qweight = torch.randn(10, 10, device="cuda")
s1_scales = torch.randn(10, device="cuda")
input_scales = torch.randn(10, device="cuda")
s1_szeros = torch.randn(10, device="cuda")
input_sum = torch.randn(10, device="cuda")
output_buffer = torch.randn(10, device="cuda")
torch.ops.sgl_kernel.gemm_forward_cuda.default(
x, qweight, s1_scales, input_scales, s1_szeros, input_sum, output_buffer
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment