Unverified Commit 1f76fc87 authored by Hongbo Xu's avatar Hongbo Xu Committed by GitHub
Browse files

[3/n] chore: decouple AWQ implementation from vLLM dependency (#8113)


Co-authored-by: default avatarAniZpZ <zhuangsen.zp@antgroup.com>
parent 6737671c
...@@ -178,6 +178,8 @@ python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1 ...@@ -178,6 +178,8 @@ python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1
### Example: Serving with 8 A100/A800 with AWQ Quantization ### Example: Serving with 8 A100/A800 with AWQ Quantization
**Recommended Usage**
Add `--quantization moe_wna16` flag to enable moe wna16 kernel for better performance. Add `--quantization moe_wna16` flag to enable moe wna16 kernel for better performance.
One example is as follows: One example is as follows:
...@@ -185,6 +187,13 @@ One example is as follows: ...@@ -185,6 +187,13 @@ One example is as follows:
python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization moe_wna16 python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization moe_wna16
``` ```
Alternatively, you can use `--quantization awq_marlin` as follows:
```bash
python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization awq_marlin --dtype float16
```
Note that `awq_marlin` only supports `float16` now, which may lead to some precision loss.
### Example: Serving with 16 A100/A800 with int8 Quantization ### Example: Serving with 16 A100/A800 with int8 Quantization
......
...@@ -7,10 +7,6 @@ import torch ...@@ -7,10 +7,6 @@ import torch
try: try:
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMoEMethod,
)
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod, CompressedTensorsW8A8Fp8MoEMethod,
...@@ -36,14 +32,14 @@ except ImportError: ...@@ -36,14 +32,14 @@ except ImportError:
def override_quantization_method(self, *args, **kwargs): def override_quantization_method(self, *args, **kwargs):
return None return None
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = ( AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
DeepSpeedFPConfig ExpertsInt8Config
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = ( ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
MarlinConfig Int8TpuConfig
) = QQQConfig = Int8TpuConfig = DummyConfig ) = DummyConfig
from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
...@@ -63,10 +59,7 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ...@@ -63,10 +59,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
) )
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import get_linear_quant_method
get_dynamic_override,
get_linear_quant_method,
)
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
...@@ -237,7 +230,6 @@ def monkey_patch_quant_configs(): ...@@ -237,7 +230,6 @@ def monkey_patch_quant_configs():
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method) setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
monkey_patch_moe_apply(AWQMoEMethod)
monkey_patch_moe_apply(GPTQMarlinMoEMethod) monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod) monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod) monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
......
...@@ -2,21 +2,52 @@ ...@@ -2,21 +2,52 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Dict, List, Optional import warnings
from typing import Any, Callable, Dict, List, Optional
import torch import torch
from sglang.srt.layers.linear import LinearBase, set_weight_attrs
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase, LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.marlin_utils import (
apply_awq_marlin_linear,
awq_to_marlin_zero_points,
check_marlin_supported,
check_marlin_supports_layer,
check_moe_marlin_supports_layer,
marlin_make_empty_g_idx,
marlin_make_workspace,
marlin_moe_permute_scales,
marlin_permute_scales,
moe_awq_to_marlin_zero_points,
verify_marlin_supported,
verify_marlin_supports_shape,
)
from sglang.srt.layers.quantization.scalar_type import scalar_types
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import replace_parameter
try:
from vllm import _custom_ops as ops
warnings.warn(
f"Using kernels directly from vllm. This might lead to performance degradation or "
f"missing functionalities as certain kernels may not be optimized. "
)
except ImportError:
ops = None
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize from sgl_kernel import awq_dequantize, fused_marlin_moe
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -103,6 +134,176 @@ class AWQConfig(QuantizationConfig): ...@@ -103,6 +134,176 @@ class AWQConfig(QuantizationConfig):
return None return None
class AWQMarlinConfig(QuantizationConfig):
"""Config class for AWQ Marlin"""
# num_bits -> type
TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[list[str]],
full_config: dict[str, Any],
) -> None:
super().__init__()
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or []
self.full_config = full_config
if self.weight_bits not in self.TYPE_MAP:
raise ValueError(
f"Unsupported num_bits = {self.weight_bits}. "
f"Supported num_bits = {self.TYPE_MAP.keys()}"
)
self.quant_type = self.TYPE_MAP[self.weight_bits]
verify_marlin_supported(
self.quant_type, group_size=self.group_size, has_zp=self.zero_point
)
def __repr__(self) -> str:
return (
f"AWQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"modules_to_not_convert={self.modules_to_not_convert})"
)
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def get_name(cls) -> str:
return "awq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> AWQMarlinConfig:
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
return cls(
weight_bits,
group_size,
zero_point,
lm_head_quantized,
modules_to_not_convert,
config,
)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
)
if can_convert and is_valid_user_quant:
msg = (
"The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name())
)
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "awq":
logger.info(
"Detected that the model can run with awq_marlin"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_marlin for"
" faster inference"
)
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
prefix,
)
return AWQConfig.from_config(self.full_config).get_quant_method(
layer, prefix
)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels."
)
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix
)
return AWQMoEMethod(self)
return None
@classmethod
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
zero_point = quant_config.get("zero_point")
if not _is_cuda:
return False
if quant_method != "awq":
return False
# If we cannot find the info needed in the config, cannot convert.
if num_bits is None or group_size is None or zero_point is None:
return False
if num_bits not in cls.TYPE_MAP:
return False
return check_marlin_supported(
quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point
)
class AWQLinearMethod(LinearMethodBase): class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ. """Linear method for AWQ.
...@@ -204,3 +405,382 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -204,3 +405,382 @@ class AWQLinearMethod(LinearMethodBase):
if bias is not None: if bias is not None:
out.add_(bias) out.add_(bias)
return out.reshape(out_shape) return out.reshape(out_shape)
class AWQMarlinLinearMethod(LinearMethodBase):
"""Linear method for AWQ Marlin.
Args:
quant_config: The AWQ Marlin quantization config.
"""
def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size,
)
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
num_groups = input_size_per_partition // group_size
qzeros = PackedvLLMParameter(
data=torch.empty(
num_groups,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
scales = GroupQuantScaleParameter(
data=torch.empty(
num_groups,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.num_groups = num_groups
# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(device)
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
)
replace_parameter(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points(
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
)
replace_parameter(layer, "qzeros", marlin_zp)
# Not-used
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_awq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.qzeros,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
bias=bias,
)
class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
extra_weight_attrs.update(
{
"is_transposed": True,
"quant_method": FusedMoeWeightScaleSupported.GROUP.value,
}
)
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = torch.nn.Parameter(
torch.empty(
num_experts,
num_groups_w13,
intermediate_size_per_partition * 2,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
w2_scales = torch.nn.Parameter(
torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = torch.nn.Parameter(
torch.empty(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_qzeros = torch.nn.Parameter(
torch.empty(
num_experts,
num_groups_w2,
hidden_size // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
marlin_w13_qweight = ops.awq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.awq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# hidden_size->intermediate_size
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points(
layer.w13_qzeros,
size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
marlin_w2_zp = moe_awq_to_marlin_zero_points(
layer.w2_qzeros,
size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
assert (
scoring_func == "softmax"
), "Only softmax score func is supported for now."
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
num_bits=self.quant_config.weight_bits,
).to(orig_dtype)
...@@ -11,7 +11,7 @@ import numpy ...@@ -11,7 +11,7 @@ import numpy
import torch import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.scalar_type import ScalarType from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -247,6 +247,36 @@ def get_pack_factor(num_bits): ...@@ -247,6 +247,36 @@ def get_pack_factor(num_bits):
return 32 // num_bits return 32 // num_bits
def permute_rows(
q_w: torch.Tensor,
w_ref: torch.Tensor,
group_size: int,
test_perm: Optional[torch.Tensor] = None,
):
assert q_w.shape == w_ref.shape
orig_device = q_w.device
k_size, _ = q_w.shape
g_idx = torch.zeros((k_size,), dtype=torch.int32)
for i in range(k_size):
g_idx[i] = i // group_size
# Simulate act_order by doing a random permutation on K
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous()
w_ref = w_ref[rand_perm, :].contiguous()
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
)
def pack_cols( def pack_cols(
q_w: torch.Tensor, q_w: torch.Tensor,
num_bits: int, num_bits: int,
...@@ -399,3 +429,56 @@ def quantize_weights( ...@@ -399,3 +429,56 @@ def quantize_weights(
w_s if group_size is not None else None, w_s if group_size is not None else None,
maybe_w_zp, maybe_w_zp,
) )
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def gptq_quantize_weights(
w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
act_order: bool,
test_perm: Optional[torch.Tensor] = None,
):
size_k, _ = w.shape
assert w.is_floating_point(), "w must be float"
assert (
quant_type in SUPPORTED_GPTQ_QUANT_TYPES
), f"Unsupported gptq type = {quant_type}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
# Apply act_order
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
if act_order:
assert (
group_size < size_k
), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k
)
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
return w_ref, w_q, w_s, g_idx, rand_perm
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
orig_device = q_w.device
sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
g_idx = g_idx[sort_indices].contiguous()
q_w = q_w[sort_indices, :].contiguous()
return (
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
sort_indices.to(device=orig_device),
)
...@@ -355,6 +355,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -355,6 +355,7 @@ class DeepseekV2MoE(nn.Module):
self.shared_experts.gate_up_proj.quant_method, "quant_config" self.shared_experts.gate_up_proj.quant_method, "quant_config"
) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in { ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
"awq", "awq",
"awq_marlin",
"moe_wna16", "moe_wna16",
} }
self.shared_experts_is_int8 = ( self.shared_experts_is_int8 = (
...@@ -929,7 +930,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -929,7 +930,7 @@ class DeepseekV2AttentionMLA(nn.Module):
has_fused_proj has_fused_proj
and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config") and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name() and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
in {"awq", "moe_wna16"} in {"awq", "awq_marlin", "moe_wna16"}
) )
self.use_min_latency_fused_a_gemm = ( self.use_min_latency_fused_a_gemm = (
has_fused_proj has_fused_proj
...@@ -2551,6 +2552,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2551,6 +2552,7 @@ class DeepseekV2ForCausalLM(nn.Module):
cat_dim = 0 cat_dim = 0
if self.quant_config is not None and ( if self.quant_config is not None and (
self.quant_config.get_name() == "awq" self.quant_config.get_name() == "awq"
or self.quant_config.get_name() == "awq_marlin"
or self.quant_config.get_name() == "moe_wna16" or self.quant_config.get_name() == "moe_wna16"
): ):
cat_dim = 1 cat_dim = 1
......
import types
from typing import Optional
import pytest
import torch
from sgl_kernel import fused_marlin_moe
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
def stack_and_dev(tensors: list[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def torch_experts(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
apply_router_weights_on_input: bool = False,
) -> torch.Tensor:
assert (
global_num_experts == -1
or (global_num_experts == w1.shape[0] and expert_map is None)
or (expert_map is not None and global_num_experts == expert_map.shape[0])
)
M, K = a.shape
topk = topk_ids.shape[1]
print("quant_dtype", quant_dtype)
# exit(0)
if apply_router_weights_on_input:
assert topk == 1
a = a * topk_weight.to(a.dtype)
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
f32 = torch.float32
for i in range(num_experts):
mask = topk_ids == i
if mask.sum():
if quant_dtype is None:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
if apply_router_weights_on_input:
return out
else:
return (
(out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1))
.sum(dim=1)
.to(out.dtype)
)
def torch_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
return torch_experts(
a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map
)
def marlin_moe_generate_valid_test_cases():
import itertools
m_list = [1, 123, 666]
n_list = [128, 1024]
k_list = [256, 2048]
e_list = [4, 12]
topk_list = [2, 3]
dtype_list = [torch.half, torch.bfloat16]
group_size_list = [128]
act_order_list = [True, False]
quant_type_list = [
scalar_types.uint4,
scalar_types.uint4b8,
]
is_k_full_list = [True, False]
all_combinations = itertools.product(
m_list,
n_list,
k_list,
e_list,
topk_list,
dtype_list,
group_size_list,
act_order_list,
quant_type_list,
is_k_full_list,
)
def is_invalid(
m, n, k, e, topk, dtype, group_size, act_order, quant_type, is_k_full
):
# Filter act_order
if act_order:
if group_size in (-1, k, n):
return False
if quant_type not in [scalar_types.uint4b8]:
return False
elif not is_k_full:
return False
return True
cases = []
for case in all_combinations:
if is_invalid(*case):
cases.append(case)
return cases
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize(
("m, n, k, e, topk, dtype, group_size," "act_order, quant_type, is_k_full"),
marlin_moe_generate_valid_test_cases(),
)
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
group_size: int,
act_order: bool,
quant_type: ScalarType,
is_k_full: bool,
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
torch.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
e_map = None
w_ref1_l = []
qweight1_l = []
scales1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
if has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
zeros1_l.append(zeros1)
else:
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
if has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
zeros2_l.append(zeros2)
else:
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
from sglang.srt.layers.moe.topk import fused_topk_torch_native
topk_weights, topk_ids = fused_topk_torch_native(a, score, topk, False)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
num_bits=4,
is_k_full=is_k_full,
)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])
"""
Adapted from
https://github.com/vllm-project/vllm/blob/020f58abcdea65302225663130d08fd8f4dd755a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
"""
# SPDX-License-Identifier: Apache-2.0
"""Utility functions used for tests and benchmarks"""
from typing import Optional
import numpy as np
import torch
from sglang.srt.layers.quantization.marlin_utils import (
GPTQ_MARLIN_TILE,
marlin_permute_scales,
marlin_zero_points,
)
from sglang.srt.layers.quantization.scalar_type import ScalarType
from sglang.srt.layers.quantization.utils import (
get_pack_factor,
gptq_quantize_weights,
quantize_weights,
sort_weights,
)
class MarlinWorkspace:
def __init__(self, out_features, min_thread_n, max_parallel):
assert (
out_features % min_thread_n == 0
), "out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n
)
max_workspace_size = (out_features // min_thread_n) * max_parallel
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
# Pack
pack_factor = get_pack_factor(num_bits)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_packed
def get_weight_perm(num_bits: int):
perm_list: list[int] = []
for i in range(32):
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list)
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_quantize(
w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
act_order: bool,
test_perm: Optional[torch.Tensor] = None,
):
size_k, size_n = w.shape
num_bits = quant_type.size_bits
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
w, quant_type, group_size, act_order, test_perm
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
if act_order:
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Reformat to marlin
weight_perm = get_weight_perm(num_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Detect num groups
assert size_k % group_size == 0
num_groups = size_k // group_size
# Quantize with zp
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
# Reformat to marlin
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
...@@ -24,7 +24,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): ...@@ -24,7 +24,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
set_custom_all_reduce, set_custom_all_reduce,
) )
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.quantization import get_dynamic_override from sglang.srt.layers.quantization.utils import get_dynamic_override
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment