Unverified Commit c28ad199 authored by Peng Zhang's avatar Peng Zhang Committed by GitHub
Browse files

[1/n] chore: decouple quantization implementation from vLLM dependency (#7992)

parent 570d3343
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts, fused_experts,
get_config_file_name, get_config_file_name,
moe_align_block_size,
try_get_optimal_moe_config,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import ( from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE, FusedMoE,
...@@ -37,4 +38,6 @@ __all__ = [ ...@@ -37,4 +38,6 @@ __all__ = [
"fused_moe", "fused_moe",
"fused_experts", "fused_experts",
"get_config_file_name", "get_config_file_name",
"moe_align_block_size",
"try_get_optimal_moe_config",
] ]
...@@ -22,10 +22,6 @@ try: ...@@ -22,10 +22,6 @@ try:
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config, GPTQMarlin24Config,
) )
...@@ -59,7 +55,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ...@@ -59,7 +55,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import ( from sglang.srt.layers.quantization.gptq import (
GPTQConfig, GPTQConfig,
GPTQLinearMethod,
GPTQMarlinConfig, GPTQMarlinConfig,
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod, GPTQMarlinMoEMethod,
) )
from sglang.srt.layers.quantization.modelopt_quant import ( from sglang.srt.layers.quantization.modelopt_quant import (
......
import logging import logging
from dataclasses import dataclass
from fractions import Fraction from fractions import Fraction
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from sglang.srt.layers.linear import LinearBase, set_weight_attrs from sglang.srt.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs
from sglang.srt.layers.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter,
permute_param_layout_,
)
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.utils import replace_parameter from sglang.srt.layers.quantization.marlin_utils import (
from sglang.srt.utils import is_cuda apply_gptq_marlin_linear,
check_marlin_supported,
_is_cuda = is_cuda() check_marlin_supports_shape,
marlin_is_k_full,
marlin_make_empty_g_idx,
marlin_make_workspace,
marlin_moe_permute_scales,
marlin_permute_scales,
marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx,
marlin_zero_points,
verify_marlin_supported,
)
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.srt.layers.quantization.utils import replace_parameter, unpack_cols
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
GPTQMarlinLinearMethod,
marlin_moe_permute_scales,
)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
)
from vllm.scalar_type import scalar_types
VLLM_AVAILABLE = True
except ImportError: except ImportError:
VLLM_AVAILABLE = False ops = None
from sglang.srt.utils import is_cuda
GPTQLinearMethod = MarlinLinearMethod = Any _is_cuda = is_cuda()
FusedMoEMethodBase = QuantizeMethodBase if _is_cuda:
from sgl_kernel import fused_marlin_moe
class scalar_types:
uint4b8 = "uint4b8"
uint8b128 = "uint8b128"
FusedMoEMethodBase = QuantizeMethodBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -54,6 +62,38 @@ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool: ...@@ -54,6 +62,38 @@ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
) )
def gptq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty(
(num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = torch.ops.sgl_kernel.gptq_marlin_repack(
b_q_weight[e], perm[e], size_k, size_n, num_bits
)
return output
@dataclass
class MarlinLinearLayerConfig:
full_weight_shape: tuple[int, int] # [in, out]
partition_weight_shape: tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
zero_points: bool
has_g_idx: bool
class GPTQConfig(QuantizationConfig): class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ. """Config class for GPTQ.
...@@ -151,11 +191,16 @@ class GPTQConfig(QuantizationConfig): ...@@ -151,11 +191,16 @@ class GPTQConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional[GPTQLinearMethod]: ) -> Optional["LinearMethodBase"]:
# Delay the import to avoid circular dependency # Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import get_linear_quant_method from sglang.srt.layers.quantization import get_linear_quant_method
if isinstance(layer, LinearBase):
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
elif isinstance(layer, FusedMoE):
raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
return None
class GPTQMarlinConfig(QuantizationConfig): class GPTQMarlinConfig(QuantizationConfig):
...@@ -313,14 +358,6 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -313,14 +358,6 @@ class GPTQMarlinConfig(QuantizationConfig):
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self) return GPTQMarlinMoEMethod(self)
# TODO: re-enable after SGLang syncs with vllm >= 0.7.3
# if layer.num_experts > 32:
# # For MoEs with many experts the moe_wna16 kernel is faster
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
# layer, prefix
# )
# else:
# return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
@classmethod @classmethod
...@@ -344,112 +381,439 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -344,112 +381,439 @@ class GPTQMarlinConfig(QuantizationConfig):
if (num_bits, sym) not in cls.TYPE_MAP: if (num_bits, sym) not in cls.TYPE_MAP:
return False return False
assert (
VLLM_AVAILABLE
), "vllm is not installed, to use gptq_marlin, please install vllm"
return check_marlin_supported( return check_marlin_supported(
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
) )
class MarlinConfig(QuantizationConfig): class GPTQLinearMethod(LinearMethodBase):
"""Config class for Marlin. """Linear method for GPTQ.
Reference: https://github.com/IST-DASLab/marlin/tree/master Args:
quant_config: The GPTQ quantization config.
""" """
def __init__( def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config
def create_weights(
self, self,
group_size: int, layer: torch.nn.Module,
lm_head_quantized: bool, input_size_per_partition: int,
) -> None: output_partition_sizes: list[int],
# Group size for the quantization. input_size: int,
self.group_size = group_size output_size: int,
self.lm_head_quantized = lm_head_quantized params_dtype: torch.dtype,
if self.group_size != 128 and self.group_size != -1: **extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
raise ValueError( raise ValueError(
"Currently, only group size 128 and -1 (channelwise) " "The output size is not aligned with the quantized "
"is supported for Marlin, but got group_size of " "weight shape. This can be caused by too large "
f"{self.group_size}" "tensor parallel size."
)
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
self.use_shuffle = True
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (
input_size != input_size_per_partition
and self.quant_config.group_size != -1
):
if self.quant_config.desc_act:
self.use_shuffle = False
else:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
) )
# 4 Bits packed into 32 bit datatype. g_idx = RowvLLMParameter(
self.pack_factor = 32 // 4 data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader,
)
qzeros_args = {
"data": torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader": weight_loader,
}
weight_scale_args = {
"data": torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader": weight_loader,
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
# Tile size used by marlin kernels. else:
self.tile_size = 16 scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
# Min out_features dim layer.register_parameter("qweight", qweight)
self.min_n_threads = 64 layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
# Min in_features dim def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.min_k_threads = 128 # for torch.compile
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = torch.nn.Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if self.use_shuffle:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty(
(0,), dtype=torch.int, device=layer.g_idx.device
)
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
# Max parallel problems to solve at once (improves large def apply(
# batch performance) self,
self.max_parallel = 16 layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
reshaped_x = x.reshape(-1, x.shape[-1])
output = ops.gptq_gemm(
reshaped_x,
layer.qweight,
layer.qzeros,
layer.scales,
layer.g_idx,
self.use_shuffle,
self.quant_config.weight_bits,
)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def __repr__(self) -> str: class GPTQMarlinLinearMethod(LinearMethodBase):
return ( """Linear method for GPTQ Marlin.
f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})" Args:
quant_config: The GPTQ Marlin quantization config.
"""
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
# Verify supported on platform.
verify_marlin_supported(
quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size,
) )
@classmethod def create_weights(
def get_name(cls) -> str: self,
return "marlin" layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
self.kernel_config = MarlinLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=(
input_size_per_partition,
output_size_per_partition,
),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act,
)
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
@classmethod # Determine sharding
def get_supported_act_dtypes(cls) -> List[torch.dtype]: if marlin_repeat_scales_on_all_ranks(
return [torch.half] self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
else:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size
# Quantized weights
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
@classmethod # Activation order
# Need to figure it out g_idx = RowvLLMParameter(
def get_min_capability(cls) -> int: data=torch.empty(
return 80 input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader,
)
@classmethod qzeros_args = {
def get_config_filenames(cls) -> List[str]: "data": torch.empty(
return ["quantize_config.json"] scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader": weight_loader,
}
weight_scale_args = {
"data": torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader": weight_loader,
}
@classmethod if scales_and_zp_input_dim is None:
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
group_size = cls.get_from_keys(config, ["group_size"]) qzeros = PackedColumnParameter(
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) output_dim=1,
return cls(group_size, lm_head_quantized) packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
@classmethod else:
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: scales = GroupQuantScaleParameter(
is_marlin_format = check_marlin_format(hf_quant_cfg) output_dim=1, input_dim=0, **weight_scale_args
)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args,
)
is_valid_user_quant = ( layer.register_parameter("qweight", qweight)
user_quant is None or user_quant == "gptq" or user_quant == "marlin" layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, "qweight").device
c = self.kernel_config
check_marlin_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size,
) )
if is_marlin_format and is_valid_user_quant: row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
msg = "The model is serialized in {} format. Using {} kernel.".format( self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
cls.get_name(), cls.get_name()
# Allocate marlin workspace.
self.workspace = marlin_make_workspace(device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
self.w_q_name = "qweight"
self.w_s_name = "scales"
self.w_zp_name = "qzeros"
self.w_gidx_name = "g_idx"
def _transform_param(
layer: torch.nn.Module, name: Optional[str], fn: Callable
) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
) )
logger.info(msg)
return cls.get_name()
return None def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = torch.ops.sgl_kernel.gptq_marlin_repack(
x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(
x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size,
)
return x
def get_quant_method( if c.has_g_idx:
self, layer: torch.nn.Module, prefix: str g_idx, g_idx_sort_indices = marlin_sort_g_idx(
) -> Optional[MarlinLinearMethod]: getattr(layer, self.w_gidx_name)
# Delay the import to avoid circular dependency )
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead _transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if isinstance(layer, LinearBase) or ( if c.zero_points:
isinstance(layer, ParallelLMHead) and self.lm_head_quantized grouped_k = (
): c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
return MarlinLinearMethod(self) )
return None _transform_param(
layer,
self.w_zp_name,
lambda x: marlin_zero_points(
unpack_cols(
x.t(),
c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1],
),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
),
)
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
_transform_param(layer, self.w_q_name, transform_w_q)
_transform_param(layer, self.w_s_name, transform_w_s)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
c = self.kernel_config
def _get_weight_params(
layer: torch.nn.Module,
) -> tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor], # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)
w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
weight_scale=w_s,
weight_zp=w_zp, # type: ignore
g_idx=w_gidx, # type: ignore
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=self.workspace,
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
bias=bias,
)
class GPTQMarlinMoEMethod(FusedMoEMethodBase): class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...@@ -467,6 +831,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -467,6 +831,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
intermediate_size = extra_weight_attrs.pop("intermediate_size") intermediate_size = extra_weight_attrs.pop("intermediate_size")
self.is_k_full = (not self.quant_config.desc_act) or ( self.is_k_full = (not self.quant_config.desc_act) or (
...@@ -644,20 +1011,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -644,20 +1011,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
requires_grad=False, requires_grad=False,
) )
# Repack weights # Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack( marlin_w13_qweight = gptq_marlin_moe_repack(
layer.w13_qweight, layer.w13_qweight,
layer.w13_g_idx_sort_indices, layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor, layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2], layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits, self.quant_config.weight_bits,
) )
replace_parameter(layer, "w13_qweight", marlin_w13_qweight) replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack( marlin_w2_qweight = gptq_marlin_moe_repack(
layer.w2_qweight, layer.w2_qweight,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor, layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2], layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits, self.quant_config.weight_bits,
) )
replace_parameter(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales # Repack scales
...@@ -698,13 +1065,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -698,13 +1065,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
assert (
scoring_func == "softmax"
), "Only softmax score func is supported for now."
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.half() x = x.half()
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
...@@ -713,11 +1086,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -713,11 +1086,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, correction_bias=e_score_correction_bias,
e_score_correction_bias=e_score_correction_bias,
) )
return torch.ops.vllm.fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_qweight, layer.w13_qweight,
layer.w2_qweight, layer.w2_qweight,
...@@ -730,6 +1102,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -730,6 +1102,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
g_idx2=layer.w2_g_idx, g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices, sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices,
quant_type_id=self.quant_config.quant_type.id, num_bits=self.quant_config.weight_bits,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
).to(orig_dtype) ).to(orig_dtype)
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py
import logging
from typing import Any, Optional
import numpy
import torch
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import get_device_capability
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
logger = logging.getLogger(__name__)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# In case there is a performance issue with Marlin, the variable below can be
# changed to False, which allows Marlin to perform global reductions in fp16
# precision (instead of fp32), and therefore, save on some memory movements.
USE_FP32_REDUCE_DEFAULT = True
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(
has_zp: Optional[bool] = None,
include_fp_type: bool = True,
device_capability: Optional[int] = None,
):
if device_capability is None:
major, minor = get_device_capability()
capability = major * 10 + minor
device_capability = -1 if capability is None else capability
if device_capability < 80:
return []
# - has_zp is True: return quant_types that has zero points
# - has_zp is False: return quant_types that has not zero points
# - has_zp is None: both
if has_zp is None:
types0 = query_marlin_supported_quant_types(
False, include_fp_type, device_capability
)
types1 = query_marlin_supported_quant_types(
True, include_fp_type, device_capability
)
return types0 + types1
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4]
else:
# GPTQ style, unsigned + symmetric bias
res = [scalar_types.uint4b8, scalar_types.uint8b128]
if include_fp_type:
res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
return res
def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None,
) -> tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = get_device_capability()
capability = major * 10 + minor
device_capability = -1 if capability is None else capability
supported_types = query_marlin_supported_quant_types(
has_zp, True, device_capability
)
if quant_type not in supported_types:
return (
False,
f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).",
)
if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return (
False,
f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.",
)
return True, None
def check_marlin_supported(
quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
device_capability: Optional[int] = None,
) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
return cond
def verify_marlin_supported(
quant_type: ScalarType, group_size: int, has_zp: bool = False
) -> None:
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
if not cond:
assert err_msg is not None
raise ValueError(err_msg)
def verify_marlin_supports_shape(
output_size_per_partition: int,
input_size_per_partition: int,
input_size: int,
group_size: int,
) -> None:
# Validate output_size_per_partition
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
# Validate input_size_per_partition
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
if group_size < input_size and input_size_per_partition % group_size != 0:
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
def check_marlin_supports_shape(
output_size_per_partition: int,
input_size_per_partition: int,
input_size: int,
group_size: int,
) -> tuple[bool, Optional[str]]:
try:
verify_marlin_supports_shape(
output_size_per_partition, input_size_per_partition, input_size, group_size
)
except ValueError as e:
return False, e.__str__()
return True, None
def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
output_size_per_partition = (
getattr(layer, "output_size_per_partition", None) or layer.output_size
)
input_size_per_partition = (
getattr(layer, "input_size_per_partition", None) or layer.input_size
)
return check_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=layer.input_size,
group_size=group_size,
)[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin
supports_router_weight = not layer.apply_router_weight_on_input
# moe marlin requires the activation to be silu
supports_activation = layer.activation == "silu"
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
supports_shape = (
hidden_size % 128 == 0
and intermediate_size_per_partition % max(64, group_size) == 0
)
supports_group_size = group_size in [-1, 32, 64, 128]
return (
supports_shape
and supports_group_size
and supports_router_weight
and supports_activation
)
def marlin_make_workspace(
device: torch.device, max_blocks_per_sm: int = 1
) -> torch.Tensor:
# In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
sms = torch.cuda.get_device_properties(device).multi_processor_count
return torch.zeros(
sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False
)
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
def marlin_repeat_scales_on_all_ranks(
act_order: bool, group_size: int, is_row_parallel: bool
) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
)
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
)
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
def get_scale_perms():
scale_perm: list[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: list[int] = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def marlin_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int
) -> torch.Tensor:
scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
size_n: int,
group_size: int,
):
num_experts = s.shape[0]
output = torch.empty(
(num_experts, s.shape[1], s.shape[2]),
device=s.device,
dtype=s.dtype,
)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
return output
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
scale_perm, _ = get_scale_perms()
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp
def moe_awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
):
num_experts = q_zp_packed.shape[0]
output = torch.empty(
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
device=q_zp_packed.device,
dtype=q_zp_packed.dtype,
)
for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
return output
def maybe_warn_marlin_atomic_add(device, dtype):
if torch.compiler.is_dynamo_compiling():
return
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
logger.info_once(
"You are running Marlin kernel with bf16 on GPUs before SM90. "
"You can consider change to fp16 to achieve better performance "
"if possible."
)
def maybe_warn_marlin_atomic_add_env():
if torch.compiler.is_dynamo_compiling():
return
# TODO(yiyun): Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
if True:
return
# if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
# return
logger.info_once(
"Marlin kernel can achieve better performance for small size_n "
"with experimental use_atomic_add feature. "
"You can consider set environment variable "
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible."
)
def should_use_atomic_add_reduce(
m: int, n: int, k: int, device: torch.device, dtype: torch.dtype
) -> bool:
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
if n >= 2048 or k < 2048 or device.type != "cuda":
return False
# disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
# TODO: Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
if not True:
maybe_warn_marlin_atomic_add_env()
return False
# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
maybe_warn_marlin_atomic_add(device, dtype)
return False
return True
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
wtype: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
wtype,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def apply_awq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
class MarlinConfig(QuantizationConfig):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""
def __init__(
self,
group_size: int,
lm_head_quantized: bool,
) -> None:
super().__init__()
# Group size for the quantization.
self.group_size = group_size
self.lm_head_quantized = lm_head_quantized
if self.group_size != 128 and self.group_size != -1:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin, but got group_size of "
f"{self.group_size}"
)
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4
# Tile size used by marlin kernels.
self.tile_size = 16
# Min out_features dim
self.min_n_threads = 64
# Min in_features dim
self.min_k_threads = 128
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return (
f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})"
)
@classmethod
def get_name(cls) -> str:
return "marlin"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
return cls(group_size, lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = hf_quant_cfg.get(
"checkpoint_format"
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
is_valid_user_quant = (
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
)
if is_marlin_format and is_valid_user_quant:
msg = "The model is serialized in {} format. Using {} kernel.".format(
cls.get_name(), cls.get_name()
)
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
return MarlinLinearMethod(self)
return None
class MarlinLinearMethod(LinearMethodBase):
"""Linear method for Marlin.
Args:
quant_config: The Marlin quantization config.
"""
def __init__(self, quant_config: MarlinConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}"
)
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}."
)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}."
)
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}."
)
if (
self.quant_config.group_size != -1
and input_size_per_partition % self.quant_config.group_size != 0
):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2
)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError("Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition
* self.quant_config.tile_size
// self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader,
)
# Determine if channelwise or not
input_groups = (
1
if self.quant_config.group_size == -1
else input_size_per_partition // self.quant_config.group_size
)
weight_scale_args = {
"data": torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
"weight_loader": weight_loader,
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
else:
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition // self.quant_config.min_n_threads
) * self.quant_config.max_parallel
workspace = BasevLLMParameter(
data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int),
weight_loader=weight_loader,
)
layer.register_parameter("B", qweight)
layer.register_parameter("s", scales)
layer.register_parameter("workspace", workspace)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B = torch.nn.Parameter(layer.B.data, requires_grad=False)
layer.s = torch.nn.Parameter(layer.s.data, requires_grad=False)
layer.workspace = torch.nn.Parameter(layer.workspace.data, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.B
scales = layer.s
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.marlin_gemm(
x_2d, qweight, scales, workspace, size_m, size_n, size_k
)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
if bias is not None:
output.add_(bias) # In-place add
return output
...@@ -19,6 +19,36 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs ...@@ -19,6 +19,36 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_weight_perm(num_bits: int):
perm_list: List[int] = []
for i in range(32):
perm1: List[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list)
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
class MoeWNA16Config(QuantizationConfig): class MoeWNA16Config(QuantizationConfig):
"""Config class for MOE WNA16 (W8A16/W4A16) quantization.""" """Config class for MOE WNA16 (W8A16/W4A16) quantization."""
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from typing import Optional
import numpy
import torch
from sgl_kernel.scalar_type import ScalarType
def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def quantize_weights(
w: torch.Tensor,
quant_type: ScalarType,
group_size: Optional[int],
zero_points: bool = False,
ref_zero_points_after_scales: bool = False,
):
assert (
quant_type.is_integer()
), "Floating point quantization may work but has not been tested"
assert not zero_points or group_size is not None, (
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
if group_size == -1:
group_size = size_k
# Reshape to [groupsize, -1]
if group_size is not None and group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
max_q_val = quant_type.max()
min_q_val = quant_type.min()
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
maybe_w_zp = None
if group_size is not None:
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = (
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
)
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
)
# Quantize
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and maybe_w_zp is not None:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias
# Restore original shapes
if group_size is not None and group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
w_s = w_s.reshape((-1, size_n)).contiguous()
if maybe_w_zp is not None:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s if group_size is not None else None,
maybe_w_zp,
)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from types import MappingProxyType from types import MappingProxyType
from typing import List, Mapping, Tuple, Union from typing import List, Mapping, Optional, Tuple, Union
import numpy
import torch import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.scalar_type import ScalarType
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -143,3 +145,162 @@ def replace_parameter( ...@@ -143,3 +145,162 @@ def replace_parameter(
if not isinstance(new, torch.nn.Parameter): if not isinstance(new, torch.nn.Parameter):
new = torch.nn.Parameter(new, requires_grad=False) new = torch.nn.Parameter(new, requires_grad=False)
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
def quantize_weights(
w: torch.Tensor,
quant_type: ScalarType,
group_size: Optional[int],
zero_points: bool = False,
ref_zero_points_after_scales: bool = False,
):
assert (
quant_type.is_integer()
), "Floating point quantization may work but has not been tested"
assert not zero_points or group_size is not None, (
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
if group_size == -1:
group_size = size_k
# Reshape to [groupsize, -1]
if group_size is not None and group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
max_q_val = quant_type.max()
min_q_val = quant_type.min()
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
maybe_w_zp = None
if group_size is not None:
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = (
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
)
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
)
# Quantize
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and maybe_w_zp is not None:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias
# Restore original shapes
if group_size is not None and group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
w_s = w_s.reshape((-1, size_n)).contiguous()
if maybe_w_zp is not None:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s if group_size is not None else None,
maybe_w_zp,
)
...@@ -2,10 +2,11 @@ import functools ...@@ -2,10 +2,11 @@ import functools
from typing import Optional from typing import Optional
import torch import torch
from sgl_kernel.scalar_type import scalar_types
def get_scalar_type(num_bits: int, has_zp: bool): def get_scalar_type(num_bits: int, has_zp: bool):
from sglang.srt.layers.quantization.scalar_type import scalar_types
if has_zp: if has_zp:
assert num_bits == 4 assert num_bits == 4
return scalar_types.uint4 return scalar_types.uint4
......
import math
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from sgl_kernel import awq_marlin_repack from sgl_kernel import awq_marlin_repack
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.quant_utils import ( from sglang.srt.layers.quantization.scalar_type import scalar_types
from sglang.srt.layers.quantization.utils import (
get_pack_factor, get_pack_factor,
pack_cols, pack_cols,
quantize_weights, quantize_weights,
......
...@@ -51,13 +51,12 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): ...@@ -51,13 +51,12 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
model_config=model_config, load_config=load_config, device_config=device_config model_config=model_config, load_config=load_config, device_config=device_config
) )
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from sglang.srt.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import ( from sglang.srt.layers.quantization.gptq import (
GPTQLinearMethod,
GPTQMarlinLinearMethod, GPTQMarlinLinearMethod,
) )
from sglang.srt.layers.linear import UnquantizedLinearMethod
linear_method_cls = ( linear_method_cls = (
GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod) GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod)
) )
...@@ -162,7 +161,7 @@ class TestGPTQModelDynamicWithMarlin(CustomTestCase): ...@@ -162,7 +161,7 @@ class TestGPTQModelDynamicWithMarlin(CustomTestCase):
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--dtype", "float16"], other_args=["--dtype", "bfloat16"],
) )
@classmethod @classmethod
......
import itertools
import sys
import unittest
import torch
sys.path.insert(0, "/home/hadoop-hmart-waimai-rank/vllm")
# from sglang.srt.layers.moe.topk import select_experts
from sgl_kernel import fused_marlin_moe
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
# from vllm.model_executor.layers. import select_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize,
)
from vllm.scalar_type import scalar_types
def stack_and_dev(tensors: list[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def torch_moe(a, w1, w2, score, topk, expert_map):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
C = torch.matmul(A, B) # [M, K]
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B, D = a.shape
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
# Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# Calculate routing
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
)
# Activation function
act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = per_token_quant_int8(act_out)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
)
# Apply routing weights and sum
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def marlin_fused_moe(
N, E, K, a, w1, w2, num_bits, group_size, act_order, score, topk, ep_size
):
quant_type = scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
if ep_size > 1:
local_e = E // ep_size
e_ids = torch.randperm(E, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((E,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
w_ref1_l = []
qweight1_l = []
scales1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
s1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(n=K)
quant_res = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n=N)
quant_res = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
topk_weights, topk_ids = fused_topk(a, score, topk, False)
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=a,
# router_logits=score,
# top_k=topk,
# num_expert_group=E,
# use_grouped_topk=False,
# renormalize=False,
# topk_group=None,
# )
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
global_num_experts=E,
expert_map=e_map,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
num_bits=num_bits,
is_k_full=True,
)
return marlin_output, torch_output
class TestW8A8Int8FusedMoE(unittest.TestCase):
DTYPES = [torch.float16]
M = [1, 16]
N = [128]
K = [256]
E = [4, 10]
TOP_KS = [2, 4]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
NUM_BITS = [4]
EP_SIZE = [1, 4]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w4a8_int8_fused_moe(
self, M, N, K, E, topk, block_size, dtype, seed, num_bits, ep_size
):
torch.manual_seed(seed)
a = torch.randn((M, K), dtype=dtype) / 10
# Generate int8 weights
w1_fp16 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2
w2_fp16 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2
score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode():
marlin_out, ref_out = marlin_fused_moe(
N=N,
E=E,
K=K,
a=a,
w1=w1_fp16,
w2=w2_fp16,
num_bits=num_bits,
group_size=-1,
act_order=False,
score=score,
topk=topk,
ep_size=ep_size,
)
# Check results
if (
torch.mean(
torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32))
)
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
> 0.1
):
print(f"marlin_out: {marlin_out}")
print(f"ref_out: {ref_out}")
print(
torch.mean(
torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32))
)
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
)
torch.testing.assert_close(marlin_out, ref_out, atol=2e-2, rtol=0)
def test_w4a8_int8_fused_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
self.NUM_BITS,
self.EP_SIZE,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
block_size=params[5],
dtype=params[6],
seed=params[7],
num_bits=params[8],
ep_size=params[9],
):
self._w4a8_int8_fused_moe(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)
import sgl_kernel
import torch
x = torch.randn(10, 10, device="cuda")
qweight = torch.randn(10, 10, device="cuda")
s1_scales = torch.randn(10, device="cuda")
input_scales = torch.randn(10, device="cuda")
s1_szeros = torch.randn(10, device="cuda")
input_sum = torch.randn(10, device="cuda")
output_buffer = torch.randn(10, device="cuda")
torch.ops.sgl_kernel.gemm_forward_cuda.default(
x, qweight, s1_scales, input_scales, s1_szeros, input_sum, output_buffer
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment