Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c0352f4a
Commit
c0352f4a
authored
Oct 22, 2025
by
lizhigong
Committed by
maxiao1
Oct 25, 2025
Browse files
adapt w4a8 marlin deepep dp ep
parent
32b1ccaf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
383 additions
and
32 deletions
+383
-32
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+357
-27
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+26
-5
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
100644 → 100755
View file @
c0352f4a
...
...
@@ -3,6 +3,7 @@ from __future__ import annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
import
torch
from
sglang.srt
import
single_batch_overlap
...
...
@@ -54,7 +55,286 @@ if _use_aiter:
logger
=
logging
.
getLogger
(
__name__
)
class
DeepEPMoE
(
FusedMoE
):
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@
torch
.
compile
def
_cast_to_e8m0_with_rounding_up
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
temp
=
x
.
to
(
torch
.
float32
).
view
(
torch
.
int32
)
exp
=
torch
.
bitwise_right_shift
(
temp
,
23
)
mant
=
torch
.
bitwise_and
(
temp
,
0x7FFFFF
)
is_ru
=
torch
.
logical_and
(
torch
.
logical_and
((
mant
>
0
),
(
exp
!=
0xFE
)),
~
torch
.
logical_and
((
exp
==
0
),
(
mant
<=
0x400000
)),
)
exp
=
torch
.
where
(
is_ru
,
exp
+
1
,
exp
)
new_x
=
exp
.
to
(
torch
.
uint8
).
view
(
torch
.
int
)
return
new_x
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
class
EPMoE
(
FusedMoE
):
"""
MoE Expert Parallel Impl
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
num_fused_shared_experts
:
int
=
0
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
gemm1_alpha
:
Optional
[
float
]
=
None
,
gemm1_clamp_limit
:
Optional
[
float
]
=
None
,
with_bias
:
bool
=
False
,
):
super
().
__init__
(
num_experts
=
num_experts
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_fused_shared_experts
=
num_fused_shared_experts
,
layer_id
=
layer_id
,
top_k
=
top_k
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
prefix
=
prefix
,
activation
=
activation
,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor
=
routed_scaling_factor
,
gemm1_alpha
=
gemm1_alpha
,
gemm1_clamp_limit
=
gemm1_clamp_limit
,
with_bias
=
with_bias
,
)
self
.
intermediate_size
=
intermediate_size
if
isinstance
(
quant_config
,
Fp8Config
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
quant_method
.
quant_config
.
weight_block_size
if
self
.
use_block_quant
else
None
)
self
.
use_fp8_w8a8
=
True
self
.
fp8_dtype
=
torch
.
float8_e4m3fn
self
.
activation_scheme
=
quant_config
.
activation_scheme
self
.
use_w4a8_marlin
=
False
elif
isinstance
(
quant_config
,
SlimQuantW4A8Int8MarlinConfig
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
quant_method
.
quant_config
.
weight_block_size
if
self
.
use_block_quant
else
None
)
self
.
use_fp8_w8a8
=
False
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
True
else
:
self
.
use_fp8_w8a8
=
False
self
.
use_block_quant
=
False
self
.
block_shape
=
None
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
return
self
.
forward_deepgemm
(
hidden_states
,
topk_output
)
else
:
return
super
().
forward
(
hidden_states
,
topk_output
)
def
forward_deepgemm
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
):
self
.
w13_weight_fp8
=
(
self
.
w13_weight
,
(
self
.
w13_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w13_weight_scale
),
)
self
.
w2_weight_fp8
=
(
self
.
w2_weight
,
self
.
w2_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w2_weight_scale
,
)
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
hidden_states_shape
=
hidden_states
.
shape
hidden_states_dtype
=
hidden_states
.
dtype
hidden_states_device
=
hidden_states
.
device
topk_weights
,
topk_ids
,
_
=
topk_output
if
not
self
.
use_block_quant
:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size
=
128
w13_weight_scale_n
=
2
*
(
(
self
.
intermediate_size
+
scale_block_size
-
1
)
//
scale_block_size
)
w13_weight_scale_k
=
(
hidden_states_shape
[
-
1
]
+
scale_block_size
-
1
)
//
scale_block_size
w13_weight_scale
=
(
self
.
w13_weight_scale
.
unsqueeze
(
1
)
.
repeat_interleave
(
w13_weight_scale_n
,
dim
=
1
)
.
unsqueeze
(
2
)
.
repeat_interleave
(
w13_weight_scale_k
,
dim
=
2
)
)
self
.
w13_weight_fp8
=
(
self
.
w13_weight
,
w13_weight_scale
,
)
w2_weight_scale_n
=
(
hidden_states_shape
[
-
1
]
+
scale_block_size
-
1
)
//
scale_block_size
w2_weight_scale_k
=
(
self
.
intermediate_size
+
scale_block_size
-
1
)
//
scale_block_size
w2_weight_scale
=
(
self
.
w2_weight_scale
.
unsqueeze
(
1
)
.
repeat_interleave
(
w2_weight_scale_n
,
dim
=
1
)
.
unsqueeze
(
2
)
.
repeat_interleave
(
w2_weight_scale_k
,
dim
=
2
)
)
self
.
w2_weight_fp8
=
(
self
.
w2_weight
,
w2_weight_scale
,
)
# PreReorder
m_max
,
masked_m
,
expected_m
,
src2dst
,
gateup_input
,
gateup_input_scale
=
(
moe_ep_deepgemm_preprocess
(
topk_ids
,
self
.
num_experts
,
hidden_states
,
self
.
top_k
,
self
.
start_expert_id
,
self
.
end_expert_id
,
self
.
block_shape
,
)
)
dispose_tensor
(
hidden_states
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
b
,
s_mn
,
s_k
=
gateup_input_scale
.
shape
assert
(
s_mn
%
4
==
0
and
s_k
%
4
==
0
),
f
"scales must be aligned to 4, but got (
{
b
}
,
{
s_mn
}
,
{
s_k
}
)"
# GroupGemm-0
gateup_input_fp8
=
(
gateup_input
,
(
_cast_to_e8m0_with_rounding_up
(
gateup_input_scale
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
deep_gemm_wrapper
.
get_mn_major_tma_aligned_tensor
(
gateup_input_scale
)
),
)
num_groups
,
m
,
k
=
gateup_input_fp8
[
0
].
size
()
n
=
self
.
w13_weight
.
size
(
1
)
gateup_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
gateup_input_fp8
,
self
.
w13_weight_fp8
,
gateup_output
,
masked_m
,
expected_m
,
)
del
gateup_input
del
gateup_input_fp8
# Act
down_input
=
torch
.
empty
(
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
],
gateup_output
.
shape
[
2
]
//
2
,
),
device
=
hidden_states_device
,
dtype
=
self
.
fp8_dtype
,
)
scale_block_size
=
128
down_input_scale
=
torch
.
empty
(
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
],
gateup_output
.
shape
[
2
]
//
2
//
scale_block_size
,
),
device
=
hidden_states_device
,
dtype
=
torch
.
float32
,
)
silu_and_mul_masked_post_quant_fwd
(
gateup_output
,
down_input
,
down_input_scale
,
scale_block_size
,
masked_m
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
del
gateup_output
# GroupGemm-1
n
=
self
.
w2_weight
.
size
(
1
)
down_input_fp8
=
(
down_input
,
(
down_input_scale
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
deep_gemm_wrapper
.
get_mn_major_tma_aligned_tensor
(
down_input_scale
)
),
)
down_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
down_input_fp8
,
self
.
w2_weight_fp8
,
down_output
,
masked_m
,
expected_m
,
)
del
down_input
del
down_input_fp8
# PostReorder
output
=
torch
.
empty
(
hidden_states_shape
,
dtype
=
hidden_states_dtype
,
device
=
hidden_states_device
)
post_reorder_triton_kernel
[(
hidden_states_shape
[
0
],)](
down_output
,
output
,
src2dst
,
topk_ids
,
topk_weights
,
self
.
start_expert_id
,
self
.
end_expert_id
,
self
.
top_k
,
hidden_states_shape
[
1
],
m_max
*
self
.
start_expert_id
,
BLOCK_SIZE
=
512
,
)
if
self
.
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
self
.
moe_runner_config
.
routed_scaling_factor
return
output
class
DeepEPMoE
(
EPMoE
):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
Mooncake EP shares the same class, as they expose the same interface.
...
...
@@ -106,11 +386,28 @@ class DeepEPMoE(FusedMoE):
self
.
deepep_mode
=
get_deepep_mode
()
if
self
.
deepep_mode
.
enable_low_latency
()
and
not
_is_npu
:
# NPU supports low_latency deepep without deepgemm
assert
(
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
),
f
"DeepEP
{
self
.
deepep_mode
}
mode requires deep_gemm"
# TODO: move to the beginning of the file
from
sglang.srt.distributed.parallel_state
import
get_tp_group
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
self
.
deepep_dispatcher
=
MaybeTboDeepEPDispatcher
(
group
=
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
num_experts
=
self
.
num_experts
,
num_local_experts
=
self
.
num_local_experts
,
hidden_size
=
hidden_size
,
params_dtype
=
params_dtype
,
deepep_mode
=
self
.
deepep_mode
,
async_finish
=
True
,
# TODO
return_recv_hook
=
True
,
)
# if self.deepep_mode.enable_low_latency() and not _is_npu:
# # NPU supports low_latency deepep without deepgemm
# assert (
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if
_use_aiter
:
# expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
...
...
@@ -124,23 +421,23 @@ class DeepEPMoE(FusedMoE):
)
# the last one is invalid rank_id
self
.
expert_mask
[:
-
1
]
=
1
elif
not
_is_npu
:
self
.
w13_weight_fp8
=
(
self
.
w13_weight
,
(
self
.
w13_weight_scale_inv
if
self
.
use_block_quant
or
self
.
use_w4afp8
else
self
.
w13_weight_scale
),
)
self
.
w2_weight_fp8
=
(
self
.
w2_weight
,
(
self
.
w2_weight_scale_inv
if
self
.
use_block_quant
or
self
.
use_w4afp8
else
self
.
w2_weight_scale
),
)
#
elif not _is_npu:
#
self.w13_weight_fp8 = (
#
self.w13_weight,
#
(
#
self.w13_weight_scale_inv
#
if self.use_block_quant
#
else self.w13_weight_scale
#
),
#
)
#
self.w2_weight_fp8 = (
#
self.w2_weight,
#
(
#
self.w2_weight_scale_inv
#
if self.use_block_quant
#
else self.w2_weight_scale
#
),
#
)
def
forward
(
self
,
...
...
@@ -187,10 +484,15 @@ class DeepEPMoE(FusedMoE):
assert
DispatchOutputChecker
.
format_is_deepep
(
dispatch_output
)
return
self
.
forward_npu
(
dispatch_output
)
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
):
if
self
.
use_w4afp8
:
return
self
.
forward_cutlass_w4afp8
(
dispatch_output
)
assert
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
elif
self
.
use_w4a8_marlin
:
return
self
.
forward_deepgemm_w4a8_marlin_contiguous
(
dispatch_output
)
else
:
raise
ValueError
(
f
"Dispatch output is not supported"
)
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
if
(
get_moe_runner_backend
().
is_flashinfer_cutedsl
()
...
...
@@ -255,6 +557,34 @@ class DeepEPMoE(FusedMoE):
expert_mask
=
self
.
expert_mask
,
)
def
forward_deepgemm_w4a8_marlin_contiguous
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
):
hidden_states_int8
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
=
(
dispatch_output
)
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
# if num_recv_tokens_per_expert is None:
return
hidden_states_int8
.
bfloat16
()
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=topk_weights,
# topk_ids=topk_idx,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales if self.use_int8_dispatch else None
# # routed_scaling_factor=self.routed_scaling_factor,
# )
# return expert_output
def
forward_deepgemm_contiguous
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
100644 → 100755
View file @
c0352f4a
...
...
@@ -460,11 +460,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
overlap_args
:
Optional
[
"CombineOverlapArgs"
],
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
or
_use_aiter
or
_is_npu
:
output
=
hidden_states
else
:
raise
NotImplementedError
()
# triton runner was supported but it's temporarily disabled
#if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output
=
hidden_states
# else:
# if hidden_states.shape[0] > 0:
# num_tokens = self.src2dst.shape[0] // self.router_topk
# output = torch.empty(
# (num_tokens, hidden_states.shape[1]),
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
# deepep_post_reorder_triton_kernel[(num_tokens,)](
# hidden_states,
# output,
# self.src2dst,
# topk_idx,
# topk_weights,
# self.router_topk,
# hidden_states.shape[1],
# BLOCK_SIZE=512,
# )
# else:
# output = torch.zeros(
# (0, hidden_states.shape[1]),
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
output
,
previous_event
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment