Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6d0646da
Unverified
Commit
6d0646da
authored
Aug 04, 2025
by
Kaixi Hou
Committed by
GitHub
Aug 04, 2025
Browse files
[NVIDIA] Fix breakage of using trtllm-gen fp8 moe (#8773)
parent
02bc1c7d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
63 deletions
+18
-63
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+4
-62
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+14
-1
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
6d0646da
...
...
@@ -673,66 +673,6 @@ class DeepEPMoE(EPMoE):
return
down_output
class
FlashInferEPMoE
(
EPMoE
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
renormalize
=
kwargs
.
pop
(
"renormalize"
,
True
)
num_fused_shared_experts
=
kwargs
.
pop
(
"num_fused_shared_experts"
,
0
)
use_grouped_topk
=
kwargs
.
pop
(
"use_grouped_topk"
,
False
)
num_expert_group
=
kwargs
.
pop
(
"num_expert_group"
,
None
)
topk_group
=
kwargs
.
pop
(
"topk_group"
,
None
)
correction_bias
=
kwargs
.
pop
(
"correction_bias"
,
None
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
renormalize
=
renormalize
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
use_grouped_topk
=
use_grouped_topk
if
self
.
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
correction_bias
=
correction_bias
self
.
use_flashinfer_trtllm_moe
=
should_use_flashinfer_trtllm_moe
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
self
.
use_flashinfer_trtllm_moe
assert
(
self
.
activation
==
"silu"
),
"Only silu is supported for flashinfer blockscale fp8 moe"
assert
(
self
.
renormalize
),
"Renormalize is required for flashinfer blockscale fp8 moe"
assert
(
self
.
num_fused_shared_experts
==
0
),
"Fused shared experts are not supported for flashinfer blockscale fp8 moe"
a_q
,
a_sf
=
sglang_per_token_group_quant_fp8
(
hidden_states
,
self
.
block_shape
[
1
])
# NOTE: scales of hidden states have to be transposed!
a_sf_t
=
a_sf
.
t
().
contiguous
()
from
flashinfer.fused_moe
import
trtllm_fp8_block_scale_moe
return
trtllm_fp8_block_scale_moe
(
routing_logits
=
router_logits
.
to
(
torch
.
float32
),
routing_bias
=
self
.
correction_bias
.
to
(
hidden_states
.
dtype
),
hidden_states
=
a_q
,
hidden_states_scale
=
a_sf_t
,
gemm1_weights
=
self
.
w13_weight
,
gemm1_weights_scale
=
self
.
w13_weight_scale_inv
,
gemm2_weights
=
self
.
w2_weight
,
gemm2_weights_scale
=
self
.
w2_weight_scale_inv
,
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
n_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
intermediate_size
=
self
.
w2_weight
.
shape
[
2
],
local_expert_offset
=
self
.
start_expert_id
,
local_num_experts
=
self
.
num_local_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
tile_tokens_dim
=
get_tile_tokens_dim
(
hidden_states
.
shape
[
0
],
self
.
top_k
,
self
.
num_experts
),
routing_method_type
=
2
,
# DeepSeek-styled routing method
use_shuffled_weight
=
False
,
)
def
get_moe_impl_class
():
if
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
():
return
DeepEPMoE
...
...
@@ -752,8 +692,10 @@ def get_moe_impl_class():
except
:
pass
if
should_use_flashinfer_trtllm_moe
():
return
FlashInferFusedMoE
if
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
]:
return
FusedMoE
if
get_moe_expert_parallel_world_size
()
>
1
:
return
FlashInferEPMoE
if
should_use_flashinfer_trtllm_moe
()
else
EPMoE
return
FlashInferFusedMoE
if
should_use_flashinfer_trtllm_moe
()
else
FusedMoE
return
EPMoE
return
FusedMoE
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
6d0646da
...
...
@@ -763,8 +763,13 @@ class FlashInferFusedMoE(FusedMoE):
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
correction_bias
=
correction_bias
self
.
use_flashinfer_trtllm_moe
=
should_use_flashinfer_trtllm_moe
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
tuple
):
assert
self
.
use_flashinfer_trtllm_moe
assert
(
self
.
activation
==
"silu"
),
"Only silu is supported for flashinfer blockscale fp8 moe"
assert
self
.
quant_method
is
not
None
assert
(
self
.
renormalize
...
...
@@ -772,6 +777,14 @@ class FlashInferFusedMoE(FusedMoE):
assert
(
self
.
num_fused_shared_experts
==
0
),
"Fused shared experts are not supported for flashinfer blockscale fp8 moe"
# TRTLLM mode expects (TopK_config, router_logits) tuple
if
not
isinstance
(
topk_output
,
tuple
)
or
len
(
topk_output
)
!=
2
:
raise
ValueError
(
f
"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got
{
type
(
topk_output
)
}
"
)
_
,
router_logits
=
topk_output
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply_with_router_logits
(
layer
=
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment