Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ec15c836
Unverified
Commit
ec15c836
authored
Sep 04, 2025
by
Yuan Luo
Committed by
GitHub
Sep 04, 2025
Browse files
Optimize Qwen3-moe model by using flashinfer fused allreduce (#9973)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
106c2b31
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
12 deletions
+52
-12
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+9
-3
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+4
-1
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+39
-8
No files found.
python/sglang/srt/layers/communicator.py
View file @
ec15c836
...
@@ -42,9 +42,15 @@ from sglang.srt.layers.moe import (
...
@@ -42,9 +42,15 @@ from sglang.srt.layers.moe import (
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
is_cuda
,
is_flashinfer_available
,
is_sm100_supported
from
sglang.srt.utils
import
(
is_cuda
,
is_flashinfer_available
,
is_sm90_supported
,
is_sm100_supported
,
)
_is_flashinfer_available
=
is_flashinfer_available
()
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm90_supported
=
is_cuda
()
and
is_sm90_supported
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
FUSE_ALLREDUCE_MAX_BATCH_SIZE
=
2048
FUSE_ALLREDUCE_MAX_BATCH_SIZE
=
2048
...
@@ -484,11 +490,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
...
@@ -484,11 +490,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
if
(
if
(
_is_sm100_supported
(
_is_sm100_supported
or
_is_sm90_supported
)
and
_is_flashinfer_available
and
_is_flashinfer_available
and
hasattr
(
layernorm
,
"forward_with_allreduce_fusion"
)
and
hasattr
(
layernorm
,
"forward_with_allreduce_fusion"
)
and
global_server_args_dict
[
"enable_flashinfer_allreduce_fusion"
]
and
global_server_args_dict
[
"enable_flashinfer_allreduce_fusion"
]
and
hidden_states
.
shape
[
0
]
<=
2048
and
hidden_states
.
shape
[
0
]
<=
4096
):
):
hidden_states
,
residual
=
layernorm
.
forward_with_allreduce_fusion
(
hidden_states
,
residual
=
layernorm
.
forward_with_allreduce_fusion
(
hidden_states
,
residual
hidden_states
,
residual
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
ec15c836
...
@@ -105,11 +105,14 @@ class Qwen2MoeMLP(nn.Module):
...
@@ -105,11 +105,14 @@ class Qwen2MoeMLP(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
x
,
x
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
):
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
use_reduce_scatter
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
should_allreduce_fusion
or
use_reduce_scatter
)
return
x
return
x
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
ec15c836
...
@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import (
...
@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_moe_a2a_backend
from
sglang.srt.layers.moe
import
(
get_moe_a2a_backend
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
)
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.topk
import
TopK
...
@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
...
@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_non_idle_and_non_empty
from
sglang.srt.utils
import
(
add_prefix
,
is_cuda
,
is_flashinfer_available
,
is_non_idle_and_non_empty
,
)
Qwen3MoeConfig
=
None
Qwen3MoeConfig
=
None
_is_flashinfer_available
=
is_flashinfer_available
()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
get_moe_a2a_backend
().
is_deepep
():
if
not
get_moe_a2a_backend
().
is_deepep
():
return
self
.
forward_normal
(
hidden_states
,
use_reduce_scatter
)
return
self
.
forward_normal
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
)
else
:
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
...
@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def
forward_normal
(
def
forward_normal
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
self
.
tp_size
>
1
and
not
use_reduce_scatter
:
if
(
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
and
not
should_use_flashinfer_cutlass_moe_fp4_allgather
()
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
input_layernorm
=
self
.
input_layernorm
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
allow_reduce_scatter
=
True
,
allow_reduce_scatter
=
True
,
is_last_layer
=
(
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
),
)
)
def
forward
(
def
forward
(
...
@@ -525,13 +545,24 @@ class Qwen3MoeDecoderLayer(nn.Module):
...
@@ -525,13 +545,24 @@ class Qwen3MoeDecoderLayer(nn.Module):
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
,
forward_batch
)
)
should_allreduce_fusion
=
(
self
.
layer_communicator
.
should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
)
# For DP with padding, reduce scatter can be used instead of all-reduce.
# For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter
=
self
.
layer_communicator
.
should_use_reduce_scatter
(
use_reduce_scatter
=
self
.
layer_communicator
.
should_use_reduce_scatter
(
forward_batch
forward_batch
)
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
,
use_reduce_scatter
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
,
should_allreduce_fusion
,
use_reduce_scatter
)
if
should_allreduce_fusion
:
hidden_states
.
_sglang_needs_allreduce_fusion
=
True
else
:
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
,
forward_batch
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment