Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f96413c4
Unverified
Commit
f96413c4
authored
Aug 20, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Aug 20, 2025
Browse files
Refactor allreduce add rmsnorm pattern (#9278)
parent
08ebdf79
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
78 deletions
+52
-78
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+41
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-40
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+6
-38
No files found.
python/sglang/srt/layers/communicator.py
View file @
f96413c4
...
...
@@ -34,6 +34,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size
,
get_global_dp_buffer
,
get_local_dp_buffer
,
is_dp_attention_enabled
,
)
from
sglang.srt.layers.moe
import
(
get_moe_a2a_backend
,
...
...
@@ -47,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
FUSE_ALLREDUCE_MAX_BATCH_SIZE
=
2048
class
ScatterMode
(
Enum
):
"""
...
...
@@ -162,11 +165,13 @@ class LayerCommunicator:
post_attention_layernorm
:
torch
.
nn
.
Module
,
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
allow_reduce_scatter
:
bool
=
False
,
is_last_layer
:
bool
=
False
,
):
self
.
layer_scatter_modes
=
layer_scatter_modes
self
.
input_layernorm
=
input_layernorm
self
.
post_attention_layernorm
=
post_attention_layernorm
self
.
allow_reduce_scatter
=
allow_reduce_scatter
self
.
is_last_layer
=
is_last_layer
self
.
_context
=
CommunicateContext
.
init_new
()
self
.
_communicate_simple_fn
=
CommunicateSimpleFn
.
get_fn
(
...
...
@@ -264,6 +269,42 @@ class LayerCommunicator:
and
forward_batch
.
dp_padding_mode
.
is_max_len
()
)
def
should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
:
ForwardBatch
)
->
bool
:
speculative_algo
=
global_server_args_dict
.
get
(
"speculative_algorithm"
,
None
)
if
(
is_dp_attention_enabled
()
and
speculative_algo
is
not
None
and
speculative_algo
.
is_eagle
()
):
return
False
batch_size
=
(
forward_batch
.
input_ids
.
shape
[
0
]
if
hasattr
(
forward_batch
,
"input_ids"
)
else
0
)
if
batch_size
>
FUSE_ALLREDUCE_MAX_BATCH_SIZE
:
return
False
static_conditions_met
=
(
(
not
self
.
is_last_layer
)
and
(
self
.
_context
.
tp_size
>
1
)
and
global_server_args_dict
.
get
(
"enable_flashinfer_allreduce_fusion"
,
False
)
and
_is_sm100_supported
and
_is_flashinfer_available
)
if
not
static_conditions_met
:
return
False
return
(
batch_size
>
0
and
batch_size
<=
FUSE_ALLREDUCE_MAX_BATCH_SIZE
and
(
not
self
.
is_last_layer
)
)
@
dataclass
class
CommunicateContext
:
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f96413c4
...
...
@@ -1852,10 +1852,11 @@ class DeepseekV2DecoderLayer(nn.Module):
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
allow_reduce_scatter
=
True
,
is_last_layer
=
(
is_nextn
or
(
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
)
),
)
self
.
_fuse_allreduce_lookup_table
=
self
.
_build_fuse_allreduce_lookup_table
()
def
_is_layer_sparse
(
self
,
layer_id
:
int
,
is_nextn
:
bool
)
->
bool
:
return
is_nextn
or
(
self
.
config
.
n_routed_experts
is
not
None
...
...
@@ -1863,20 +1864,6 @@ class DeepseekV2DecoderLayer(nn.Module):
and
layer_id
%
self
.
config
.
moe_layer_freq
==
0
)
def
_should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
)
->
bool
:
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
batch_size
=
(
forward_batch
.
input_ids
.
shape
[
0
]
if
hasattr
(
forward_batch
,
"input_ids"
)
else
0
)
if
batch_size
>
128
:
return
False
return
self
.
_fuse_allreduce_lookup_table
.
get
(
batch_size
,
False
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -1902,11 +1889,9 @@ class DeepseekV2DecoderLayer(nn.Module):
)
should_allreduce_fusion
=
(
self
.
_should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
and
not
(
is_dp_attention_enabled
()
and
self
.
speculative_algorithm
.
is_eagle
()
self
.
layer_communicator
.
should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
and
not
self
.
is_nextn
)
# For DP with padding, reduce scatter can be used instead of all-reduce.
...
...
@@ -1997,26 +1982,6 @@ class DeepseekV2DecoderLayer(nn.Module):
)
return
output
def
_build_fuse_allreduce_lookup_table
(
self
):
static_conditions_met
=
(
self
.
layer_id
!=
self
.
config
.
num_hidden_layers
-
1
and
get_tensor_model_parallel_world_size
()
>
1
and
global_server_args_dict
.
get
(
"enable_flashinfer_allreduce_fusion"
,
False
)
and
_is_sm100_supported
and
_is_flashinfer_available
)
if
not
static_conditions_met
:
return
{}
lookup_table
=
{}
for
batch_size
in
range
(
129
):
# 0 to 128
is_last_layer
=
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
should_fuse
=
batch_size
>
0
and
batch_size
<=
128
and
not
is_last_layer
lookup_table
[
batch_size
]
=
should_fuse
return
lookup_table
class
DeepseekV2Model
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
...
...
python/sglang/srt/models/gpt_oss.py
View file @
f96413c4
...
...
@@ -453,44 +453,11 @@ class GptOssDecoderLayer(nn.Module):
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
is_last_layer
=
(
self
.
is_nextn
or
(
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
)
),
)
self
.
_fuse_allreduce_lookup_table
=
self
.
_build_fuse_allreduce_lookup_table
()
def
_should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
)
->
bool
:
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
batch_size
=
(
forward_batch
.
input_ids
.
shape
[
0
]
if
hasattr
(
forward_batch
,
"input_ids"
)
else
0
)
if
batch_size
>
128
:
return
False
return
self
.
_fuse_allreduce_lookup_table
.
get
(
batch_size
,
False
)
def
_build_fuse_allreduce_lookup_table
(
self
):
static_conditions_met
=
(
self
.
layer_id
!=
self
.
config
.
num_hidden_layers
-
1
and
get_tensor_model_parallel_world_size
()
>
1
and
global_server_args_dict
.
get
(
"enable_flashinfer_allreduce_fusion"
,
False
)
and
_is_sm100_supported
and
_is_flashinfer_available
)
if
not
static_conditions_met
:
return
{}
lookup_table
=
{}
for
batch_size
in
range
(
129
):
# 0 to 128
is_last_layer
=
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
should_fuse
=
batch_size
>
0
and
batch_size
<=
128
and
not
is_last_layer
lookup_table
[
batch_size
]
=
should_fuse
return
lookup_table
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -514,8 +481,9 @@ class GptOssDecoderLayer(nn.Module):
)
should_allreduce_fusion
=
(
self
.
_should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
and
not
self
.
is_nextn
self
.
layer_communicator
.
should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
,
should_allreduce_fusion
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment