Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
49a5915f
Unverified
Commit
49a5915f
authored
Jul 11, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jul 10, 2025
Browse files
[ready b200] fuse allreduce+add_rmsnorm in prepare_attention + mlp module (#7775)
parent
766392c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
85 additions
and
20 deletions
+85
-20
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+17
-4
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+2
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+66
-14
No files found.
python/sglang/srt/layers/communicator.py
View file @
49a5915f
...
@@ -187,11 +187,24 @@ class LayerCommunicator:
...
@@ -187,11 +187,24 @@ class LayerCommunicator:
if
hidden_states
.
shape
[
0
]
==
0
:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
residual
=
hidden_states
else
:
else
:
if
residual
is
None
:
if
(
residual
=
hidden_states
residual
is
not
None
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
and
hasattr
(
hidden_states
,
"_sglang_needs_allreduce_fusion"
)
and
hidden_states
.
_sglang_needs_allreduce_fusion
):
hidden_states
,
residual
=
(
self
.
input_layernorm
.
forward_with_allreduce_fusion
(
hidden_states
,
residual
)
)
else
:
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
_communicate_simple_fn
(
hidden_states
=
self
.
_communicate_simple_fn
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
...
python/sglang/srt/layers/linear.py
View file @
49a5915f
...
@@ -1367,7 +1367,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1367,7 +1367,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters.
# It does not support additional parameters.
param
.
load_row_parallel_weight
(
loaded_weight
)
param
.
load_row_parallel_weight
(
loaded_weight
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
,
can_fuse_mlp_allreduce
=
False
):
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
...
@@ -1382,7 +1382,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1382,7 +1382,7 @@ class RowParallelLinear(LinearBase):
# bias will not get added more than once in TP>1 case)
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
can_fuse_mlp_allreduce
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
49a5915f
...
@@ -77,6 +77,7 @@ from sglang.srt.layers.quantization.int8_utils import (
...
@@ -77,6 +77,7 @@ from sglang.srt.layers.quantization.int8_utils import (
)
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
,
get_rope_wrapper
from
sglang.srt.layers.rotary_embedding
import
get_rope
,
get_rope_wrapper
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
@@ -100,6 +101,7 @@ from sglang.srt.utils import (
...
@@ -100,6 +101,7 @@ from sglang.srt.utils import (
get_int_env_var
,
get_int_env_var
,
is_cpu
,
is_cpu
,
is_cuda
,
is_cuda
,
is_flashinfer_available
,
is_hip
,
is_hip
,
is_non_idle_and_non_empty
,
is_non_idle_and_non_empty
,
log_info_on_rank0
,
log_info_on_rank0
,
...
@@ -132,6 +134,9 @@ if _is_hip:
...
@@ -132,6 +134,9 @@ if _is_hip:
decode_attention_fwd_grouped_rope
,
decode_attention_fwd_grouped_rope
,
)
)
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -195,13 +200,13 @@ class DeepseekV2MLP(nn.Module):
...
@@ -195,13 +200,13 @@ class DeepseekV2MLP(nn.Module):
)
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
forward_batch
=
None
):
def
forward
(
self
,
x
,
forward_batch
=
None
,
can_fuse_mlp_allreduce
=
False
):
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
return
x
return
x
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
,
can_fuse_mlp_allreduce
=
can_fuse_mlp_allreduce
)
return
x
return
x
...
@@ -409,7 +414,10 @@ class DeepseekV2MoE(nn.Module):
...
@@ -409,7 +414,10 @@ class DeepseekV2MoE(nn.Module):
]
]
def
forward
(
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
can_fuse_mlp_allreduce
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
self
.
_enable_deepep_moe
:
if
not
self
.
_enable_deepep_moe
:
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
...
@@ -418,13 +426,17 @@ class DeepseekV2MoE(nn.Module):
...
@@ -418,13 +426,17 @@ class DeepseekV2MoE(nn.Module):
and
self
.
num_fused_shared_experts
==
0
and
self
.
num_fused_shared_experts
==
0
and
hidden_states
.
shape
[
0
]
<=
DUAL_STREAM_TOKEN_THRESHOLD
and
hidden_states
.
shape
[
0
]
<=
DUAL_STREAM_TOKEN_THRESHOLD
):
):
return
self
.
forward_normal_dual_stream
(
hidden_states
)
return
self
.
forward_normal_dual_stream
(
hidden_states
,
can_fuse_mlp_allreduce
)
else
:
else
:
return
self
.
forward_normal
(
hidden_states
)
return
self
.
forward_normal
(
hidden_states
,
can_fuse_mlp_allreduce
)
else
:
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
def
forward_normal_dual_stream
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_normal_dual_stream
(
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp_allreduce
:
bool
=
False
)
->
torch
.
Tensor
:
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
...
@@ -440,11 +452,13 @@ class DeepseekV2MoE(nn.Module):
...
@@ -440,11 +452,13 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
current_stream
.
wait_stream
(
self
.
alt_stream
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
and
not
can_fuse_mlp_allreduce
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp_allreduce
:
bool
=
False
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"shared_experts"
)
and
use_intel_amx_backend
(
if
hasattr
(
self
,
"shared_experts"
)
and
use_intel_amx_backend
(
self
.
shared_experts
.
gate_up_proj
self
.
shared_experts
.
gate_up_proj
):
):
...
@@ -461,7 +475,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -461,7 +475,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
and
not
can_fuse_mlp_allreduce
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
...
@@ -514,7 +528,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -514,7 +528,7 @@ class DeepseekV2MoE(nn.Module):
None
,
# a2_scale
None
,
# a2_scale
True
,
# is_vnni
True
,
# is_vnni
)
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
and
not
self
.
can_fuse_mlp_allreduce
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
...
@@ -1818,6 +1832,29 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1818,6 +1832,29 @@ class DeepseekV2DecoderLayer(nn.Module):
and
layer_id
%
self
.
config
.
moe_layer_freq
==
0
and
layer_id
%
self
.
config
.
moe_layer_freq
==
0
)
)
def
_should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
)
->
bool
:
"""Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
if
(
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
or
get_tensor_model_parallel_world_size
()
<=
1
):
return
False
if
not
global_server_args_dict
.
get
(
"enable_flashinfer_allreduce_fusion"
,
False
):
return
False
if
not
_is_sm100_supported
or
not
_is_flashinfer_available
:
return
False
if
hasattr
(
forward_batch
,
"input_ids"
)
and
(
forward_batch
.
input_ids
.
shape
[
0
]
==
0
or
forward_batch
.
input_ids
.
shape
[
0
]
>
128
):
return
False
return
True
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -1842,12 +1879,27 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1842,12 +1879,27 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
,
forward_batch
)
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
can_fuse_mlp_allreduce
=
(
self
.
_should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
and
not
(
self
.
enable_dp_attention
and
self
.
speculative_algorithm
.
is_eagle
())
hidden_states
,
residual
,
forward_batch
and
not
self
.
is_nextn
)
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
,
can_fuse_mlp_allreduce
)
if
can_fuse_mlp_allreduce
:
hidden_states
.
_sglang_needs_allreduce_fusion
=
True
if
not
can_fuse_mlp_allreduce
:
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
)
if
self
.
enable_dp_attention
and
self
.
speculative_algorithm
.
is_eagle
():
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
hidden_states
=
hidden_states
.
clone
()
return
hidden_states
,
residual
return
hidden_states
,
residual
def
op_comm_prepare_attn
(
def
op_comm_prepare_attn
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment