Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
84b006b2
Unverified
Commit
84b006b2
authored
Aug 15, 2025
by
Cheng Wan
Committed by
GitHub
Aug 15, 2025
Browse files
Cleanup MoE Refactor (#9223)
parent
8ca07bd9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
16 deletions
+18
-16
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+8
-3
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+6
-7
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+4
-6
No files found.
python/sglang/srt/layers/quantization/mxfp4.py
View file @
84b006b2
...
@@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_output
:
TopKOutput
,
topk_output
:
TopKOutput
,
moe_runner_config
:
MoeRunnerConfig
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.topk
import
TopKOutputChecker
if
self
.
use_flashinfer
:
if
self
.
use_flashinfer
:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant
,
x_scale
=
mxfp8_quantize
(
x_quant
,
x_scale
=
mxfp8_quantize
(
...
@@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
# to mxfp8
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
assert
x_quant
.
shape
[
-
1
]
==
self
.
hidden_size
assert
x_quant
.
shape
[
-
1
]
==
self
.
hidden_size
assert
TopKOutputChecker
.
format_is_bypassed
(
topk_output
)
top_k
,
router_logits
=
topk_output
top_k
=
topk_output
.
topk_config
.
top_k
router_logits
=
topk_output
.
router_logits
trtllm_gen_output
=
trtllm_fp4_block_scale_moe
(
trtllm_gen_output
=
trtllm_fp4_block_scale_moe
(
router_logits
.
to
(
torch
.
bfloat16
),
router_logits
.
to
(
torch
.
bfloat16
),
...
@@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None
,
# output2_scale_scalar
None
,
# output2_scale_scalar
layer
.
num_experts
,
layer
.
num_experts
,
top_k
,
top_k
,
None
,
# n_group
None
,
# n_group
# TODO: support n_group
None
,
# topk_group
None
,
# topk_group
# TODO: support topk_group
self
.
intermediate_size
,
# padded to multiple of 256
self
.
intermediate_size
,
# padded to multiple of 256
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
# local_expert_offset
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
# local_expert_offset
layer
.
num_local_experts
,
# local num experts
layer
.
num_local_experts
,
# local num experts
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
84b006b2
...
@@ -459,15 +459,15 @@ class DeepseekV2MoE(nn.Module):
...
@@ -459,15 +459,15 @@ class DeepseekV2MoE(nn.Module):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
:
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
current_stream
.
wait_stream
(
self
.
alt_stream
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
sm
.
tag
(
final_hidden_states
)
...
@@ -489,10 +489,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -489,10 +489,9 @@ class DeepseekV2MoE(nn.Module):
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
not
_is_cuda
and
not
_use_aiter
:
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
...
...
python/sglang/srt/models/glm4_moe.py
View file @
84b006b2
...
@@ -509,9 +509,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -509,9 +509,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
:
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
current_stream
.
wait_stream
(
self
.
alt_stream
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
...
@@ -552,9 +551,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -552,9 +551,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
and
not
_use_aiter
:
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment