Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f96413c4
"examples/pytorch/vscode:/vscode.git/clone" did not exist on "b133abb82e128bb091ffd1c299b427f3b3958cb3"
Unverified
Commit
f96413c4
authored
Aug 20, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Aug 20, 2025
Browse files
Refactor allreduce add rmsnorm pattern (#9278)
parent
08ebdf79
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
78 deletions
+52
-78
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+41
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-40
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+6
-38
No files found.
python/sglang/srt/layers/communicator.py
View file @
f96413c4
...
...
@@ -34,6 +34,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size
,
get_global_dp_buffer
,
get_local_dp_buffer
,
is_dp_attention_enabled
,
)
from
sglang.srt.layers.moe
import
(
get_moe_a2a_backend
,
...
...
@@ -47,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
FUSE_ALLREDUCE_MAX_BATCH_SIZE
=
2048
class
ScatterMode
(
Enum
):
"""
...
...
@@ -162,11 +165,13 @@ class LayerCommunicator:
post_attention_layernorm
:
torch
.
nn
.
Module
,
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
allow_reduce_scatter
:
bool
=
False
,
is_last_layer
:
bool
=
False
,
):
self
.
layer_scatter_modes
=
layer_scatter_modes
self
.
input_layernorm
=
input_layernorm
self
.
post_attention_layernorm
=
post_attention_layernorm
self
.
allow_reduce_scatter
=
allow_reduce_scatter
self
.
is_last_layer
=
is_last_layer
self
.
_context
=
CommunicateContext
.
init_new
()
self
.
_communicate_simple_fn
=
CommunicateSimpleFn
.
get_fn
(
...
...
@@ -264,6 +269,42 @@ class LayerCommunicator:
and
forward_batch
.
dp_padding_mode
.
is_max_len
()
)
def
should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
:
ForwardBatch
)
->
bool
:
speculative_algo
=
global_server_args_dict
.
get
(
"speculative_algorithm"
,
None
)
if
(
is_dp_attention_enabled
()
and
speculative_algo
is
not
None
and
speculative_algo
.
is_eagle
()
):
return
False
batch_size
=
(
forward_batch
.
input_ids
.
shape
[
0
]
if
hasattr
(
forward_batch
,
"input_ids"
)
else
0
)
if
batch_size
>
FUSE_ALLREDUCE_MAX_BATCH_SIZE
:
return
False
static_conditions_met
=
(
(
not
self
.
is_last_layer
)
and
(
self
.
_context
.
tp_size
>
1
)
and
global_server_args_dict
.
get
(
"enable_flashinfer_allreduce_fusion"
,
False
)
and
_is_sm100_supported
and
_is_flashinfer_available
)
if
not
static_conditions_met
:
return
False
return
(
batch_size
>
0
and
batch_size
<=
FUSE_ALLREDUCE_MAX_BATCH_SIZE
and
(
not
self
.
is_last_layer
)
)
@
dataclass
class
CommunicateContext
:
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f96413c4
...
...
@@ -1852,10 +1852,11 @@ class DeepseekV2DecoderLayer(nn.Module):
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
allow_reduce_scatter
=
True
,
is_last_layer
=
(
is_nextn
or
(
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
)
),
)
self
.
_fuse_allreduce_lookup_table
=
self
.
_build_fuse_allreduce_lookup_table
()
def
_is_layer_sparse
(
self
,
layer_id
:
int
,
is_nextn
:
bool
)
->
bool
:
return
is_nextn
or
(
self
.
config
.
n_routed_experts
is
not
None
...
...
@@ -1863,20 +1864,6 @@ class DeepseekV2DecoderLayer(nn.Module):
and
layer_id
%
self
.
config
.
moe_layer_freq
==
0
)
def
_should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
)
->
bool
:
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
batch_size
=
(
forward_batch
.
input_ids
.
shape
[
0
]
if
hasattr
(
forward_batch
,
"input_ids"
)
else
0
)
if
batch_size
>
128
:
return
False
return
self
.
_fuse_allreduce_lookup_table
.
get
(
batch_size
,
False
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -1902,11 +1889,9 @@ class DeepseekV2DecoderLayer(nn.Module):
)
should_allreduce_fusion
=
(
self
.
_should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
and
not
(
is_dp_attention_enabled
()
and
self
.
speculative_algorithm
.
is_eagle
()
self
.
layer_communicator
.
should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
and
not
self
.
is_nextn
)
# For DP with padding, reduce scatter can be used instead of all-reduce.
...
...
@@ -1997,26 +1982,6 @@ class DeepseekV2DecoderLayer(nn.Module):
)
return
output
def
_build_fuse_allreduce_lookup_table
(
self
):
static_conditions_met
=
(
self
.
layer_id
!=
self
.
config
.
num_hidden_layers
-
1
and
get_tensor_model_parallel_world_size
()
>
1
and
global_server_args_dict
.
get
(
"enable_flashinfer_allreduce_fusion"
,
False
)
and
_is_sm100_supported
and
_is_flashinfer_available
)
if
not
static_conditions_met
:
return
{}
lookup_table
=
{}
for
batch_size
in
range
(
129
):
# 0 to 128
is_last_layer
=
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
should_fuse
=
batch_size
>
0
and
batch_size
<=
128
and
not
is_last_layer
lookup_table
[
batch_size
]
=
should_fuse
return
lookup_table
class
DeepseekV2Model
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
...
...
python/sglang/srt/models/gpt_oss.py
View file @
f96413c4
...
...
@@ -453,44 +453,11 @@ class GptOssDecoderLayer(nn.Module):
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
is_last_layer
=
(
self
.
is_nextn
or
(
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
)
),
)
self
.
_fuse_allreduce_lookup_table
=
self
.
_build_fuse_allreduce_lookup_table
()
def
_should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
)
->
bool
:
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
batch_size
=
(
forward_batch
.
input_ids
.
shape
[
0
]
if
hasattr
(
forward_batch
,
"input_ids"
)
else
0
)
if
batch_size
>
128
:
return
False
return
self
.
_fuse_allreduce_lookup_table
.
get
(
batch_size
,
False
)
def
_build_fuse_allreduce_lookup_table
(
self
):
static_conditions_met
=
(
self
.
layer_id
!=
self
.
config
.
num_hidden_layers
-
1
and
get_tensor_model_parallel_world_size
()
>
1
and
global_server_args_dict
.
get
(
"enable_flashinfer_allreduce_fusion"
,
False
)
and
_is_sm100_supported
and
_is_flashinfer_available
)
if
not
static_conditions_met
:
return
{}
lookup_table
=
{}
for
batch_size
in
range
(
129
):
# 0 to 128
is_last_layer
=
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
should_fuse
=
batch_size
>
0
and
batch_size
<=
128
and
not
is_last_layer
lookup_table
[
batch_size
]
=
should_fuse
return
lookup_table
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -514,8 +481,9 @@ class GptOssDecoderLayer(nn.Module):
)
should_allreduce_fusion
=
(
self
.
_should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
and
not
self
.
is_nextn
self
.
layer_communicator
.
should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
,
should_allreduce_fusion
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment