Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3c2c9f6c
Unverified
Commit
3c2c9f6c
authored
Aug 18, 2025
by
Jiaqi Gu
Committed by
GitHub
Aug 18, 2025
Browse files
[Bug] Fix input arguments of flashinfer_trtllm_moe (#9317)
parent
a31ea448
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
21 deletions
+30
-21
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+2
-2
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+14
-14
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+14
-5
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
3c2c9f6c
...
@@ -932,11 +932,11 @@ class FlashInferFusedMoE(FusedMoE):
...
@@ -932,11 +932,11 @@ class FlashInferFusedMoE(FusedMoE):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
assert
self
.
use_flashinfer_trtllm_moe
assert
self
.
use_flashinfer_trtllm_moe
assert
(
assert
(
self
.
activation
==
"silu"
self
.
moe_runner_config
.
activation
==
"silu"
),
"Only silu is supported for flashinfer blockscale fp8 moe"
),
"Only silu is supported for flashinfer blockscale fp8 moe"
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
(
assert
(
self
.
renormalize
topk_output
.
topk_config
.
renormalize
),
"Renormalize is required for flashinfer blockscale fp8 moe"
),
"Renormalize is required for flashinfer blockscale fp8 moe"
assert
(
assert
(
self
.
num_fused_shared_experts
==
0
self
.
num_fused_shared_experts
==
0
...
...
python/sglang/srt/layers/moe/topk.py
View file @
3c2c9f6c
...
@@ -85,8 +85,8 @@ if _is_npu:
...
@@ -85,8 +85,8 @@ if _is_npu:
class
TopKConfig
:
class
TopKConfig
:
top_k
:
int
top_k
:
int
use_grouped_topk
:
bool
=
False
use_grouped_topk
:
bool
=
False
topk_group
:
int
=
0
topk_group
:
Optional
[
int
]
=
None
num_expert_group
:
int
=
0
num_expert_group
:
Optional
[
int
]
=
None
renormalize
:
bool
=
True
renormalize
:
bool
=
True
num_fused_shared_experts
:
int
=
0
num_fused_shared_experts
:
int
=
0
custom_routing_function
:
Optional
[
Callable
]
=
None
custom_routing_function
:
Optional
[
Callable
]
=
None
...
@@ -189,8 +189,8 @@ class TopK(CustomOp):
...
@@ -189,8 +189,8 @@ class TopK(CustomOp):
top_k
:
int
,
top_k
:
int
,
*
,
*
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
=
0
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
int
=
0
,
num_expert_group
:
Optional
[
int
]
=
None
,
renormalize
:
bool
=
True
,
renormalize
:
bool
=
True
,
num_fused_shared_experts
:
int
=
0
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
...
@@ -427,8 +427,8 @@ def grouped_topk_gpu(
...
@@ -427,8 +427,8 @@ def grouped_topk_gpu(
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
int
=
0
,
topk_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
num_fused_shared_experts
:
int
=
0
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -492,8 +492,8 @@ def grouped_topk_cpu(
...
@@ -492,8 +492,8 @@ def grouped_topk_cpu(
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
int
=
0
,
topk_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
num_fused_shared_experts
:
int
=
0
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -522,8 +522,8 @@ def biased_grouped_topk_impl(
...
@@ -522,8 +522,8 @@ def biased_grouped_topk_impl(
correction_bias
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
int
=
0
,
topk_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
num_fused_shared_experts
:
int
=
0
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -615,8 +615,8 @@ def biased_grouped_topk_gpu(
...
@@ -615,8 +615,8 @@ def biased_grouped_topk_gpu(
correction_bias
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
int
=
0
,
topk_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
num_fused_shared_experts
:
int
=
0
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -690,8 +690,8 @@ def biased_grouped_topk_cpu(
...
@@ -690,8 +690,8 @@ def biased_grouped_topk_cpu(
correction_bias
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
int
=
0
,
topk_group
:
Optional
[
int
]
=
None
,
compiled
:
bool
=
True
,
compiled
:
bool
=
True
,
num_fused_shared_experts
:
int
=
0
,
num_fused_shared_experts
:
int
=
0
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
3c2c9f6c
...
@@ -445,7 +445,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -445,7 +445,6 @@ class Fp8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
use_marlin
:
if
self
.
use_marlin
:
return
apply_fp8_marlin_linear
(
return
apply_fp8_marlin_linear
(
input
=
x
,
input
=
x
,
...
@@ -1087,7 +1086,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1087,7 +1086,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_output
:
TopKOutput
,
topk_output
:
TopKOutput
,
moe_runner_config
:
MoeRunnerConfig
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
activation
=
moe_runner_config
.
activation
activation
=
moe_runner_config
.
activation
routed_scaling_factor
=
moe_runner_config
.
routed_scaling_factor
routed_scaling_factor
=
moe_runner_config
.
routed_scaling_factor
...
@@ -1105,9 +1103,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1105,9 +1103,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# NOTE: scales of hidden states have to be transposed!
# NOTE: scales of hidden states have to be transposed!
a_sf_t
=
a_sf
.
t
().
contiguous
()
a_sf_t
=
a_sf
.
t
().
contiguous
()
assert
(
topk_config
.
num_expert_group
is
not
None
and
topk_config
.
topk_group
is
not
None
),
"Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
if
topk_config
.
correction_bias
is
None
:
correction_bias
=
topk_config
.
correction_bias
.
to
(
x
.
dtype
)
else
:
correction_bias
=
None
return
trtllm_fp8_block_scale_moe
(
return
trtllm_fp8_block_scale_moe
(
routing_logits
=
router_logits
.
to
(
torch
.
float32
),
routing_logits
=
router_logits
.
to
(
torch
.
float32
),
routing_bias
=
layer
.
correction_bias
.
to
(
x
.
dtype
)
,
routing_bias
=
correction_bias
,
hidden_states
=
a_q
,
hidden_states
=
a_q
,
hidden_states_scale
=
a_sf_t
,
hidden_states_scale
=
a_sf_t
,
gemm1_weights
=
layer
.
w13_weight
,
gemm1_weights
=
layer
.
w13_weight
,
...
@@ -1121,9 +1128,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1121,9 +1128,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size
=
layer
.
w2_weight
.
shape
[
2
],
intermediate_size
=
layer
.
w2_weight
.
shape
[
2
],
local_expert_offset
=
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
local_expert_offset
=
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
local_num_experts
=
layer
.
num_local_experts
,
local_num_experts
=
layer
.
num_local_experts
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
(
routed_scaling_factor
if
routed_scaling_factor
is
not
None
else
1.0
),
tile_tokens_dim
=
get_tile_tokens_dim
(
tile_tokens_dim
=
get_tile_tokens_dim
(
x
.
shape
[
0
],
layer
.
top_k
,
layer
.
num_experts
x
.
shape
[
0
],
topk_config
.
top_k
,
layer
.
num_experts
),
),
routing_method_type
=
2
,
# DeepSeek-styled routing method
routing_method_type
=
2
,
# DeepSeek-styled routing method
use_shuffled_weight
=
False
,
use_shuffled_weight
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment