"git@developer.sourcefind.cn:change/sglang.git" did not exist on "37d83c6e6d8a45fc6e015f7bf828863bf322d547"
Unverified Commit 825432fc authored by Jinwu's avatar Jinwu Committed by GitHub
Browse files

[1/N]Support DeepSeek-R1 w4a8 normal deepep (#8247)


Co-authored-by: default avatarHank Han <hanhan7630@outlook.com>
parent a40229f6
# SPDX-License-Identifier: Apache-2.0
"""Cutlass W4A8 MoE kernel."""
import logging
from typing import Optional
import torch
......@@ -11,6 +12,9 @@ from sgl_kernel import (
)
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_permute_triton_kernel,
deepep_post_reorder_triton_kernel,
deepep_run_moe_deep_preprocess,
post_reorder_triton_kernel_for_cutlass_moe,
pre_reorder_triton_kernel_for_cutlass_moe,
run_moe_ep_preproess,
......@@ -201,3 +205,195 @@ def cutlass_w4a8_moe(
BLOCK_SIZE=512,
)
return output
def cutlass_w4a8_moe_deepep_normal(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides2: torch.Tensor,
c_strides2: torch.Tensor,
s_strides13: torch.Tensor,
s_strides2: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
Shape: [num_experts, N * 2, K // 2]
(the weights are passed transposed and int4-packed)
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
Shape: [num_experts, K, N // 2]
(the weights are passed transposed and int4-packed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts, K // 512, N * 8]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts, N // 512, K * 4]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [1, K]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [1, N]
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
Returns:
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert w1_q.dtype == torch.int8
assert w2_q.dtype == torch.int8
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(2) * 2 # w1_q is transposed and packed
n = w2_q.size(2) * 2 # w2_q is transposed and packed
topk = topk_ids_.size(1)
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(2) * 2
n = w2_q.size(2) * 2
topk = topk_ids_.size(1)
device = a.device
reorder_topk_ids, src2dst, _ = deepep_run_moe_deep_preprocess(
topk_ids_, num_experts
)
num_total_tokens = reorder_topk_ids.numel()
gateup_input_pre_reorder = torch.empty(
(int(num_total_tokens), a.shape[1]),
device=device,
dtype=a.dtype,
)
deepep_permute_triton_kernel[(a.shape[0],)](
a,
gateup_input_pre_reorder,
src2dst,
topk_ids_.to(torch.int64),
None,
topk,
a.shape[1],
BLOCK_SIZE=512,
)
gateup_input = torch.empty(
gateup_input_pre_reorder.shape, dtype=torch.float8_e4m3fn, device=device
)
sgl_per_tensor_quant_fp8(
gateup_input_pre_reorder, gateup_input, a1_scale.float(), True
)
del gateup_input_pre_reorder
local_topk_ids = topk_ids_
local_topk_ids = (
torch.where(local_topk_ids == -1, num_experts, topk_ids_).to(torch.int32)
).contiguous()
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
get_cutlass_w4a8_moe_mm_data(
local_topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
num_experts,
n,
k,
)
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
cutlass_w4a8_moe_mm(
c1,
gateup_input,
w1_q,
a1_scale.float(),
w1_scale,
expert_offsets[:-1],
problem_sizes1,
a_strides1,
b_strides1,
c_strides1,
s_strides13,
128,
topk,
)
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
silu_and_mul(c1, intermediate)
intermediate_q = torch.empty(
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
)
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
cutlass_w4a8_moe_mm(
c2,
intermediate_q,
w2_q,
a2_scale.float(),
w2_scale,
expert_offsets[:-1],
problem_sizes2,
a_strides2,
b_strides2,
c_strides2,
s_strides2,
128,
topk,
)
num_tokens = src2dst.shape[0] // topk
output = torch.empty(
(num_tokens, c2.shape[1]),
device=c2.device,
dtype=torch.bfloat16,
)
deepep_post_reorder_triton_kernel[(num_tokens,)](
c2,
output,
src2dst,
topk_ids_,
topk_weights,
topk,
c2.shape[1],
BLOCK_SIZE=512,
)
return output
......@@ -29,6 +29,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
CUTEDSL_MOE_NVFP4_DISPATCH,
ModelOptNvFp4FusedMoEMethod,
)
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
......@@ -96,6 +97,11 @@ class DeepEPMoE(FusedMoE):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.use_w4afp8 = False
elif isinstance(quant_config, W4AFp8Config):
self.use_w4afp8 = True
self.use_fp8_w8a8 = False
self.use_block_quant = False
else:
self.use_fp8_w8a8 = False
self.use_block_quant = False
......@@ -142,7 +148,7 @@ class DeepEPMoE(FusedMoE):
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
if self.use_block_quant or self.use_w4afp8
else self.w13_weight_scale
),
)
......@@ -150,7 +156,7 @@ class DeepEPMoE(FusedMoE):
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant
if self.use_block_quant or self.use_w4afp8
else self.w2_weight_scale
),
)
......@@ -210,6 +216,8 @@ class DeepEPMoE(FusedMoE):
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
return self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
if self.use_w4afp8:
return self.forward_cutlass_w4afp8(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
......@@ -438,6 +446,17 @@ class DeepEPMoE(FusedMoE):
)
return output
def forward_cutlass_w4afp8(
self,
dispatch_output: DeepEPNormalOutput,
):
assert self.moe_runner_config.activation == "silu"
assert isinstance(self.quant_method, W4AFp8MoEMethod)
return self.quant_method.apply_deepep_normal(
layer=self,
dispatch_output=dispatch_output,
)
def forward_deepgemm_masked(
self,
dispatch_output: DeepEPLLOutput,
......
......@@ -14,7 +14,12 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
from sglang.srt.layers.moe.utils import (
DeepEPMode,
get_deepep_config,
get_moe_runner_backend,
is_tbo_enabled,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.utils import (
get_bool_env_var,
......@@ -340,7 +345,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_weights: torch.Tensor,
):
topk_idx = topk_idx.to(torch.int64)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and not get_moe_runner_backend().is_cutlass()
):
# TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8(
hidden_states,
......@@ -386,7 +394,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish,
allocate_on_comm_stream=previous_event is not None,
)
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
# However, doing this would incur an unknown synchronization error, but keeping
# `handle` as a member variable works.
......@@ -412,7 +419,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
config=DeepEPConfig.get_instance().normal_dispatch_config,
)
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank,
......
......@@ -55,6 +55,7 @@ class MoeRunnerBackend(Enum):
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
CUTLASS = "cutlass"
def is_auto(self):
return self == MoeRunnerBackend.AUTO
......@@ -80,6 +81,9 @@ class MoeRunnerBackend(Enum):
def is_flashinfer_mxfp4(self):
return self == MoeRunnerBackend.FLASHINFER_MXFP4
def is_cutlass(self):
return self == MoeRunnerBackend.CUTLASS
class DeepEPMode(Enum):
......
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch
from torch.nn import Module
......@@ -21,8 +21,10 @@ from sglang.srt.utils import is_npu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
DeepEPNormalOutput,
StandardDispatchOutput,
)
......@@ -326,3 +328,47 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return StandardCombineInput(hidden_states=output)
def apply_deepep_normal(
self,
layer: DeepEPMoE,
dispatch_output: DeepEPNormalOutput,
) -> torch.Tensor:
from sglang.srt.layers.moe.cutlass_w4a8_moe import (
cutlass_w4a8_moe_deepep_normal,
)
hidden_states, topk_idx, topk_weights = (
dispatch_output.hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
)
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
num_tokens = hidden_states.shape[0]
if num_tokens > 0:
return cutlass_w4a8_moe_deepep_normal(
hidden_states,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
topk_weights,
topk_idx,
self.a_strides1,
self.b_strides1,
self.c_strides1,
self.a_strides2,
self.b_strides2,
self.c_strides2,
self.s_strides13,
self.s_strides2,
self.expert_offsets,
self.problem_sizes1,
self.problem_sizes2,
layer.w13_input_scale,
layer.w2_input_scale,
)
else:
return hidden_states
......@@ -137,6 +137,7 @@ MOE_RUNNER_BACKEND_CHOICES = [
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
"cutlass",
]
......
......@@ -118,5 +118,60 @@ class TestDeepseekV3W4Afp8Mtp(CustomTestCase):
self.assertGreater(avg_spec_accept_length, 2.9)
class TestDeepseekV3W4Afp8DeepepNormal(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = try_cached_model(DEFAULT_DEEPSEEK_W4AFP8_MODEL_FOR_TEST)
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--tp",
"8",
"--trust-remote-code",
"--ep-size",
"8",
"--cuda-graph-bs",
"256",
"--disable-radix-cache",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--dp",
"8",
"--enable-dp-attention",
"--moe-runner-backend",
"cutlass",
]
if not is_in_amd_ci():
other_args += ["--mem-frac", "0.7"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(
self,
):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Eval accuracy of GSM8K: {metrics=}")
self.assertGreater(metrics["accuracy"], 0.92)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment