Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
a91e90d9
"Contributors.md" did not exist on "6d36c932be58b8bc2f84e8683126bb1c0e5457f1"
Unverified
Commit
a91e90d9
authored
Aug 20, 2025
by
Trevor Morris
Committed by
GitHub
Aug 20, 2025
Browse files
[2/2] Fuse routed scaling factor into select_experts (#8690)
parent
f96413c4
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
55 additions
and
25 deletions
+55
-25
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+7
-0
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+20
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+8
-9
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+2
-4
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+12
-11
sgl-kernel/tests/test_moe_fused_gate.py
sgl-kernel/tests/test_moe_fused_gate.py
+6
-1
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
a91e90d9
...
...
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8
import
Fp8MoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
...
@@ -923,6 +924,12 @@ class FusedMoE(torch.nn.Module):
for
shard_id
in
[
"w1"
,
"w2"
,
"w3"
]
]
def
should_fuse_routed_scaling_factor_in_topk
(
self
):
return
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
or
(
isinstance
(
self
.
quant_method
,
Fp8MoEMethod
)
and
self
.
quant_method
.
use_cutlass_fused_experts_fp8
)
class
FlashInferFusedMoE
(
FusedMoE
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
python/sglang/srt/layers/moe/topk.py
View file @
a91e90d9
...
...
@@ -197,6 +197,7 @@ class TopK(CustomOp):
scoring_func
:
str
=
"softmax"
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
...
...
@@ -215,6 +216,7 @@ class TopK(CustomOp):
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
...
...
@@ -433,6 +435,7 @@ def grouped_topk_gpu(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
...
...
@@ -480,6 +483,8 @@ def grouped_topk_gpu(
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
)
topk_weights
=
topk_weights
/
topk_weights_sum
if
apply_routed_scaling_factor_on_output
:
topk_weights
*=
routed_scaling_factor
topk_weights
,
topk_ids
=
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
...
...
@@ -528,6 +533,7 @@ def biased_grouped_topk_impl(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
...
...
@@ -579,6 +585,8 @@ def biased_grouped_topk_impl(
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
)
topk_weights
=
topk_weights
/
topk_weights_sum
if
apply_routed_scaling_factor_on_output
:
topk_weights
*=
routed_scaling_factor
topk_weights
,
topk_ids
=
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
...
...
@@ -621,6 +629,7 @@ def biased_grouped_topk_gpu(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
assert
(
routed_scaling_factor
is
not
None
...
...
@@ -640,6 +649,7 @@ def biased_grouped_topk_gpu(
topk
,
num_fused_shared_experts
,
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
,
)
# TODO merge into kernel
if
(
expert_location_dispatch_info
is
not
None
)
or
(
...
...
@@ -650,6 +660,7 @@ def biased_grouped_topk_gpu(
)
return
topk_weights
,
topk_ids
elif
_use_aiter
:
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
token
=
gating_output
.
shape
[
0
]
device
=
gating_output
.
device
assert
(
...
...
@@ -681,6 +692,7 @@ def biased_grouped_topk_gpu(
routed_scaling_factor
=
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
...
...
@@ -743,6 +755,9 @@ def select_experts(
correction_bias
=
topk_config
.
correction_bias
torch_native
=
topk_config
.
torch_native
routed_scaling_factor
=
topk_config
.
routed_scaling_factor
apply_routed_scaling_factor_on_output
=
(
topk_config
.
apply_routed_scaling_factor_on_output
)
router_logits
,
correction_bias
=
(
expert_location_dispatch
.
transform_select_experts_inputs
(
...
...
@@ -768,6 +783,7 @@ def select_experts(
routed_scaling_factor
=
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
else
:
topk_weights
,
topk_ids
=
biased_grouped_topk
(
...
...
@@ -782,12 +798,14 @@ def select_experts(
routed_scaling_factor
=
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
elif
torch_native
and
custom_routing_function
is
None
:
assert
(
num_token_non_padded
is
None
),
"num_token_non_padded is not yet supported in fused_topk_native"
assert
expert_location_dispatch_info
is
None
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
topk_weights
,
topk_ids
=
fused_topk_native
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
...
...
@@ -795,6 +813,7 @@ def select_experts(
renormalize
=
renormalize
,
)
elif
custom_routing_function
is
None
:
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
# Qwen3MOE uses fused_topk
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
...
...
@@ -809,6 +828,7 @@ def select_experts(
num_token_non_padded
is
None
),
"num_token_non_padded is not yet supported in custom_routing_function"
assert
expert_location_dispatch_info
is
None
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
a91e90d9
...
...
@@ -514,6 +514,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
use_cutlass_fused_experts_fp8
=
(
get_bool_env_var
(
"SGLANG_CUTLASS_MOE"
)
and
self
.
cutlass_fp8_supported
and
self
.
block_quant
and
(
is_sm100_supported
()
or
is_sm90_supported
())
)
def
create_weights
(
self
,
...
...
@@ -1021,12 +1027,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
ret
is
not
None
:
return
ret
if
(
get_bool_env_var
(
"SGLANG_CUTLASS_MOE"
)
and
self
.
cutlass_fp8_supported
and
self
.
block_quant
and
(
is_sm100_supported
()
or
is_sm90_supported
())
):
if
self
.
use_cutlass_fused_experts_fp8
:
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
topk_weights
,
topk_ids
,
_
=
topk_output
...
...
@@ -1053,9 +1054,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
problem_sizes2
,
use_fp8_blockscale
=
True
,
)
# TODO: Fuse into select_experts
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
# Scale by routed_scaling_factor is fused into select_experts.
return
output
# Expert fusion with FP8 quantization
return
fused_experts
(
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
a91e90d9
...
...
@@ -1305,8 +1305,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_rank
=
layer
.
moe_tp_rank
,
tune_max_num_tokens
=
next_power_of_2
(
x
.
shape
[
0
]),
)[
0
]
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
# Scale by routed_scaling_factor is fused into select_experts.
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
output
,
global_output
=
get_local_dp_buffer
(),
output
get_tp_group
().
reduce_scatterv
(
...
...
@@ -1332,6 +1331,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
params
=
layer
.
cutlass_moe_params
,
apply_router_weight_on_input
=
moe_runner_config
.
apply_router_weight_on_input
,
).
to
(
x
.
dtype
)
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
# Scale by routed_scaling_factor is fused into select_experts.
return
output
python/sglang/srt/models/deepseek_v2.py
View file @
a91e90d9
...
...
@@ -319,17 +319,6 @@ class DeepseekV2MoE(nn.Module):
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
self
.
experts
=
get_moe_impl_class
()(
num_experts
=
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
...
...
@@ -344,6 +333,18 @@ class DeepseekV2MoE(nn.Module):
prefix
=
add_prefix
(
"experts"
,
prefix
),
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
(),
)
self
.
shared_experts_is_int8
=
False
self
.
shared_experts_is_fp8
=
False
self
.
shared_experts_weight_block_size
=
None
...
...
sgl-kernel/tests/test_moe_fused_gate.py
View file @
a91e90d9
...
...
@@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
],
)
@
pytest
.
mark
.
parametrize
(
"num_fused_shared_experts"
,
[
0
,
1
,
2
])
def
test_moe_fused_gate_combined
(
seq_length
,
params
,
num_fused_shared_experts
):
@
pytest
.
mark
.
parametrize
(
"apply_routed_scaling_factor_on_output"
,
[
False
,
True
])
def
test_moe_fused_gate_combined
(
seq_length
,
params
,
num_fused_shared_experts
,
apply_routed_scaling_factor_on_output
):
num_experts
,
num_expert_group
,
topk_group
,
topk
=
params
dtype
=
torch
.
float32
...
...
@@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk
=
topk
,
num_fused_shared_experts
=
num_fused_shared_experts
,
routed_scaling_factor
=
2.5
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
ref_output
,
ref_indices
=
biased_grouped_topk
(
scores
,
...
...
@@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk_group
=
topk_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
routed_scaling_factor
=
2.5
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment