Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4c22897a
Unverified
Commit
4c22897a
authored
Aug 14, 2025
by
wxzhoucs
Committed by
GitHub
Aug 13, 2025
Browse files
Feature: support qwen and llama4 reducescatter for dp attention padding (#9101)
parent
1bc183c6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
16 deletions
+68
-16
python/sglang/srt/lora/layers.py
python/sglang/srt/lora/layers.py
+6
-2
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+10
-2
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+16
-3
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+18
-4
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+18
-5
No files found.
python/sglang/srt/lora/layers.py
View file @
4c22897a
...
...
@@ -253,7 +253,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
)
return
lora_output
def
forward
(
self
,
input_
:
torch
.
Tensor
):
def
forward
(
self
,
input_
:
torch
.
Tensor
,
skip_all_reduce
=
False
):
# duplicate the logic in RowParallelLinear
if
self
.
base_layer
.
input_is_parallel
:
input_parallel
=
input_
...
...
@@ -270,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
if
self
.
set_lora
:
output_parallel
=
self
.
apply_lora
(
output_parallel
,
input_parallel
)
if
self
.
base_layer
.
reduce_results
and
self
.
base_layer
.
tp_size
>
1
:
if
(
self
.
base_layer
.
reduce_results
and
self
.
base_layer
.
tp_size
>
1
and
not
skip_all_reduce
):
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output_
=
output_parallel
...
...
python/sglang/srt/models/llama.py
View file @
4c22897a
...
...
@@ -91,10 +91,18 @@ class LlamaMLP(nn.Module):
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
forward_batch
=
None
):
def
forward
(
self
,
x
,
forward_batch
=
None
,
use_reduce_scatter
:
bool
=
False
,
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
use_reduce_scatter
,
)
return
x
...
...
python/sglang/srt/models/llama4.py
View file @
4c22897a
...
...
@@ -131,14 +131,19 @@ class Llama4MoE(nn.Module):
reduce_results
=
False
,
# We need to do scatter before reduce
)
def
forward
(
self
,
hidden_states
,
forward_batch
:
ForwardBatch
):
def
forward
(
self
,
hidden_states
,
forward_batch
:
ForwardBatch
,
use_reduce_scatter
:
bool
=
False
,
):
shared_out
,
routed_out
=
self
.
_forward_core
(
hidden_states
,
forward_batch
.
forward_mode
)
out_aD
=
routed_out
+
shared_out
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
and
not
use_reduce_scatter
:
out_aD
=
tensor_model_parallel_all_reduce
(
out_aD
)
return
out_aD
...
...
@@ -412,6 +417,7 @@ class Llama4DecoderLayer(nn.Module):
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
allow_reduce_scatter
=
True
,
)
def
_is_moe_layer
(
self
,
layer_id
:
int
)
->
bool
:
...
...
@@ -441,8 +447,15 @@ class Llama4DecoderLayer(nn.Module):
hidden_states
,
residual
,
forward_batch
)
# For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter
=
self
.
layer_communicator
.
should_use_reduce_scatter
(
forward_batch
)
# Fully Connected
hidden_states
=
self
.
feed_forward
(
hidden_states
,
forward_batch
)
hidden_states
=
self
.
feed_forward
(
hidden_states
,
forward_batch
,
use_reduce_scatter
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
4c22897a
...
...
@@ -108,10 +108,14 @@ class Qwen2MoeMLP(nn.Module):
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
use_reduce_scatter
:
bool
=
False
,
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
use_reduce_scatter
)
return
x
...
...
@@ -176,7 +180,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
...
@@ -194,6 +201,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
and
not
use_reduce_scatter
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
...
@@ -368,6 +376,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
allow_reduce_scatter
=
True
,
)
def
forward
(
...
...
@@ -393,7 +402,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states
,
residual
,
forward_batch
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
# For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter
=
self
.
layer_communicator
.
should_use_reduce_scatter
(
forward_batch
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
,
use_reduce_scatter
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
4c22897a
...
...
@@ -144,11 +144,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self
.
top_k
=
config
.
num_experts_per_tok
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
():
return
self
.
forward_normal
(
hidden_states
)
return
self
.
forward_normal
(
hidden_states
,
use_reduce_scatter
)
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
...
...
@@ -159,7 +162,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
if
name
not
in
[
"correction_bias"
]
]
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
...
@@ -167,7 +174,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
and
not
use_reduce_scatter
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
...
@@ -521,6 +528,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
allow_reduce_scatter
=
True
,
)
def
forward
(
...
...
@@ -546,7 +554,12 @@ class Qwen3MoeDecoderLayer(nn.Module):
hidden_states
,
residual
,
forward_batch
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
# For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter
=
self
.
layer_communicator
.
should_use_reduce_scatter
(
forward_batch
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
,
use_reduce_scatter
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment