"torchvision/models/vscode:/vscode.git/clone" did not exist on "5906d5908032ff8aa905d94fb97eb36482e51731"
Unverified Commit 1f76fc87 authored by Hongbo Xu's avatar Hongbo Xu Committed by GitHub
Browse files

[3/n] chore: decouple AWQ implementation from vLLM dependency (#8113)


Co-authored-by: default avatarAniZpZ <zhuangsen.zp@antgroup.com>
parent 6737671c
......@@ -178,6 +178,8 @@ python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1
### Example: Serving with 8 A100/A800 with AWQ Quantization
**Recommended Usage**
Add `--quantization moe_wna16` flag to enable moe wna16 kernel for better performance.
One example is as follows:
......@@ -185,6 +187,13 @@ One example is as follows:
python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization moe_wna16
```
Alternatively, you can use `--quantization awq_marlin` as follows:
```bash
python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization awq_marlin --dtype float16
```
Note that `awq_marlin` only supports `float16` now, which may lead to some precision loss.
### Example: Serving with 16 A100/A800 with int8 Quantization
......
......@@ -7,10 +7,6 @@ import torch
try:
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMoEMethod,
)
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod,
......@@ -36,14 +32,14 @@ except ImportError:
def override_quantization_method(self, *args, **kwargs):
return None
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
DeepSpeedFPConfig
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
MarlinConfig
) = QQQConfig = Int8TpuConfig = DummyConfig
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
ExpertsInt8Config
) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
Int8TpuConfig
) = DummyConfig
from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
......@@ -63,10 +59,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
)
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.utils import (
get_dynamic_override,
get_linear_quant_method,
)
from sglang.srt.layers.quantization.utils import get_linear_quant_method
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
......@@ -237,7 +230,6 @@ def monkey_patch_quant_configs():
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
monkey_patch_moe_apply(AWQMoEMethod)
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
......
......@@ -2,21 +2,52 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
import warnings
from typing import Any, Callable, Dict, List, Optional
import torch
from sglang.srt.layers.linear import LinearBase, set_weight_attrs
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.marlin_utils import (
apply_awq_marlin_linear,
awq_to_marlin_zero_points,
check_marlin_supported,
check_marlin_supports_layer,
check_moe_marlin_supports_layer,
marlin_make_empty_g_idx,
marlin_make_workspace,
marlin_moe_permute_scales,
marlin_permute_scales,
moe_awq_to_marlin_zero_points,
verify_marlin_supported,
verify_marlin_supports_shape,
)
from sglang.srt.layers.quantization.scalar_type import scalar_types
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import replace_parameter
try:
from vllm import _custom_ops as ops
warnings.warn(
f"Using kernels directly from vllm. This might lead to performance degradation or "
f"missing functionalities as certain kernels may not be optimized. "
)
except ImportError:
ops = None
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import awq_dequantize
from sgl_kernel import awq_dequantize, fused_marlin_moe
logger = logging.getLogger(__name__)
......@@ -103,6 +134,176 @@ class AWQConfig(QuantizationConfig):
return None
class AWQMarlinConfig(QuantizationConfig):
"""Config class for AWQ Marlin"""
# num_bits -> type
TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[list[str]],
full_config: dict[str, Any],
) -> None:
super().__init__()
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or []
self.full_config = full_config
if self.weight_bits not in self.TYPE_MAP:
raise ValueError(
f"Unsupported num_bits = {self.weight_bits}. "
f"Supported num_bits = {self.TYPE_MAP.keys()}"
)
self.quant_type = self.TYPE_MAP[self.weight_bits]
verify_marlin_supported(
self.quant_type, group_size=self.group_size, has_zp=self.zero_point
)
def __repr__(self) -> str:
return (
f"AWQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"modules_to_not_convert={self.modules_to_not_convert})"
)
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def get_name(cls) -> str:
return "awq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> AWQMarlinConfig:
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
return cls(
weight_bits,
group_size,
zero_point,
lm_head_quantized,
modules_to_not_convert,
config,
)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
)
if can_convert and is_valid_user_quant:
msg = (
"The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name())
)
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "awq":
logger.info(
"Detected that the model can run with awq_marlin"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_marlin for"
" faster inference"
)
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
prefix,
)
return AWQConfig.from_config(self.full_config).get_quant_method(
layer, prefix
)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels."
)
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix
)
return AWQMoEMethod(self)
return None
@classmethod
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
zero_point = quant_config.get("zero_point")
if not _is_cuda:
return False
if quant_method != "awq":
return False
# If we cannot find the info needed in the config, cannot convert.
if num_bits is None or group_size is None or zero_point is None:
return False
if num_bits not in cls.TYPE_MAP:
return False
return check_marlin_supported(
quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point
)
class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.
......@@ -204,3 +405,382 @@ class AWQLinearMethod(LinearMethodBase):
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)
class AWQMarlinLinearMethod(LinearMethodBase):
"""Linear method for AWQ Marlin.
Args:
quant_config: The AWQ Marlin quantization config.
"""
def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size,
)
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
num_groups = input_size_per_partition // group_size
qzeros = PackedvLLMParameter(
data=torch.empty(
num_groups,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
scales = GroupQuantScaleParameter(
data=torch.empty(
num_groups,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.num_groups = num_groups
# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(device)
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
)
replace_parameter(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points(
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
)
replace_parameter(layer, "qzeros", marlin_zp)
# Not-used
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_awq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.qzeros,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
bias=bias,
)
class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
extra_weight_attrs.update(
{
"is_transposed": True,
"quant_method": FusedMoeWeightScaleSupported.GROUP.value,
}
)
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = torch.nn.Parameter(
torch.empty(
num_experts,
num_groups_w13,
intermediate_size_per_partition * 2,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
w2_scales = torch.nn.Parameter(
torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = torch.nn.Parameter(
torch.empty(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_qzeros = torch.nn.Parameter(
torch.empty(
num_experts,
num_groups_w2,
hidden_size // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
marlin_w13_qweight = ops.awq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.awq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# hidden_size->intermediate_size
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points(
layer.w13_qzeros,
size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
marlin_w2_zp = moe_awq_to_marlin_zero_points(
layer.w2_qzeros,
size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
assert (
scoring_func == "softmax"
), "Only softmax score func is supported for now."
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
num_bits=self.quant_config.weight_bits,
).to(orig_dtype)
......@@ -11,7 +11,7 @@ import numpy
import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.scalar_type import ScalarType
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
if TYPE_CHECKING:
......@@ -247,6 +247,36 @@ def get_pack_factor(num_bits):
return 32 // num_bits
def permute_rows(
q_w: torch.Tensor,
w_ref: torch.Tensor,
group_size: int,
test_perm: Optional[torch.Tensor] = None,
):
assert q_w.shape == w_ref.shape
orig_device = q_w.device
k_size, _ = q_w.shape
g_idx = torch.zeros((k_size,), dtype=torch.int32)
for i in range(k_size):
g_idx[i] = i // group_size
# Simulate act_order by doing a random permutation on K
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous()
w_ref = w_ref[rand_perm, :].contiguous()
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
)
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
......@@ -399,3 +429,56 @@ def quantize_weights(
w_s if group_size is not None else None,
maybe_w_zp,
)
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def gptq_quantize_weights(
w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
act_order: bool,
test_perm: Optional[torch.Tensor] = None,
):
size_k, _ = w.shape
assert w.is_floating_point(), "w must be float"
assert (
quant_type in SUPPORTED_GPTQ_QUANT_TYPES
), f"Unsupported gptq type = {quant_type}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
# Apply act_order
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
if act_order:
assert (
group_size < size_k
), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k
)
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
return w_ref, w_q, w_s, g_idx, rand_perm
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
orig_device = q_w.device
sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
g_idx = g_idx[sort_indices].contiguous()
q_w = q_w[sort_indices, :].contiguous()
return (
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
sort_indices.to(device=orig_device),
)
......@@ -355,6 +355,7 @@ class DeepseekV2MoE(nn.Module):
self.shared_experts.gate_up_proj.quant_method, "quant_config"
) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
"awq",
"awq_marlin",
"moe_wna16",
}
self.shared_experts_is_int8 = (
......@@ -929,7 +930,7 @@ class DeepseekV2AttentionMLA(nn.Module):
has_fused_proj
and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
in {"awq", "moe_wna16"}
in {"awq", "awq_marlin", "moe_wna16"}
)
self.use_min_latency_fused_a_gemm = (
has_fused_proj
......@@ -2551,6 +2552,7 @@ class DeepseekV2ForCausalLM(nn.Module):
cat_dim = 0
if self.quant_config is not None and (
self.quant_config.get_name() == "awq"
or self.quant_config.get_name() == "awq_marlin"
or self.quant_config.get_name() == "moe_wna16"
):
cat_dim = 1
......
import types
from typing import Optional
import pytest
import torch
from sgl_kernel import fused_marlin_moe
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
def stack_and_dev(tensors: list[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def torch_experts(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
apply_router_weights_on_input: bool = False,
) -> torch.Tensor:
assert (
global_num_experts == -1
or (global_num_experts == w1.shape[0] and expert_map is None)
or (expert_map is not None and global_num_experts == expert_map.shape[0])
)
M, K = a.shape
topk = topk_ids.shape[1]
print("quant_dtype", quant_dtype)
# exit(0)
if apply_router_weights_on_input:
assert topk == 1
a = a * topk_weight.to(a.dtype)
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
f32 = torch.float32
for i in range(num_experts):
mask = topk_ids == i
if mask.sum():
if quant_dtype is None:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
if apply_router_weights_on_input:
return out
else:
return (
(out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1))
.sum(dim=1)
.to(out.dtype)
)
def torch_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
return torch_experts(
a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map
)
def marlin_moe_generate_valid_test_cases():
import itertools
m_list = [1, 123, 666]
n_list = [128, 1024]
k_list = [256, 2048]
e_list = [4, 12]
topk_list = [2, 3]
dtype_list = [torch.half, torch.bfloat16]
group_size_list = [128]
act_order_list = [True, False]
quant_type_list = [
scalar_types.uint4,
scalar_types.uint4b8,
]
is_k_full_list = [True, False]
all_combinations = itertools.product(
m_list,
n_list,
k_list,
e_list,
topk_list,
dtype_list,
group_size_list,
act_order_list,
quant_type_list,
is_k_full_list,
)
def is_invalid(
m, n, k, e, topk, dtype, group_size, act_order, quant_type, is_k_full
):
# Filter act_order
if act_order:
if group_size in (-1, k, n):
return False
if quant_type not in [scalar_types.uint4b8]:
return False
elif not is_k_full:
return False
return True
cases = []
for case in all_combinations:
if is_invalid(*case):
cases.append(case)
return cases
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize(
("m, n, k, e, topk, dtype, group_size," "act_order, quant_type, is_k_full"),
marlin_moe_generate_valid_test_cases(),
)
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
group_size: int,
act_order: bool,
quant_type: ScalarType,
is_k_full: bool,
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
torch.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
e_map = None
w_ref1_l = []
qweight1_l = []
scales1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
if has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
zeros1_l.append(zeros1)
else:
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
if has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
zeros2_l.append(zeros2)
else:
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
from sglang.srt.layers.moe.topk import fused_topk_torch_native
topk_weights, topk_ids = fused_topk_torch_native(a, score, topk, False)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
num_bits=4,
is_k_full=is_k_full,
)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])
"""
Adapted from
https://github.com/vllm-project/vllm/blob/020f58abcdea65302225663130d08fd8f4dd755a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
"""
# SPDX-License-Identifier: Apache-2.0
"""Utility functions used for tests and benchmarks"""
from typing import Optional
import numpy as np
import torch
from sglang.srt.layers.quantization.marlin_utils import (
GPTQ_MARLIN_TILE,
marlin_permute_scales,
marlin_zero_points,
)
from sglang.srt.layers.quantization.scalar_type import ScalarType
from sglang.srt.layers.quantization.utils import (
get_pack_factor,
gptq_quantize_weights,
quantize_weights,
sort_weights,
)
class MarlinWorkspace:
def __init__(self, out_features, min_thread_n, max_parallel):
assert (
out_features % min_thread_n == 0
), "out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n
)
max_workspace_size = (out_features // min_thread_n) * max_parallel
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
# Pack
pack_factor = get_pack_factor(num_bits)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_packed
def get_weight_perm(num_bits: int):
perm_list: list[int] = []
for i in range(32):
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list)
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_quantize(
w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
act_order: bool,
test_perm: Optional[torch.Tensor] = None,
):
size_k, size_n = w.shape
num_bits = quant_type.size_bits
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
w, quant_type, group_size, act_order, test_perm
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
if act_order:
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Reformat to marlin
weight_perm = get_weight_perm(num_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Detect num groups
assert size_k % group_size == 0
num_groups = size_k // group_size
# Quantize with zp
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
# Reformat to marlin
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
......@@ -24,7 +24,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
set_custom_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.quantization import get_dynamic_override
from sglang.srt.layers.quantization.utils import get_dynamic_override
from sglang.srt.model_loader import get_model
from sglang.srt.server_args import PortArgs, ServerArgs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment