Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
18da2c96
"vscode:/vscode.git/clone" did not exist on "9731e023255559b1c7537aad60b0d534738add85"
Unverified
Commit
18da2c96
authored
Aug 21, 2025
by
Kaixi Hou
Committed by
GitHub
Aug 21, 2025
Browse files
[NVIDIA] Fix trtllm fp4 moe backend when used in MTP (#9384)
parent
9b5f0f64
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
3 deletions
+12
-3
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+5
-1
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+2
-0
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+3
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-1
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
18da2c96
...
...
@@ -783,13 +783,17 @@ class DeepEPMoE(EPMoE):
return
hidden_states
def
get_moe_impl_class
():
def
get_moe_impl_class
(
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
if
get_moe_a2a_backend
().
is_deepep
():
return
DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if
get_moe_runner_backend
().
is_flashinfer_trtllm
():
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
# If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
if
quant_config
is
None
:
return
FusedMoE
try
:
# Check the quantization argument directly
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
18da2c96
...
...
@@ -1008,6 +1008,8 @@ class FlashInferFP4MoE(FusedMoE):
hidden_states: Input tensor
topk_output: TopKOutput object with Bypassed format
"""
assert
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
assert
TopKOutputChecker
.
format_is_bypassed
(
topk_output
)
router_logits
=
topk_output
.
router_logits
...
...
python/sglang/srt/layers/moe/topk.py
View file @
18da2c96
...
...
@@ -198,6 +198,7 @@ class TopK(CustomOp):
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
force_topk
:
bool
=
False
,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
...
...
@@ -220,6 +221,7 @@ class TopK(CustomOp):
)
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
self
.
force_topk
=
force_topk
def
forward_native
(
self
,
...
...
@@ -254,7 +256,7 @@ class TopK(CustomOp):
sm_first
=
not
self
.
topk_config
.
renormalize
,
)
return
TritonKernelTopKOutput
(
routing_data
,
gather_idx
,
scatter_idx
)
elif
(
elif
not
self
.
force_topk
and
(
should_use_flashinfer_trtllm_moe
()
or
get_moe_runner_backend
().
is_flashinfer_mxfp4
()
):
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
18da2c96
...
...
@@ -319,7 +319,7 @@ class DeepseekV2MoE(nn.Module):
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
(
quant_config
)(
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
...
...
@@ -343,6 +343,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
(),
force_topk
=
quant_config
is
None
,
)
self
.
shared_experts_is_int8
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment