Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0b6f535f
Unverified
Commit
0b6f535f
authored
Oct 13, 2025
by
Yuan Luo
Committed by
GitHub
Oct 13, 2025
Browse files
[Reland] perf: optimize qwen-vl with symm mem allreduce (#11457)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
c5fe3c0b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
13 deletions
+33
-13
python/sglang/srt/distributed/device_communicators/all_reduce_utils.py
.../srt/distributed/device_communicators/all_reduce_utils.py
+4
-4
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+3
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+26
-9
No files found.
python/sglang/srt/distributed/device_communicators/all_reduce_utils.py
View file @
0b6f535f
...
@@ -3,13 +3,13 @@ MiB = 1024 * 1024
...
@@ -3,13 +3,13 @@ MiB = 1024 * 1024
SYMM_MEM_ALL_REDUCE_MAX_SIZES
=
{
SYMM_MEM_ALL_REDUCE_MAX_SIZES
=
{
9
:
{
9
:
{
2
:
64
*
MiB
,
# 64 MB
2
:
64
*
MiB
,
# 64 MB
4
:
32
*
MiB
,
#
32
MB
4
:
64
*
MiB
,
#
64
MB
6
:
64
*
MiB
,
#
64
MB
6
:
128
*
MiB
,
#
128
MB
8
:
64
*
MiB
,
#
64
MB
8
:
128
*
MiB
,
#
128
MB
},
},
10
:
{
10
:
{
2
:
64
*
MiB
,
# 64 MB
2
:
64
*
MiB
,
# 64 MB
4
:
32
*
MiB
,
#
32
MB
4
:
64
*
MiB
,
#
64
MB
6
:
128
*
MiB
,
# 128 MB
6
:
128
*
MiB
,
# 128 MB
8
:
128
*
MiB
,
# 128 MB
8
:
128
*
MiB
,
# 128 MB
},
},
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
0b6f535f
...
@@ -615,8 +615,11 @@ class GroupCoordinator:
...
@@ -615,8 +615,11 @@ class GroupCoordinator:
def
_all_reduce_in_place
(
self
,
input_
:
torch
.
Tensor
)
->
None
:
def
_all_reduce_in_place
(
self
,
input_
:
torch
.
Tensor
)
->
None
:
pynccl_comm
=
self
.
pynccl_comm
pynccl_comm
=
self
.
pynccl_comm
symm_mem_comm
=
self
.
symm_mem_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
all_reduce
(
input_
)
pynccl_comm
.
all_reduce
(
input_
)
elif
symm_mem_comm
is
not
None
and
not
symm_mem_comm
.
disabled
:
symm_mem_comm
.
all_reduce
(
input_
)
else
:
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
0b6f535f
...
@@ -1008,6 +1008,17 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
...
@@ -1008,6 +1008,17 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
return
cache
return
cache
def
apply_interleaved_rope
(
x
:
torch
.
Tensor
,
mrope_section
:
list
[
int
])
->
torch
.
Tensor
:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t
=
x
[
0
].
clone
()
x_t
[...,
1
:
mrope_section
[
1
]
*
3
:
3
]
=
x
[
1
,
...,
1
:
mrope_section
[
1
]
*
3
:
3
]
x_t
[...,
2
:
mrope_section
[
2
]
*
3
:
3
]
=
x
[
2
,
...,
2
:
mrope_section
[
2
]
*
3
:
3
]
return
x_t
class
MRotaryEmbedding
(
RotaryEmbedding
):
class
MRotaryEmbedding
(
RotaryEmbedding
):
"""Rotary Embedding with Multimodal Sections."""
"""Rotary Embedding with Multimodal Sections."""
...
@@ -1020,12 +1031,14 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1020,12 +1031,14 @@ class MRotaryEmbedding(RotaryEmbedding):
is_neox_style
:
bool
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
mrope_section
:
Optional
[
List
[
int
]]
=
None
,
mrope_section
:
Optional
[
List
[
int
]]
=
None
,
mrope_interleaved
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
(
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
)
self
.
mrope_section
=
mrope_section
self
.
mrope_section
=
mrope_section
self
.
mrope_interleaved
=
mrope_interleaved
if
self
.
mrope_section
:
if
self
.
mrope_section
:
expected_sum
=
rotary_dim
//
2
expected_sum
=
rotary_dim
//
2
actual_sum
=
sum
(
self
.
mrope_section
)
actual_sum
=
sum
(
self
.
mrope_section
)
...
@@ -1086,15 +1099,18 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1086,15 +1099,18 @@ class MRotaryEmbedding(RotaryEmbedding):
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
positions
.
ndim
==
2
:
if
positions
.
ndim
==
2
:
assert
self
.
mrope_section
assert
self
.
mrope_section
if
self
.
mrope_interleaved
:
cos
=
torch
.
cat
(
cos
=
apply_interleaved_rope
(
cos
,
self
.
mrope_section
)
[
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
sin
=
apply_interleaved_rope
(
sin
,
self
.
mrope_section
)
dim
=-
1
,
else
:
)
cos
=
torch
.
cat
(
sin
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
[
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
dim
=-
1
,
)
)
sin
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
query_shape
=
query
.
shape
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
...
@@ -1768,6 +1784,7 @@ def get_rope(
...
@@ -1768,6 +1784,7 @@ def get_rope(
is_neox_style
,
is_neox_style
,
dtype
,
dtype
,
mrope_section
=
rope_scaling
[
"mrope_section"
],
mrope_section
=
rope_scaling
[
"mrope_section"
],
mrope_interleaved
=
rope_scaling
.
get
(
"mrope_interleaved"
,
False
),
)
)
else
:
else
:
rotary_emb
=
RotaryEmbedding
(
rotary_emb
=
RotaryEmbedding
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment