Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
825432fc
Unverified
Commit
825432fc
authored
Oct 15, 2025
by
Jinwu
Committed by
GitHub
Oct 14, 2025
Browse files
[1/N]Support DeepSeek-R1 w4a8 normal deepep (#8247)
Co-authored-by:
Hank Han
<
hanhan7630@outlook.com
>
parent
a40229f6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
334 additions
and
7 deletions
+334
-7
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
+196
-0
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+21
-2
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+10
-4
python/sglang/srt/layers/moe/utils.py
python/sglang/srt/layers/moe/utils.py
+4
-0
python/sglang/srt/layers/quantization/w4afp8.py
python/sglang/srt/layers/quantization/w4afp8.py
+47
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
test/srt/quant/test_w4a8_deepseek_v3.py
test/srt/quant/test_w4a8_deepseek_v3.py
+55
-0
No files found.
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
View file @
825432fc
# SPDX-License-Identifier: Apache-2.0
"""Cutlass W4A8 MoE kernel."""
import
logging
from
typing
import
Optional
import
torch
...
...
@@ -11,6 +12,9 @@ from sgl_kernel import (
)
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
deepep_permute_triton_kernel
,
deepep_post_reorder_triton_kernel
,
deepep_run_moe_deep_preprocess
,
post_reorder_triton_kernel_for_cutlass_moe
,
pre_reorder_triton_kernel_for_cutlass_moe
,
run_moe_ep_preproess
,
...
...
@@ -201,3 +205,195 @@ def cutlass_w4a8_moe(
BLOCK_SIZE
=
512
,
)
return
output
def
cutlass_w4a8_moe_deepep_normal
(
a
:
torch
.
Tensor
,
w1_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids_
:
torch
.
Tensor
,
a_strides1
:
torch
.
Tensor
,
b_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
a_strides2
:
torch
.
Tensor
,
b_strides2
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
s_strides13
:
torch
.
Tensor
,
s_strides2
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
problem_sizes1
:
torch
.
Tensor
,
problem_sizes2
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
Shape: [num_experts, N * 2, K // 2]
(the weights are passed transposed and int4-packed)
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
Shape: [num_experts, K, N // 2]
(the weights are passed transposed and int4-packed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts, K // 512, N * 8]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts, N // 512, K * 4]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [1, K]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [1, N]
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
Returns:
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
"""
assert
topk_weights
.
shape
==
topk_ids_
.
shape
,
"topk shape mismatch"
assert
w1_q
.
dtype
==
torch
.
int8
assert
w2_q
.
dtype
==
torch
.
int8
assert
a
.
shape
[
1
]
//
2
==
w1_q
.
shape
[
2
],
"Hidden size mismatch w1"
assert
w1_q
.
shape
[
2
]
*
2
==
w2_q
.
shape
[
1
],
"Hidden size mismatch w2"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"A Strides 1 expert number mismatch"
assert
b_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"B Strides 1 expert number mismatch"
assert
a_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"A Strides 2 expert number mismatch"
assert
b_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"B Strides 2 expert number mismatch"
num_experts
=
w1_q
.
size
(
0
)
m
=
a
.
size
(
0
)
k
=
w1_q
.
size
(
2
)
*
2
# w1_q is transposed and packed
n
=
w2_q
.
size
(
2
)
*
2
# w2_q is transposed and packed
topk
=
topk_ids_
.
size
(
1
)
num_experts
=
w1_q
.
size
(
0
)
m
=
a
.
size
(
0
)
k
=
w1_q
.
size
(
2
)
*
2
n
=
w2_q
.
size
(
2
)
*
2
topk
=
topk_ids_
.
size
(
1
)
device
=
a
.
device
reorder_topk_ids
,
src2dst
,
_
=
deepep_run_moe_deep_preprocess
(
topk_ids_
,
num_experts
)
num_total_tokens
=
reorder_topk_ids
.
numel
()
gateup_input_pre_reorder
=
torch
.
empty
(
(
int
(
num_total_tokens
),
a
.
shape
[
1
]),
device
=
device
,
dtype
=
a
.
dtype
,
)
deepep_permute_triton_kernel
[(
a
.
shape
[
0
],)](
a
,
gateup_input_pre_reorder
,
src2dst
,
topk_ids_
.
to
(
torch
.
int64
),
None
,
topk
,
a
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
gateup_input
=
torch
.
empty
(
gateup_input_pre_reorder
.
shape
,
dtype
=
torch
.
float8_e4m3fn
,
device
=
device
)
sgl_per_tensor_quant_fp8
(
gateup_input_pre_reorder
,
gateup_input
,
a1_scale
.
float
(),
True
)
del
gateup_input_pre_reorder
local_topk_ids
=
topk_ids_
local_topk_ids
=
(
torch
.
where
(
local_topk_ids
==
-
1
,
num_experts
,
topk_ids_
).
to
(
torch
.
int32
)
).
contiguous
()
a_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
get_cutlass_w4a8_moe_mm_data
(
local_topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
num_experts
,
n
,
k
,
)
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
c2
=
torch
.
zeros
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
cutlass_w4a8_moe_mm
(
c1
,
gateup_input
,
w1_q
,
a1_scale
.
float
(),
w1_scale
,
expert_offsets
[:
-
1
],
problem_sizes1
,
a_strides1
,
b_strides1
,
c_strides1
,
s_strides13
,
128
,
topk
,
)
intermediate
=
torch
.
empty
((
m
*
topk
,
n
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
silu_and_mul
(
c1
,
intermediate
)
intermediate_q
=
torch
.
empty
(
intermediate
.
shape
,
dtype
=
torch
.
float8_e4m3fn
,
device
=
device
)
sgl_per_tensor_quant_fp8
(
intermediate
,
intermediate_q
,
a2_scale
.
float
(),
True
)
cutlass_w4a8_moe_mm
(
c2
,
intermediate_q
,
w2_q
,
a2_scale
.
float
(),
w2_scale
,
expert_offsets
[:
-
1
],
problem_sizes2
,
a_strides2
,
b_strides2
,
c_strides2
,
s_strides2
,
128
,
topk
,
)
num_tokens
=
src2dst
.
shape
[
0
]
//
topk
output
=
torch
.
empty
(
(
num_tokens
,
c2
.
shape
[
1
]),
device
=
c2
.
device
,
dtype
=
torch
.
bfloat16
,
)
deepep_post_reorder_triton_kernel
[(
num_tokens
,)](
c2
,
output
,
src2dst
,
topk_ids_
,
topk_weights
,
topk
,
c2
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
return
output
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
825432fc
...
...
@@ -29,6 +29,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
CUTEDSL_MOE_NVFP4_DISPATCH
,
ModelOptNvFp4FusedMoEMethod
,
)
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
,
W4AFp8MoEMethod
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.single_batch_overlap
import
DownGemmOverlapArgs
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
...
...
@@ -96,6 +97,11 @@ class DeepEPMoE(FusedMoE):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
use_fp8_w8a8
=
True
self
.
fp8_dtype
=
torch
.
float8_e4m3fn
self
.
use_w4afp8
=
False
elif
isinstance
(
quant_config
,
W4AFp8Config
):
self
.
use_w4afp8
=
True
self
.
use_fp8_w8a8
=
False
self
.
use_block_quant
=
False
else
:
self
.
use_fp8_w8a8
=
False
self
.
use_block_quant
=
False
...
...
@@ -142,7 +148,7 @@ class DeepEPMoE(FusedMoE):
self
.
w13_weight
,
(
self
.
w13_weight_scale_inv
if
self
.
use_block_quant
if
self
.
use_block_quant
or
self
.
use_w4afp8
else
self
.
w13_weight_scale
),
)
...
...
@@ -150,7 +156,7 @@ class DeepEPMoE(FusedMoE):
self
.
w2_weight
,
(
self
.
w2_weight_scale_inv
if
self
.
use_block_quant
if
self
.
use_block_quant
or
self
.
use_w4afp8
else
self
.
w2_weight_scale
),
)
...
...
@@ -210,6 +216,8 @@ class DeepEPMoE(FusedMoE):
assert
DispatchOutputChecker
.
format_is_deepep
(
dispatch_output
)
return
self
.
forward_npu
(
dispatch_output
)
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
):
if
self
.
use_w4afp8
:
return
self
.
forward_cutlass_w4afp8
(
dispatch_output
)
assert
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
...
...
@@ -438,6 +446,17 @@ class DeepEPMoE(FusedMoE):
)
return
output
def
forward_cutlass_w4afp8
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
):
assert
self
.
moe_runner_config
.
activation
==
"silu"
assert
isinstance
(
self
.
quant_method
,
W4AFp8MoEMethod
)
return
self
.
quant_method
.
apply_deepep_normal
(
layer
=
self
,
dispatch_output
=
dispatch_output
,
)
def
forward_deepgemm_masked
(
self
,
dispatch_output
:
DeepEPLLOutput
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
825432fc
...
...
@@ -14,7 +14,12 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutput
,
DispatchOutputFormat
,
)
from
sglang.srt.layers.moe.utils
import
DeepEPMode
,
get_deepep_config
,
is_tbo_enabled
from
sglang.srt.layers.moe.utils
import
(
DeepEPMode
,
get_deepep_config
,
get_moe_runner_backend
,
is_tbo_enabled
,
)
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.utils
import
(
get_bool_env_var
,
...
...
@@ -340,7 +345,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_weights
:
torch
.
Tensor
,
):
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
if
(
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
not
get_moe_runner_backend
().
is_cutlass
()
):
# TODO hard code 128 block quant,use fp8 communication
hidden_states
=
sglang_per_token_group_quant_fp8
(
hidden_states
,
...
...
@@ -386,7 +394,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish
=
self
.
async_finish
,
allocate_on_comm_stream
=
previous_event
is
not
None
,
)
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
# However, doing this would incur an unknown synchronization error, but keeping
# `handle` as a member variable works.
...
...
@@ -412,7 +419,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment
=
128
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
else
1
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
num_recv_tokens_per_expert
,
num_tokens_per_rank
=
num_tokens_per_rank
,
...
...
python/sglang/srt/layers/moe/utils.py
View file @
825432fc
...
...
@@ -55,6 +55,7 @@ class MoeRunnerBackend(Enum):
FLASHINFER_CUTLASS
=
"flashinfer_cutlass"
FLASHINFER_MXFP4
=
"flashinfer_mxfp4"
FLASHINFER_CUTEDSL
=
"flashinfer_cutedsl"
CUTLASS
=
"cutlass"
def
is_auto
(
self
):
return
self
==
MoeRunnerBackend
.
AUTO
...
...
@@ -80,6 +81,9 @@ class MoeRunnerBackend(Enum):
def
is_flashinfer_mxfp4
(
self
):
return
self
==
MoeRunnerBackend
.
FLASHINFER_MXFP4
def
is_cutlass
(
self
):
return
self
==
MoeRunnerBackend
.
CUTLASS
class
DeepEPMode
(
Enum
):
...
...
python/sglang/srt/layers/quantization/w4afp8.py
View file @
825432fc
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch.nn
import
Module
...
...
@@ -21,8 +21,10 @@ from sglang.srt.utils import is_npu, set_weight_attrs
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe
import
MoeRunnerConfig
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
EPMoE
from
sglang.srt.layers.moe.token_dispatcher
import
(
CombineInput
,
DeepEPNormalOutput
,
StandardDispatchOutput
,
)
...
...
@@ -326,3 +328,47 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
if
self
.
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
self
.
moe_runner_config
.
routed_scaling_factor
return
StandardCombineInput
(
hidden_states
=
output
)
def
apply_deepep_normal
(
self
,
layer
:
DeepEPMoE
,
dispatch_output
:
DeepEPNormalOutput
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
(
cutlass_w4a8_moe_deepep_normal
,
)
hidden_states
,
topk_idx
,
topk_weights
=
(
dispatch_output
.
hidden_states
,
dispatch_output
.
topk_idx
,
dispatch_output
.
topk_weights
,
)
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
=
hidden_states
[
0
]
num_tokens
=
hidden_states
.
shape
[
0
]
if
num_tokens
>
0
:
return
cutlass_w4a8_moe_deepep_normal
(
hidden_states
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w13_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
topk_weights
,
topk_idx
,
self
.
a_strides1
,
self
.
b_strides1
,
self
.
c_strides1
,
self
.
a_strides2
,
self
.
b_strides2
,
self
.
c_strides2
,
self
.
s_strides13
,
self
.
s_strides2
,
self
.
expert_offsets
,
self
.
problem_sizes1
,
self
.
problem_sizes2
,
layer
.
w13_input_scale
,
layer
.
w2_input_scale
,
)
else
:
return
hidden_states
python/sglang/srt/server_args.py
View file @
825432fc
...
...
@@ -137,6 +137,7 @@ MOE_RUNNER_BACKEND_CHOICES = [
"flashinfer_cutlass"
,
"flashinfer_mxfp4"
,
"flashinfer_cutedsl"
,
"cutlass"
,
]
...
...
test/srt/quant/test_w4a8_deepseek_v3.py
View file @
825432fc
...
...
@@ -118,5 +118,60 @@ class TestDeepseekV3W4Afp8Mtp(CustomTestCase):
self
.
assertGreater
(
avg_spec_accept_length
,
2.9
)
class
TestDeepseekV3W4Afp8DeepepNormal
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
try_cached_model
(
DEFAULT_DEEPSEEK_W4AFP8_MODEL_FOR_TEST
)
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
[
"--tp"
,
"8"
,
"--trust-remote-code"
,
"--ep-size"
,
"8"
,
"--cuda-graph-bs"
,
"256"
,
"--disable-radix-cache"
,
"--moe-a2a-backend"
,
"deepep"
,
"--deepep-mode"
,
"normal"
,
"--dp"
,
"8"
,
"--enable-dp-attention"
,
"--moe-runner-backend"
,
"cutlass"
,
]
if
not
is_in_amd_ci
():
other_args
+=
[
"--mem-frac"
,
"0.7"
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
,
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Eval accuracy of GSM8K:
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.92
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment