Unverified Commit c7f25446 authored by HandH1998's avatar HandH1998 Committed by GitHub
Browse files

[Feature] DeepSeek V3/R1 INT8 Quantization (channel-wise) (#3888)


Co-authored-by: default avataryych0745 <1398089567@qq.com>
Co-authored-by: default avatarsleepcoo <sleepcoo@gmail.com>
Co-authored-by: default avatarb0urnee <2769086541@qq.com>
parent 63ee26d1
...@@ -15,7 +15,10 @@ from vllm import _custom_ops as ops ...@@ -15,7 +15,10 @@ from vllm import _custom_ops as ops
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8 from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8,
per_token_quant_int8,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
...@@ -117,6 +120,7 @@ def fused_moe_kernel( ...@@ -117,6 +120,7 @@ def fused_moe_kernel(
- expert_ids: A tensor containing the indices of the expert for each - expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for block. It determines which expert matrix from B should be used for
each block in A. each block in A.
This kernel performs the multiplication of a token by its corresponding This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by `sorted_token_ids` by expert index and padding ensures divisibility by
...@@ -167,17 +171,38 @@ def fused_moe_kernel( ...@@ -167,17 +171,38 @@ def fused_moe_kernel(
) )
b_scale = tl.load(b_scale_ptrs) b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8 or use_int8_w8a8: if use_fp8_w8a8:
# block-wise
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n offs_bsn = offs_bn // group_n
b_scale_ptrs = ( b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
) )
# tensor-wise
else: else:
a_scale = tl.load(a_scale_ptr) a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts) b_scale = tl.load(b_scale_ptr + off_experts)
if use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
# channel-wise
else:
# Load per-column scale for weights
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
# ----------------------------------------------------------- # -----------------------------------------------------------
# Iterate to compute a block of the C matrix. # Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
...@@ -217,9 +242,13 @@ def fused_moe_kernel( ...@@ -217,9 +242,13 @@ def fused_moe_kernel(
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else: else:
# fix out of shared memory issue
if use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
else: else:
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block. # Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
...@@ -497,9 +526,11 @@ def invoke_fused_moe_kernel( ...@@ -497,9 +526,11 @@ def invoke_fused_moe_kernel(
if use_fp8_w8a8: if use_fp8_w8a8:
assert B_scale is not None assert B_scale is not None
if block_shape is None: if block_shape is None:
# activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale) A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else: else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2 assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1] block_n, block_k = block_shape[0], block_shape[1]
if _is_cuda: if _is_cuda:
...@@ -512,9 +543,10 @@ def invoke_fused_moe_kernel( ...@@ -512,9 +543,10 @@ def invoke_fused_moe_kernel(
elif use_int8_w8a8: elif use_int8_w8a8:
assert B_scale is not None assert B_scale is not None
if block_shape is None: if block_shape is None:
padded_size = padding_size # activation channel-wise int8 quantization
A, A_scale = ops.scaled_int8_quant(A, A_scale) A, A_scale = per_token_quant_int8(A)
else: else:
# activation block-wise int8 quantization
assert len(block_shape) == 2 assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1] block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k) A, A_scale = per_token_group_quant_int8(A, block_k)
...@@ -1060,7 +1092,6 @@ def fused_experts_impl( ...@@ -1060,7 +1092,6 @@ def fused_experts_impl(
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape, block_shape=block_shape,
) )
if activation == "silu": if activation == "silu":
if _is_cuda: if _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
......
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available, set_weight_attrs
is_cuda = is_cuda_available() is_cuda = is_cuda_available()
if is_cuda: if is_cuda:
...@@ -10,6 +10,7 @@ if is_cuda: ...@@ -10,6 +10,7 @@ if is_cuda:
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -55,9 +56,12 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -55,9 +56,12 @@ class W8A8Int8Config(QuantizationConfig):
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self) return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -81,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -81,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs **extra_weight_attrs,
): ):
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
...@@ -115,3 +119,148 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -115,3 +119,148 @@ class W8A8Int8LinearMethod(LinearMethodBase):
return int8_scaled_mm( return int8_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
) )
class W8A8Int8MoEMethod:
"""MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
)
...@@ -1202,10 +1202,9 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1202,10 +1202,9 @@ class DeepseekV2ForCausalLM(nn.Module):
weight, weight_scale, weight_block_size weight, weight_scale, weight_block_size
) )
self_attn.w_scale = scale self_attn.w_scale = scale
if ( if w.dtype == torch.int8:
hasattr(self.quant_config, "weight_block_size") if hasattr(self.quant_config, "weight_block_size"):
and w.dtype == torch.int8 # block-wise int8 need it
):
weight_block_size = self.quant_config.weight_block_size weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None: if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
...@@ -1214,6 +1213,11 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1214,6 +1213,11 @@ class DeepseekV2ForCausalLM(nn.Module):
w = int8_block_dequant( w = int8_block_dequant(
weight, weight_scale, weight_block_size weight, weight_scale, weight_block_size
).to(torch.bfloat16) ).to(torch.bfloat16)
else:
# channel-wise int8 need it
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
torch.bfloat16
)
w_kc, w_vc = w.unflatten( w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
......
...@@ -61,6 +61,7 @@ suites = { ...@@ -61,6 +61,7 @@ suites = {
"test_w8a8_quantization.py", "test_w8a8_quantization.py",
"test_fp8_kernel.py", "test_fp8_kernel.py",
"test_block_int8.py", "test_block_int8.py",
"test_int8_kernel.py",
"test_reasoning_content.py", "test_reasoning_content.py",
], ],
"nightly": [ "nightly": [
......
import itertools
import unittest
import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
C = torch.matmul(A, B) # [M, K]
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B, D = a.shape
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
# Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# Calculate routing
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
)
# Activation function
act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = per_token_quant_int8(act_out)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
)
# Apply routing weights and sum
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
class TestW8A8Int8FusedMoE(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 33]
N = [128, 1024]
K = [256, 4096]
E = [8]
TOP_KS = [2, 6]
BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed):
torch.manual_seed(seed)
# Initialize int8 quantization parameters
factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
# Input tensor
# M * K
a = torch.randn((M, K), dtype=dtype) / 10
# Generate int8 weights
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
# Generate scale for each column (per-column quantization)
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode():
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_fp8_w8a8=False, # Not using fp8
use_int8_w8a16=False, # Not using int8-w8a16
use_int8_w8a8=True, # Using int8-w8a8
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
)
# Check results
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
< 0.05
)
def test_w8a8_int8_fused_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
block_size=params[5],
dtype=params[6],
seed=params[7],
):
self._w8a8_int8_fused_moe(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment