Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
84b006b2
"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "742fd13c17f3e61390fd28924c71f92f18c7efe3"
Unverified
Commit
84b006b2
authored
Aug 15, 2025
by
Cheng Wan
Committed by
GitHub
Aug 15, 2025
Browse files
Cleanup MoE Refactor (#9223)
parent
8ca07bd9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
16 deletions
+18
-16
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+8
-3
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+6
-7
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+4
-6
No files found.
python/sglang/srt/layers/quantization/mxfp4.py
View file @
84b006b2
...
@@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_output
:
TopKOutput
,
topk_output
:
TopKOutput
,
moe_runner_config
:
MoeRunnerConfig
,
moe_runner_config
:
MoeRunnerConfig
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.topk
import
TopKOutputChecker
if
self
.
use_flashinfer
:
if
self
.
use_flashinfer
:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant
,
x_scale
=
mxfp8_quantize
(
x_quant
,
x_scale
=
mxfp8_quantize
(
...
@@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
# to mxfp8
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
assert
x_quant
.
shape
[
-
1
]
==
self
.
hidden_size
assert
x_quant
.
shape
[
-
1
]
==
self
.
hidden_size
assert
TopKOutputChecker
.
format_is_bypassed
(
topk_output
)
top_k
,
router_logits
=
topk_output
top_k
=
topk_output
.
topk_config
.
top_k
router_logits
=
topk_output
.
router_logits
trtllm_gen_output
=
trtllm_fp4_block_scale_moe
(
trtllm_gen_output
=
trtllm_fp4_block_scale_moe
(
router_logits
.
to
(
torch
.
bfloat16
),
router_logits
.
to
(
torch
.
bfloat16
),
...
@@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None
,
# output2_scale_scalar
None
,
# output2_scale_scalar
layer
.
num_experts
,
layer
.
num_experts
,
top_k
,
top_k
,
None
,
# n_group
None
,
# n_group
# TODO: support n_group
None
,
# topk_group
None
,
# topk_group
# TODO: support topk_group
self
.
intermediate_size
,
# padded to multiple of 256
self
.
intermediate_size
,
# padded to multiple of 256
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
# local_expert_offset
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
# local_expert_offset
layer
.
num_local_experts
,
# local num experts
layer
.
num_local_experts
,
# local num experts
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
84b006b2
...
@@ -459,15 +459,15 @@ class DeepseekV2MoE(nn.Module):
...
@@ -459,15 +459,15 @@ class DeepseekV2MoE(nn.Module):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
:
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
current_stream
.
wait_stream
(
self
.
alt_stream
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
sm
.
tag
(
final_hidden_states
)
...
@@ -489,10 +489,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -489,10 +489,9 @@ class DeepseekV2MoE(nn.Module):
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
not
_is_cuda
and
not
_use_aiter
:
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
...
...
python/sglang/srt/models/glm4_moe.py
View file @
84b006b2
...
@@ -509,9 +509,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -509,9 +509,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
:
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
current_stream
.
wait_stream
(
self
.
alt_stream
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
...
@@ -552,9 +551,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -552,9 +551,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
and
not
_use_aiter
:
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment