Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
52fcbbb8
Unverified
Commit
52fcbbb8
authored
Oct 10, 2025
by
Cheng Wan
Committed by
GitHub
Oct 10, 2025
Browse files
Revert "perf: optimize qwen-vl with symm mem allreduce" (#11436)
parent
af96ca11
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
17 additions
and
82 deletions
+17
-82
python/sglang/srt/distributed/device_communicators/all_reduce_utils.py
.../srt/distributed/device_communicators/all_reduce_utils.py
+4
-4
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+0
-3
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+9
-26
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-5
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+3
-44
No files found.
python/sglang/srt/distributed/device_communicators/all_reduce_utils.py
View file @
52fcbbb8
...
...
@@ -3,13 +3,13 @@ MiB = 1024 * 1024
SYMM_MEM_ALL_REDUCE_MAX_SIZES
=
{
9
:
{
2
:
64
*
MiB
,
# 64 MB
4
:
64
*
MiB
,
#
64
MB
6
:
128
*
MiB
,
#
128
MB
8
:
128
*
MiB
,
#
128
MB
4
:
32
*
MiB
,
#
32
MB
6
:
64
*
MiB
,
#
64
MB
8
:
64
*
MiB
,
#
64
MB
},
10
:
{
2
:
64
*
MiB
,
# 64 MB
4
:
64
*
MiB
,
#
64
MB
4
:
32
*
MiB
,
#
32
MB
6
:
128
*
MiB
,
# 128 MB
8
:
128
*
MiB
,
# 128 MB
},
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
52fcbbb8
...
...
@@ -603,11 +603,8 @@ class GroupCoordinator:
def
_all_reduce_in_place
(
self
,
input_
:
torch
.
Tensor
)
->
None
:
pynccl_comm
=
self
.
pynccl_comm
symm_mem_comm
=
self
.
symm_mem_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
all_reduce
(
input_
)
elif
symm_mem_comm
is
not
None
and
not
symm_mem_comm
.
disabled
:
symm_mem_comm
.
all_reduce
(
input_
)
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
52fcbbb8
...
...
@@ -1008,17 +1008,6 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
return
cache
def
apply_interleaved_rope
(
x
:
torch
.
Tensor
,
mrope_section
:
list
[
int
])
->
torch
.
Tensor
:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t
=
x
[
0
].
clone
()
x_t
[...,
1
:
mrope_section
[
1
]
*
3
:
3
]
=
x
[
1
,
...,
1
:
mrope_section
[
1
]
*
3
:
3
]
x_t
[...,
2
:
mrope_section
[
2
]
*
3
:
3
]
=
x
[
2
,
...,
2
:
mrope_section
[
2
]
*
3
:
3
]
return
x_t
class
MRotaryEmbedding
(
RotaryEmbedding
):
"""Rotary Embedding with Multimodal Sections."""
...
...
@@ -1031,14 +1020,12 @@ class MRotaryEmbedding(RotaryEmbedding):
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
mrope_section
:
Optional
[
List
[
int
]]
=
None
,
mrope_interleaved
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
self
.
mrope_section
=
mrope_section
self
.
mrope_interleaved
=
mrope_interleaved
if
self
.
mrope_section
:
expected_sum
=
rotary_dim
//
2
actual_sum
=
sum
(
self
.
mrope_section
)
...
...
@@ -1099,18 +1086,15 @@ class MRotaryEmbedding(RotaryEmbedding):
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
positions
.
ndim
==
2
:
assert
self
.
mrope_section
if
self
.
mrope_interleaved
:
cos
=
apply_interleaved_rope
(
cos
,
self
.
mrope_section
)
sin
=
apply_interleaved_rope
(
sin
,
self
.
mrope_section
)
else
:
cos
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
sin
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
cos
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
sin
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
...
...
@@ -1789,7 +1773,6 @@ def get_rope(
is_neox_style
,
dtype
,
mrope_section
=
rope_scaling
[
"mrope_section"
],
mrope_interleaved
=
rope_scaling
.
get
(
"mrope_interleaved"
,
False
),
)
else
:
rotary_emb
=
RotaryEmbedding
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
52fcbbb8
...
...
@@ -1766,11 +1766,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[
keep_indices
]
self
.
orig_seq_lens
=
self
.
orig_seq_lens
[
keep_indices_device
]
self
.
out_cache_loc
=
None
if
isinstance
(
self
.
seq_lens_cpu
,
torch
.
Tensor
):
# CPU tensor
self
.
seq_lens_sum
=
int
(
self
.
seq_lens_cpu
.
sum
().
item
())
else
:
self
.
seq_lens_sum
=
int
(
np
.
asarray
(
self
.
seq_lens_cpu
).
sum
())
self
.
seq_lens_sum
=
self
.
seq_lens
.
sum
().
item
()
self
.
output_ids
=
self
.
output_ids
[
keep_indices_device
]
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
if
self
.
return_logprob
:
...
...
python/sglang/srt/models/qwen2.py
View file @
52fcbbb8
...
...
@@ -27,7 +27,6 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
...
...
@@ -89,17 +88,10 @@ class Qwen2MLP(nn.Module):
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
should_allreduce_fusion
:
bool
=
False
,
):
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
should_allreduce_fusion
,
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
...
...
@@ -117,11 +109,9 @@ class Qwen2Attention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dual_chunk_attention_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
prefix
:
str
=
""
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
...
...
@@ -153,8 +143,6 @@ class Qwen2Attention(nn.Module):
self
.
total_num_kv_heads
,
bias
=
True
,
quant_config
=
quant_config
,
tp_rank
=
tp_rank
,
tp_size
=
tp_size
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
)
self
.
o_proj
=
RowParallelLinear
(
...
...
@@ -162,8 +150,6 @@ class Qwen2Attention(nn.Module):
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
tp_rank
=
tp_rank
,
tp_size
=
tp_size
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
)
...
...
@@ -209,7 +195,6 @@ class Qwen2DecoderLayer(nn.Module):
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
1000000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
...
...
@@ -231,18 +216,6 @@ class Qwen2DecoderLayer(nn.Module):
dual_chunk_attention_config
=
dual_chunk_attention_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
self
.
layer_id
=
layer_id
self
.
is_layer_sparse
=
False
is_previous_layer_sparse
=
False
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
layer_id
=
layer_id
,
num_layers
=
config
.
num_hidden_layers
,
is_layer_sparse
=
self
.
is_layer_sparse
,
is_previous_layer_sparse
=
is_previous_layer_sparse
,
)
self
.
mlp
=
Qwen2MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
...
...
@@ -255,14 +228,6 @@ class Qwen2DecoderLayer(nn.Module):
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
layer_communicator
=
LayerCommunicator
(
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
allow_reduce_scatter
=
True
,
is_last_layer
=
(
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
),
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -284,13 +249,7 @@ class Qwen2DecoderLayer(nn.Module):
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
should_allreduce_fusion
=
(
self
.
layer_communicator
.
should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
)
hidden_states
=
self
.
mlp
(
hidden_states
,
should_allreduce_fusion
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment