Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a0fb70e9
Commit
a0fb70e9
authored
Oct 22, 2025
by
lizhigong
Browse files
adapt w4a8 marlin deepep dp ep
parent
848c5b82
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
51 deletions
+98
-51
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+72
-25
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+26
-26
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
a0fb70e9
...
...
@@ -3,6 +3,7 @@ from __future__ import annotations
import
logging
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
import
torch
import
triton
import
triton.language
as
tl
...
...
@@ -124,7 +125,6 @@ class EPMoE(FusedMoE):
)
self
.
intermediate_size
=
intermediate_size
if
isinstance
(
quant_config
,
Fp8Config
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
...
...
@@ -135,11 +135,23 @@ class EPMoE(FusedMoE):
self
.
use_fp8_w8a8
=
True
self
.
fp8_dtype
=
torch
.
float8_e4m3fn
self
.
activation_scheme
=
quant_config
.
activation_scheme
self
.
use_w4a8_marlin
=
False
elif
isinstance
(
quant_config
,
SlimQuantW4A8Int8MarlinConfig
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
quant_method
.
quant_config
.
weight_block_size
if
self
.
use_block_quant
else
None
)
self
.
use_fp8_w8a8
=
False
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
True
else
:
self
.
use_fp8_w8a8
=
False
self
.
use_block_quant
=
False
self
.
block_shape
=
None
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
...
...
@@ -386,11 +398,11 @@ class DeepEPMoE(EPMoE):
return_recv_hook
=
True
,
)
if
self
.
deepep_mode
.
enable_low_latency
()
and
not
_is_npu
:
# NPU supports low_latency deepep without deepgemm
assert
(
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
),
f
"DeepEP
{
self
.
deepep_mode
}
mode requires deep_gemm"
#
if self.deepep_mode.enable_low_latency() and not _is_npu:
#
# NPU supports low_latency deepep without deepgemm
#
assert (
#
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
#
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if
_use_aiter
:
# expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
...
...
@@ -404,23 +416,23 @@ class DeepEPMoE(EPMoE):
)
# the last one is invalid rank_id
self
.
expert_mask
[:
-
1
]
=
1
elif
not
_is_npu
:
self
.
w13_weight_fp8
=
(
self
.
w13_weight
,
(
self
.
w13_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w13_weight_scale
),
)
self
.
w2_weight_fp8
=
(
self
.
w2_weight
,
(
self
.
w2_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w2_weight_scale
),
)
#
elif not _is_npu:
#
self.w13_weight_fp8 = (
#
self.w13_weight,
#
(
#
self.w13_weight_scale_inv
#
if self.use_block_quant
#
else self.w13_weight_scale
#
),
#
)
#
self.w2_weight_fp8 = (
#
self.w2_weight,
#
(
#
self.w2_weight_scale_inv
#
if self.use_block_quant
#
else self.w2_weight_scale
#
),
#
)
def
forward
(
self
,
...
...
@@ -466,8 +478,15 @@ class DeepEPMoE(EPMoE):
assert
DispatchOutputChecker
.
format_is_deepep
(
dispatch_output
)
return
self
.
forward_npu
(
dispatch_output
)
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
):
assert
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
elif
self
.
use_w4a8_marlin
:
return
self
.
forward_deepgemm_w4a8_marlin_contiguous
(
dispatch_output
)
else
:
raise
ValueError
(
f
"Dispatch output is not supported"
)
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
if
get_moe_runner_backend
().
is_flashinfer_cutedsl
():
return
self
.
forward_flashinfer_cutedsl
(
dispatch_output
)
...
...
@@ -526,6 +545,34 @@ class DeepEPMoE(EPMoE):
expert_mask
=
self
.
expert_mask
,
)
def
forward_deepgemm_w4a8_marlin_contiguous
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
):
hidden_states_int8
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
=
(
dispatch_output
)
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
# if num_recv_tokens_per_expert is None:
return
hidden_states_int8
.
bfloat16
()
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=topk_weights,
# topk_ids=topk_idx,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales if self.use_int8_dispatch else None
# # routed_scaling_factor=self.routed_scaling_factor,
# )
# return expert_output
def
forward_deepgemm_contiguous
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
a0fb70e9
...
...
@@ -431,32 +431,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
deepep_post_reorder_triton_kernel
,
)
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
or
_use_aiter
or
_is_npu
:
output
=
hidden_states
else
:
if
hidden_states
.
shape
[
0
]
>
0
:
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
output
=
torch
.
empty
(
(
num_tokens
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
deepep_post_reorder_triton_kernel
[(
num_tokens
,)](
hidden_states
,
output
,
self
.
src2dst
,
topk_idx
,
topk_weights
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
else
:
output
=
torch
.
zeros
(
(
0
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
#
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output
=
hidden_states
#
else:
#
if hidden_states.shape[0] > 0:
#
num_tokens = self.src2dst.shape[0] // self.router_topk
#
output = torch.empty(
#
(num_tokens, hidden_states.shape[1]),
#
device=hidden_states.device,
#
dtype=hidden_states.dtype,
#
)
#
deepep_post_reorder_triton_kernel[(num_tokens,)](
#
hidden_states,
#
output,
#
self.src2dst,
#
topk_idx,
#
topk_weights,
#
self.router_topk,
#
hidden_states.shape[1],
#
BLOCK_SIZE=512,
#
)
#
else:
#
output = torch.zeros(
#
(0, hidden_states.shape[1]),
#
device=hidden_states.device,
#
dtype=hidden_states.dtype,
#
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
output
,
previous_event
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment