Unverified Commit d95269f9 authored by AniZpZ's avatar AniZpZ Committed by GitHub
Browse files

[2/3] fix dsv3 awq issue (#4625)


Co-authored-by: default avatar晟海 <huangtingwei.htw@antgroup.com>
Co-authored-by: default avatarlaixinn <xielx@shanghaitech.edu.cn>
parent e53bf190
...@@ -178,10 +178,11 @@ python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1 ...@@ -178,10 +178,11 @@ python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1
### Example: Serving with 8 A100/A800 with AWQ Quantization ### Example: Serving with 8 A100/A800 with AWQ Quantization
AWQ does not support BF16, so add the `--dtype half` flag if AWQ is used for quantization. One example is as follows: Add `--quantization moe_wna16` flag to enable moe wna16 kernel for better performance.
One example is as follows:
```bash ```bash
python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --dtype half python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization moe_wna16
``` ```
......
...@@ -258,6 +258,7 @@ class ModelConfig: ...@@ -258,6 +258,7 @@ class ModelConfig:
"experts_int8", "experts_int8",
"w8a8_int8", "w8a8_int8",
"w8a8_fp8", "w8a8_fp8",
"moe_wna16",
] ]
compatible_quantization_methods = { compatible_quantization_methods = {
"w8a8_int8": ["compressed-tensors", "compressed_tensors"], "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
......
...@@ -52,6 +52,257 @@ if _is_cuda or _is_hip: ...@@ -52,6 +52,257 @@ if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
@triton.jit
def write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
):
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit
def fused_moe_kernel_gptq_awq(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
b_scale_ptr,
b_zp_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N: tl.constexpr,
K: tl.constexpr,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bse,
stride_bsk,
stride_bsn,
stride_bze,
stride_bzk,
stride_bzn,
group_size: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
)
return
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
if use_int4_w4a16:
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] // 2) * stride_bk
+ offs_bn[None, :] * stride_bn
)
b_shifter = (offs_k[:, None] % 2) * 4
elif use_int8_w8a16:
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
)
if not has_zp and use_int4_w4a16:
b_zp_num = 8
if not has_zp and use_int8_w8a16:
b_zp_num = 128
elif has_zp and use_int4_w4a16:
b_zp_shifter = (offs_bn[None, :] % 2) * 4
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
if not even_Ks:
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
k_other = 0.0
else:
k_mask = None
k_other = None
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs)
if use_int4_w4a16:
b = (b >> b_shifter) & 0xF
b_scale_ptrs = (
b_scale_ptr
+ off_experts * stride_bse
+ offs_bn[None, :] * stride_bsn
+ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
)
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
b_scale = b_scale.to(tl.float32)
if has_zp and use_int4_w4a16:
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
b_zp_ptrs = (
b_zp_ptr
+ off_experts * stride_bze
+ (offs_bn[None, :] // 2) * stride_bzn
+ offs_k_true * stride_bzk
)
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
b_zp = (b_zp >> b_zp_shifter) & 0xF
b_zp = b_zp.to(tl.float32)
elif has_zp and use_int8_w8a16:
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
b_zp_ptrs = (
b_zp_ptr
+ off_experts * stride_bze
+ offs_bn[None, :] * stride_bzn
+ offs_k_true * stride_bzk
)
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
b_zp = b_zp.to(tl.float32)
# We accumulate along the K dimension.
if has_zp:
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
else:
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
accumulator = tl.dot(a, b, acc=accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
if use_int4_w4a16:
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
else:
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit @triton.jit
def fused_moe_kernel( def fused_moe_kernel(
# Pointers to matrices # Pointers to matrices
...@@ -496,6 +747,7 @@ def invoke_fused_moe_kernel( ...@@ -496,6 +747,7 @@ def invoke_fused_moe_kernel(
C: torch.Tensor, C: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor],
B_zp: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
...@@ -508,6 +760,7 @@ def invoke_fused_moe_kernel( ...@@ -508,6 +760,7 @@ def invoke_fused_moe_kernel(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
) -> None: ) -> None:
...@@ -548,8 +801,9 @@ def invoke_fused_moe_kernel( ...@@ -548,8 +801,9 @@ def invoke_fused_moe_kernel(
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16: elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else: else:
assert A_scale is None assert A_scale is None
assert B_scale is None assert B_scale is None
...@@ -565,6 +819,53 @@ def invoke_fused_moe_kernel( ...@@ -565,6 +819,53 @@ def invoke_fused_moe_kernel(
else: else:
even_Ks = False even_Ks = False
if (
(use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None
and block_shape[1] > 0
):
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
fused_moe_kernel_gptq_awq[grid](
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0),
B_scale.stride(2),
B_scale.stride(1),
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
even_Ks=even_Ks,
**config,
)
else:
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
B, B,
...@@ -750,6 +1051,7 @@ def try_get_optimal_moe_config( ...@@ -750,6 +1051,7 @@ def try_get_optimal_moe_config(
def get_config_dtype_str( def get_config_dtype_str(
dtype: torch.dtype, dtype: torch.dtype,
use_int8_w8a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False,
use_int4_w4a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False,
use_int8_w8a8: Optional[bool] = False, use_int8_w8a8: Optional[bool] = False,
): ):
...@@ -757,6 +1059,8 @@ def get_config_dtype_str( ...@@ -757,6 +1059,8 @@ def get_config_dtype_str(
return "fp8_w8a8" return "fp8_w8a8"
elif use_int8_w8a8: elif use_int8_w8a8:
return "int8_w8a8" return "int8_w8a8"
elif use_int4_w4a16:
return "int4_w4a16"
elif use_int8_w8a16: elif use_int8_w8a16:
return "int8_w8a16" return "int8_w8a16"
elif dtype == torch.float: elif dtype == torch.float:
...@@ -776,8 +1080,11 @@ def inplace_fused_experts( ...@@ -776,8 +1080,11 @@ def inplace_fused_experts(
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -793,8 +1100,11 @@ def inplace_fused_experts( ...@@ -793,8 +1100,11 @@ def inplace_fused_experts(
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
w1_scale, w1_scale,
w2_scale, w2_scale,
w1_zp,
w2_zp,
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
...@@ -811,8 +1121,11 @@ def inplace_fused_experts_fake( ...@@ -811,8 +1121,11 @@ def inplace_fused_experts_fake(
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -838,8 +1151,11 @@ def outplace_fused_experts( ...@@ -838,8 +1151,11 @@ def outplace_fused_experts(
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -856,8 +1172,11 @@ def outplace_fused_experts( ...@@ -856,8 +1172,11 @@ def outplace_fused_experts(
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
w1_scale, w1_scale,
w2_scale, w2_scale,
w1_zp,
w2_zp,
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
...@@ -875,8 +1194,11 @@ def outplace_fused_experts_fake( ...@@ -875,8 +1194,11 @@ def outplace_fused_experts_fake(
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -904,8 +1226,11 @@ def fused_experts( ...@@ -904,8 +1226,11 @@ def fused_experts(
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -923,8 +1248,11 @@ def fused_experts( ...@@ -923,8 +1248,11 @@ def fused_experts(
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
w1_scale, w1_scale,
w2_scale, w2_scale,
w1_zp,
w2_zp,
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
...@@ -941,8 +1269,11 @@ def fused_experts( ...@@ -941,8 +1269,11 @@ def fused_experts(
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
w1_scale, w1_scale,
w2_scale, w2_scale,
w1_zp,
w2_zp,
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
...@@ -961,8 +1292,11 @@ def fused_experts_impl( ...@@ -961,8 +1292,11 @@ def fused_experts_impl(
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -977,7 +1311,12 @@ def fused_experts_impl( ...@@ -977,7 +1311,12 @@ def fused_experts_impl(
padded_size = 0 padded_size = 0
# Check constraints. # Check constraints.
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch" if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
else:
assert (
hidden_states.shape[1] == w1.shape[2] - padded_size
), "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
...@@ -994,6 +1333,7 @@ def fused_experts_impl( ...@@ -994,6 +1333,7 @@ def fused_experts_impl(
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
...@@ -1075,6 +1415,7 @@ def fused_experts_impl( ...@@ -1075,6 +1415,7 @@ def fused_experts_impl(
intermediate_cache1, intermediate_cache1,
a1_scale, a1_scale,
w1_scale, w1_scale,
w1_zp,
curr_topk_weights, curr_topk_weights,
curr_topk_ids, curr_topk_ids,
sorted_token_ids, sorted_token_ids,
...@@ -1087,6 +1428,7 @@ def fused_experts_impl( ...@@ -1087,6 +1428,7 @@ def fused_experts_impl(
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape, block_shape=block_shape,
) )
if activation == "silu": if activation == "silu":
...@@ -1116,6 +1458,7 @@ def fused_experts_impl( ...@@ -1116,6 +1458,7 @@ def fused_experts_impl(
), ),
a2_scale, a2_scale,
w2_scale, w2_scale,
w2_zp,
curr_topk_weights, curr_topk_weights,
curr_topk_ids, curr_topk_ids,
sorted_token_ids, sorted_token_ids,
...@@ -1128,6 +1471,7 @@ def fused_experts_impl( ...@@ -1128,6 +1471,7 @@ def fused_experts_impl(
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape, block_shape=block_shape,
) )
...@@ -1173,8 +1517,11 @@ def fused_moe( ...@@ -1173,8 +1517,11 @@ def fused_moe(
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -1204,6 +1551,9 @@ def fused_moe( ...@@ -1204,6 +1551,9 @@ def fused_moe(
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1. w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
...@@ -1243,8 +1593,11 @@ def fused_moe( ...@@ -1243,8 +1593,11 @@ def fused_moe(
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
......
...@@ -61,6 +61,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ...@@ -61,6 +61,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
...@@ -75,6 +76,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -75,6 +76,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"modelopt": ModelOptFp8Config, "modelopt": ModelOptFp8Config,
"w8a8_int8": W8A8Int8Config, "w8a8_int8": W8A8Int8Config,
"w8a8_fp8": W8A8Fp8Config, "w8a8_fp8": W8A8Fp8Config,
"moe_wna16": MoeWNA16Config,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
} }
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py
import logging
from typing import Any, Callable, Dict, List, Optional
import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__)
class MoeWNA16Config(QuantizationConfig):
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
def __init__(
self,
linear_quant_method: str,
weight_bits: int,
group_size: int,
has_zp: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any],
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.has_zp = has_zp
self.bit8_pack_factor = 8 // self.weight_bits
self.lm_head_quantized = lm_head_quantized
self.linear_quant_method = linear_quant_method
self.full_config = full_config
self.use_marlin = False
# Avoid circular import
if self.linear_quant_method == "gptq":
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
elif self.linear_quant_method == "awq":
capability_tuple = get_device_capability()
device_capability = (
-1
if capability_tuple is None
else capability_tuple[0] * 10 + capability_tuple[1]
)
awq_min_capability = AWQConfig.get_min_capability()
if device_capability < awq_min_capability:
raise ValueError(
"The quantization method moe_wna16 + awq is not supported "
"for the current GPU. "
f"Minimum capability: {awq_min_capability}. "
f"Current capability: {device_capability}."
)
else:
raise ValueError("moe_wna16 only support gptq and awq.")
if modules_to_not_convert is None:
self.modules_to_not_convert = []
else:
self.modules_to_not_convert = modules_to_not_convert
@classmethod
def get_name(cls) -> str:
return "moe_wna16"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
def get_scaled_act_names(self) -> List[str]:
raise NotImplementedError
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
if quant_method == "gptq":
has_zp = not cls.get_from_keys(config, ["sym"])
modules_to_not_convert = []
elif quant_method == "awq":
has_zp = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
else:
raise ValueError("moe_wna16 only support gptq and awq.")
return cls(
quant_method,
weight_bits,
group_size,
has_zp,
lm_head_quantized,
modules_to_not_convert,
config,
)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
if can_convert and user_quant == "moe_wna16":
return cls.get_name()
return None
@classmethod
def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
desc_act = quant_config.get("desc_act")
capability_tuple = get_device_capability()
device_capability = (
-1
if capability_tuple is None
else capability_tuple[0] * 10 + capability_tuple[1]
)
# Avoid circular import
awq_min_capability = AWQConfig.get_min_capability()
gptq_compatible = quant_method == "gptq" and not desc_act and num_bits in [4, 8]
awq_compatible = (
quant_method == "awq"
and num_bits == 4
and device_capability >= awq_min_capability
)
return gptq_compatible or awq_compatible
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
# avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
elif isinstance(layer, LinearBase):
if self.linear_quant_method == "gptq":
if self.use_marlin:
return GPTQMarlinConfig.from_config(
self.full_config
).get_quant_method(layer, prefix)
else:
return GPTQConfig.from_config(self.full_config).get_quant_method(
layer, prefix
)
elif self.linear_quant_method == "awq":
return AWQConfig.from_config(self.full_config).get_quant_method(
layer, prefix
)
else:
raise ValueError("moe_wna16 only support gptq and awq.")
elif isinstance(layer, FusedMoE):
return MoeWNA16Method(self)
return None
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
class MoeWNA16Method:
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
Args:
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
"""
def __new__(cls, *args, **kwargs):
# avoid circular import
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: MoeWNA16Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
layer.quant_config = self.quant_config
bit8_pack_factor = self.quant_config.bit8_pack_factor
group_size = self.quant_config.group_size
group_size_div_factor = 1
# make intermediate_size and hidden_size diviable by group_size
# we reduce the group size to ensure that
# and we would repeat the loaded_weight later
while intermediate_size_per_partition % group_size or hidden_size % group_size:
group_size = group_size // 2
group_size_div_factor *= 2
assert group_size >= 32
layer.group_size = group_size
layer.group_size_div_factor = group_size_div_factor
strategy = FusedMoeWeightScaleSupported.GROUP.value
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False})
assert "weight_loader" in extra_weight_attrs
weight_loader = extra_weight_attrs["weight_loader"]
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader)
extra_weight_attrs["weight_loader"] = wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // bit8_pack_factor,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
# down_proj (row parallel)
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // bit8_pack_factor,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
w13_scales = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // group_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
w2_scales = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // group_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
if self.quant_config.has_zp:
w13_qzeros = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition // bit8_pack_factor,
hidden_size // group_size,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_qzeros = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size // bit8_pack_factor,
intermediate_size_per_partition // group_size,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
if self.quant_config.linear_quant_method == "gptq":
# some param are unused, but we need to init them in order to
# load weights
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
if not self.quant_config.has_zp:
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
for key in invalid_param_keys:
param = torch.nn.Parameter(
torch.empty((0,), dtype=torch.int32), requires_grad=False
)
layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
# avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
)
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
no_combine=no_combine,
)
@staticmethod
def get_weight_loader(layer, weight_loader):
def convert_awq_tensor(tensor, tensor_type):
# convert awq qweight/qzeros to a standard format (assume int4)
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
# qzeros: (k // group_size, n // pack_factor_bit32) ->
# (n // pack_factor_bit8, k // group_size)
# pack_factor_bit32 = 32 // weight_bits
# pack_factor_bit8 = 8 // weight_bits
# 0. suppose origin shape (a, b), dtype int32
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
size0 = tensor.size(0)
tensor = tensor.view(torch.uint8)
# 2. unpack to uint4 (only when weight_bits == 4)
# shape (a, 4 * b) -> (a, 4 * b, 2)
shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
# 3. change order, see
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
# shape -> (a, 4 * b * pack_factor_bit8)
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
tensor = tensor.view(size0, -1)
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
tensor = tensor.T.contiguous()
# 5. repack (only when weight_bits == 4)
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
# qzeros shape -> (4 * b, a)
if tensor_type == "qweight":
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
elif tensor_type == "qzeros":
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
return tensor
def convert_gptq_int4_qzeros(tensor):
tensor = tensor.view(torch.uint8)
shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
tensor = tensor + 1
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
return tensor
def moe_wna16_weight_loader(
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
):
if "g_idx" in weight_name:
return
if not layer.quant_config.has_zp and "qzeros" in weight_name:
return
device = get_tp_group().device
tp_rank = get_tensor_model_parallel_rank()
loaded_weight = loaded_weight.to(device)
shard_size = layer.intermediate_size_per_partition
# convert gptq and awq weight to a standard format
if layer.quant_config.linear_quant_method == "awq":
assert layer.quant_config.weight_bits == 4
if "weight" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qweight")
elif "zeros" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
else:
loaded_weight = loaded_weight.T
elif layer.quant_config.linear_quant_method == "gptq":
assert layer.quant_config.weight_bits in [4, 8]
if "weight" in weight_name:
loaded_weight = loaded_weight.T.contiguous().view(torch.uint8)
elif "zeros" in weight_name:
# add 1 to gptq qzeros to align with awq
loaded_weight = loaded_weight.view(torch.uint8)
if layer.quant_config.weight_bits == 4:
loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T
else:
loaded_weight = loaded_weight.T + 1
else:
loaded_weight = loaded_weight.T
# repeat the qzeros/scales to fit new group size
if (
layer.group_size_div_factor > 1
and "qzeros" in weight_name
or "scales" in weight_name
):
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor, 1
)
if "w13_qzeros" in weight_name:
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
tp_rank
]
if shard_id == "w1":
param.data[expert_id, : shard_size // 2] = tensor
else:
param.data[expert_id, shard_size // 2 :] = tensor
elif "w2_qzeros" in weight_name:
param.data[expert_id] = loaded_weight.view(
loaded_weight.size(0), layer.tp_size, -1
)[:, tp_rank]
else:
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
return moe_wna16_weight_loader
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from types import MappingProxyType from types import MappingProxyType
from typing import List, Mapping, Tuple, Union from typing import List, Mapping, Optional, Tuple, Union
import torch import torch
......
...@@ -496,6 +496,7 @@ class ServerArgs: ...@@ -496,6 +496,7 @@ class ServerArgs:
"modelopt", "modelopt",
"w8a8_int8", "w8a8_int8",
"w8a8_fp8", "w8a8_fp8",
"moe_wna16",
], ],
help="The quantization method.", help="The quantization method.",
) )
......
from typing import Optional
import pytest
import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
def quantize_weights(
w: torch.Tensor,
quant_type: str,
group_size: Optional[int],
zero_points: bool = False,
ref_zero_points_after_scales: bool = False,
):
assert quant_type in ["w4a16", "w4a16b8", "w8a16", "w8a16b128"]
assert not zero_points or group_size is not None, (
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
if group_size == -1:
group_size = size_k
# Reshape to [groupsize, -1]
if group_size is not None and group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
if quant_type == "w4a16":
max_q_val = 15
min_q_val = 0
elif quant_type == "w4a16b8":
max_q_val = 7
min_q_val = -1
elif quant_type == "w8a16":
max_q_val = 255
min_q_val = 0
elif quant_type == "w8a16b128":
max_q_val = 127
min_q_val = -128
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
maybe_w_zp = None
if group_size is not None:
if zero_points:
w_s = (max_val - min_val).clamp(min=1e-5) / max_q_val
maybe_w_zp = (
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
)
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
)
# Quantize
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and maybe_w_zp is not None:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type == "w4a16b8":
w_q += 8
elif quant_type == "w8a16b128":
w_q += 128
# Restore original shapes
if group_size is not None and group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
w_s = w_s.reshape((-1, size_n)).contiguous()
if maybe_w_zp is not None:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s if group_size is not None else None,
maybe_w_zp,
)
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
# fork from https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_moe.py
@pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [8]) # [4, 8])
def test_fused_moe_wn16(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
group_size: int,
has_zp: bool,
weight_bits: int,
):
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
if weight_bits == 4:
pack_factor = 2
quant_type = "w4a16" if has_zp else "w4a16b8"
elif weight_bits == 8:
pack_factor = 1
quant_type = "w8a16" if has_zp else "w8a16b128"
w1_ref = w1.clone()
w2_ref = w2.clone()
w1_qweight = torch.empty(
(e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
)
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
w1_qzeros = torch.empty(
(e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
)
w2_qzeros = torch.empty(
(e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
)
for i in range(e * 2):
expert_id = i % e
if i // e == 0:
w, w_ref, w_qweight, w_scales, w_qzeros = (
w1,
w1_ref,
w1_qweight,
w1_scales,
w1_qzeros,
)
else:
w, w_ref, w_qweight, w_scales, w_qzeros = (
w2,
w2_ref,
w2_qweight,
w2_scales,
w2_qzeros,
)
weight, qweight, scales, qzeros = quantize_weights(
w[expert_id].T, quant_type, group_size, has_zp, False
)
weight = weight.T
qweight = qweight.T.contiguous().to(torch.uint8)
scales = scales.T
if has_zp:
qzeros = qzeros.T.contiguous().to(torch.uint8)
if weight_bits == 4:
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
if has_zp:
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
w_ref[expert_id] = weight
w_qweight[expert_id] = qweight
w_scales[expert_id] = scales
if has_zp:
w_qzeros[expert_id] = qzeros
triton_output = fused_moe(
a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size],
)
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment