Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
819fc591
Unverified
Commit
819fc591
authored
Nov 03, 2025
by
Yuan Luo
Committed by
GitHub
Nov 02, 2025
Browse files
Add prefix for torch symm mem (#12506)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
7efd8b3d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
87 additions
and
69 deletions
+87
-69
benchmark/kernels/all_reduce/benchmark_torch_symm_mem.py
benchmark/kernels/all_reduce/benchmark_torch_symm_mem.py
+31
-14
python/sglang/srt/distributed/device_communicators/all_reduce_utils.py
.../srt/distributed/device_communicators/all_reduce_utils.py
+1
-1
python/sglang/srt/distributed/device_communicators/torch_symm_mem.py
...ng/srt/distributed/device_communicators/torch_symm_mem.py
+21
-22
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+31
-29
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-2
No files found.
benchmark/kernels/all_reduce/benchmark_symm_mem.py
→
benchmark/kernels/all_reduce/benchmark_
torch_
symm_mem.py
View file @
819fc591
"""For Now, SYMM_MEM is only supported on TP8 case
export WORLD_SIZE=1
"""For Now, TORCH_SYMM_MEM is only supported on following limited tp case
SM90: {
2: 64 * MiB, # 64 MB
4: 64 * MiB, # 64 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
},
SM100: {
2: 64 * MiB, # 64 MB
4: 64 * MiB, # 64 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
export WORLD_SIZE=8
export RANK=0
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=12345
...
...
@@ -9,7 +22,7 @@ torchrun --nproc_per_node gpu \
--nnodes $WORLD_SIZE
\
--node_rank $RANK
\
--master_addr $MASTER_ADDR
\
--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_symm_mem.py
--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_
torch_
symm_mem.py
"""
import
os
...
...
@@ -22,12 +35,14 @@ from torch.distributed import ProcessGroup
from
sglang.srt.distributed
import
init_distributed_environment
from
sglang.srt.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
sglang.srt.distributed.device_communicators.symm_mem
import
SymmMemCommunicator
from
sglang.srt.distributed.device_communicators.torch_symm_mem
import
(
TorchSymmMemCommunicator
,
)
from
sglang.srt.distributed.parallel_state
import
(
get_tensor_model_parallel_group
,
graph_capture
,
initialize_model_parallel
,
set_symm_mem_all_reduce
,
set_
torch_
symm_mem_all_reduce
,
)
# CI environment detection
...
...
@@ -42,10 +57,10 @@ def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Ten
return
torch_input
def
symm_mem_allreduce
(
symm_mem_input
:
torch
.
Tensor
,
symm_mem_comm
:
SymmMemCommunicator
def
torch_
symm_mem_allreduce
(
torch_
symm_mem_input
:
torch
.
Tensor
,
torch_
symm_mem_comm
:
Torch
SymmMemCommunicator
)
->
torch
.
Tensor
:
return
symm_mem_comm
.
all_reduce
(
symm_mem_input
)
return
torch_
symm_mem_comm
.
all_reduce
(
torch_
symm_mem_input
)
def
pynccl_allreduce
(
...
...
@@ -170,7 +185,7 @@ if __name__ == "__main__":
rank
=
dist
.
get_rank
()
torch
.
cuda
.
set_device
(
rank
%
8
)
device
=
torch
.
cuda
.
current_device
()
set_symm_mem_all_reduce
(
True
)
set_
torch_
symm_mem_all_reduce
(
True
)
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
...
...
@@ -180,7 +195,7 @@ if __name__ == "__main__":
group
=
get_tensor_model_parallel_group
().
device_group
cpu_group
=
get_tensor_model_parallel_group
().
cpu_group
pynccl_comm
=
get_tensor_model_parallel_group
().
pynccl_comm
symm_mem_comm
=
get_tensor_model_parallel_group
().
symm_mem_comm
torch_
symm_mem_comm
=
get_tensor_model_parallel_group
().
torch_
symm_mem_comm
dist
.
barrier
()
profile
=
False
dtype
=
torch
.
bfloat16
...
...
@@ -204,10 +219,12 @@ if __name__ == "__main__":
lambda
inp
:
torch_allreduce
(
inp
,
group
),
inp_randn
)
symm_mem_eager_output
,
symm_mem_eager_time
=
_bench_eager_time
(
lambda
inp
:
symm_mem_allreduce
(
inp
,
symm_mem_comm
),
inp_randn
lambda
inp
:
torch_symm_mem_allreduce
(
inp
,
torch_symm_mem_comm
),
inp_randn
,
)
symm_mem_graph_output
,
symm_mem_graph_time
=
_bench_graph_time
(
lambda
inp
:
symm_mem_allreduce
(
inp
,
symm_mem_comm
),
inp_randn
lambda
inp
:
torch_symm_mem_allreduce
(
inp
,
torch_symm_mem_comm
),
inp_randn
,
)
# since pynccl is inplace op, this return result is not correct if graph loop > 1
_
,
pynccl_graph_time
=
_bench_graph_time
(
...
...
@@ -229,6 +246,6 @@ if __name__ == "__main__":
if
rank
==
0
:
print_markdown_table
(
result
)
if
profile
:
prof_dir
=
f
"prof/symm_mem"
prof_dir
=
f
"prof/
torch_
symm_mem"
os
.
makedirs
(
prof_dir
,
exist_ok
=
True
)
ctx
.
export_chrome_trace
(
f
"
{
prof_dir
}
/trace_rank
{
dist
.
get_rank
()
}
.json.gz"
)
python/sglang/srt/distributed/device_communicators/all_reduce_utils.py
View file @
819fc591
MiB
=
1024
*
1024
SYMM_MEM_ALL_REDUCE_MAX_SIZES
=
{
TORCH_
SYMM_MEM_ALL_REDUCE_MAX_SIZES
=
{
9
:
{
2
:
64
*
MiB
,
# 64 MB
4
:
64
*
MiB
,
# 64 MB
...
...
python/sglang/srt/distributed/device_communicators/symm_mem.py
→
python/sglang/srt/distributed/device_communicators/
torch_
symm_mem.py
View file @
819fc591
...
...
@@ -7,33 +7,29 @@ import torch.distributed as dist
from
torch.distributed
import
ProcessGroup
from
sglang.srt.distributed.device_communicators.all_reduce_utils
import
(
SYMM_MEM_ALL_REDUCE_MAX_SIZES
,
TORCH_
SYMM_MEM_ALL_REDUCE_MAX_SIZES
,
)
from
sglang.srt.utils
import
is_cuda
,
is_hip
try
:
import
torch.distributed._symmetric_memory
as
torch_symm_mem
symm_mem_available
=
True
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
torch_symm_mem_available
=
False
if
_is_cuda
:
torch_symm_mem_available
=
True
except
ImportError
:
symm_mem_available
=
False
torch_
symm_mem_available
=
False
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
symm_mem_is_available
=
False
if
_is_hip
:
symm_mem_is_available
=
False
if
_is_cuda
:
symm_mem_is_available
=
True
class
SymmMemCommunicator
:
class
TorchSymmMemCommunicator
:
"""
Thin wrapper around symmetric-memory collectives.
Thin wrapper around
torch-
symmetric-memory collectives.
This communicator:
- Validates device capability and world size.
...
...
@@ -62,7 +58,7 @@ class SymmMemCommunicator:
self
.
disabled
=
True
if
not
symm_mem_available
:
if
not
torch_
symm_mem_available
:
return
if
isinstance
(
device
,
int
):
...
...
@@ -77,19 +73,22 @@ class SymmMemCommunicator:
self
.
device_capability
=
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
if
self
.
device_capability
<
9
:
logger
.
warning
(
"SymmMemCommunicator: Device capability %s not supported, "
"
Torch
SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available."
,
self
.
device_capability
,
)
return
if
self
.
world_size
not
in
SYMM_MEM_ALL_REDUCE_MAX_SIZES
[
self
.
device_capability
]:
if
(
self
.
world_size
not
in
TORCH_SYMM_MEM_ALL_REDUCE_MAX_SIZES
[
self
.
device_capability
]
):
logger
.
warning
(
"SymmMemCommunicator: World size %d not supported, "
"
Torch
SymmMemCommunicator: World size %d not supported, "
"communicator is not available."
,
self
.
world_size
,
)
return
self
.
max_size
=
SYMM_MEM_ALL_REDUCE_MAX_SIZES
[
self
.
device_capability
][
self
.
max_size
=
TORCH_
SYMM_MEM_ALL_REDUCE_MAX_SIZES
[
self
.
device_capability
][
self
.
world_size
]
self
.
buffer
=
torch_symm_mem
.
empty
(
...
...
@@ -100,7 +99,7 @@ class SymmMemCommunicator:
handle
=
torch_symm_mem
.
rendezvous
(
self
.
buffer
,
self
.
group
.
group_name
)
if
handle
.
multicast_ptr
==
0
:
logger
.
warning
(
"SymmMemCommunicator: symmetric memory "
"
Torch
SymmMemCommunicator:
torch
symmetric memory "
"multicast operations are not supported."
)
self
.
buffer
=
None
...
...
@@ -108,7 +107,7 @@ class SymmMemCommunicator:
return
self
.
disabled
=
False
def
should_symm_mem_allreduce
(
self
,
inp
:
torch
.
Tensor
):
def
should_
torch_
symm_mem_allreduce
(
self
,
inp
:
torch
.
Tensor
):
"""
Fast-path eligibility check for a given tensor.
...
...
@@ -135,7 +134,7 @@ class SymmMemCommunicator:
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Optional
[
torch
.
Tensor
]:
"""
Perform an in-place sum all-reduce via symmetric memory.
Perform an in-place sum all-reduce via
torch
symmetric memory.
Args:
inp: Input tensor on the target CUDA device (bfloat16).
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
819fc591
...
...
@@ -217,14 +217,16 @@ class GroupCoordinator:
use_pynccl
:
bool
# a hint of whether to use PyNccl
use_pymscclpp
:
bool
# a hint of whether to use PyMsccl
use_custom_allreduce
:
bool
# a hint of whether to use CustomAllreduce
use_torch_symm_mem
:
bool
# a hint of whether to use SymmMemAllReduce
use_torch_symm_mem_all_reduce
:
(
bool
# a hint of whether to use TorchSymmMemAllReduce
)
use_message_queue_broadcaster
:
(
bool
# a hint of whether to use message queue broadcaster
)
# communicators are only created for world size > 1
pynccl_comm
:
Optional
[
Any
]
# PyNccl communicator
ca_comm
:
Optional
[
Any
]
# Custom allreduce communicator
symm_mem_comm
:
Optional
[
Any
]
#
S
ymm mem communicator
torch_
symm_mem_comm
:
Optional
[
Any
]
#
Torch s
ymm mem communicator
mq_broadcaster
:
Optional
[
Any
]
# shared memory broadcaster
def
__init__
(
...
...
@@ -235,7 +237,7 @@ class GroupCoordinator:
use_pynccl
:
bool
,
use_pymscclpp
:
bool
,
use_custom_allreduce
:
bool
,
use_torch_symm_mem
:
bool
,
use_torch_symm_mem
_all_reduce
:
bool
,
use_hpu_communicator
:
bool
,
use_xpu_communicator
:
bool
,
use_npu_communicator
:
bool
,
...
...
@@ -295,7 +297,7 @@ class GroupCoordinator:
self
.
pynccl_use_current_stream
=
pynccl_use_current_stream
self
.
use_pymscclpp
=
use_pymscclpp
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_torch_symm_mem
=
use_torch_symm_mem
self
.
use_torch_symm_mem
_all_reduce
=
use_torch_symm_mem
_all_reduce
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
self
.
use_npu_communicator
=
use_npu_communicator
...
...
@@ -311,8 +313,8 @@ class GroupCoordinator:
from
sglang.srt.distributed.device_communicators.pynccl
import
(
PyNcclCommunicator
,
)
from
sglang.srt.distributed.device_communicators.symm_mem
import
(
SymmMemCommunicator
,
from
sglang.srt.distributed.device_communicators.
torch_
symm_mem
import
(
Torch
SymmMemCommunicator
,
)
if
is_hip
():
...
...
@@ -363,9 +365,9 @@ class GroupCoordinator:
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to initialize QuickAllReduce:
{
e
}
"
)
self
.
symm_mem_comm
:
Optional
[
SymmMemCommunicator
]
=
None
if
self
.
use_torch_symm_mem
and
self
.
world_size
>
1
:
self
.
symm_mem_comm
=
SymmMemCommunicator
(
self
.
torch_
symm_mem_comm
:
Optional
[
Torch
SymmMemCommunicator
]
=
None
if
self
.
use_torch_symm_mem
_all_reduce
and
self
.
world_size
>
1
:
self
.
torch_
symm_mem_comm
=
Torch
SymmMemCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
...
...
@@ -580,11 +582,11 @@ class GroupCoordinator:
):
outplace_all_reduce_method
=
"pymscclpp"
elif
(
self
.
symm_mem_comm
is
not
None
and
not
self
.
symm_mem_comm
.
disabled
and
self
.
symm_mem_comm
.
should_symm_mem_allreduce
(
input_
)
self
.
torch_
symm_mem_comm
is
not
None
and
not
self
.
torch_
symm_mem_comm
.
disabled
and
self
.
torch_
symm_mem_comm
.
should_
torch_
symm_mem_allreduce
(
input_
)
):
outplace_all_reduce_method
=
"symm_mem"
outplace_all_reduce_method
=
"
torch_
symm_mem"
if
outplace_all_reduce_method
is
not
None
:
return
torch
.
ops
.
sglang
.
outplace_all_reduce
(
input_
,
...
...
@@ -601,7 +603,7 @@ class GroupCoordinator:
ca_comm
=
self
.
ca_comm
qr_comm
=
self
.
qr_comm
pymscclpp_comm
=
self
.
pymscclpp_comm
symm_mem_comm
=
self
.
symm_mem_comm
torch_
symm_mem_comm
=
self
.
torch_
symm_mem_comm
assert
any
([
qr_comm
,
ca_comm
,
pymscclpp_comm
])
if
outplace_all_reduce_method
==
"ca"
:
assert
not
ca_comm
.
disabled
...
...
@@ -609,9 +611,9 @@ class GroupCoordinator:
elif
outplace_all_reduce_method
==
"qr"
:
assert
not
qr_comm
.
disabled
out
=
qr_comm
.
quick_all_reduce
(
input_
)
elif
outplace_all_reduce_method
==
"symm_mem"
:
assert
not
symm_mem_comm
.
disabled
out
=
symm_mem_comm
.
all_reduce
(
input_
)
elif
outplace_all_reduce_method
==
"
torch_
symm_mem"
:
assert
not
torch_
symm_mem_comm
.
disabled
out
=
torch_
symm_mem_comm
.
all_reduce
(
input_
)
else
:
assert
not
pymscclpp_comm
.
disabled
out
=
pymscclpp_comm
.
all_reduce
(
input_
)
...
...
@@ -620,11 +622,11 @@ class GroupCoordinator:
def
_all_reduce_in_place
(
self
,
input_
:
torch
.
Tensor
)
->
None
:
pynccl_comm
=
self
.
pynccl_comm
symm_mem_comm
=
self
.
symm_mem_comm
torch_
symm_mem_comm
=
self
.
torch_
symm_mem_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
all_reduce
(
input_
)
elif
symm_mem_comm
is
not
None
and
not
symm_mem_comm
.
disabled
:
symm_mem_comm
.
all_reduce
(
input_
)
elif
torch_
symm_mem_comm
is
not
None
and
not
torch_
symm_mem_comm
.
disabled
:
torch_
symm_mem_comm
.
all_reduce
(
input_
)
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
...
...
@@ -1267,7 +1269,7 @@ def init_world_group(
use_pynccl
=
False
,
use_pymscclpp
=
False
,
use_custom_allreduce
=
False
,
use_torch_symm_mem
=
False
,
use_torch_symm_mem
_all_reduce
=
False
,
use_hpu_communicator
=
False
,
use_xpu_communicator
=
False
,
use_npu_communicator
=
False
,
...
...
@@ -1284,15 +1286,15 @@ def init_model_parallel_group(
group_name
:
Optional
[
str
]
=
None
,
use_mscclpp_allreduce
:
Optional
[
bool
]
=
None
,
pynccl_use_current_stream
:
bool
=
True
,
use_symm_mem_allreduce
:
Optional
[
bool
]
=
None
,
use_
torch_
symm_mem_allreduce
:
Optional
[
bool
]
=
None
,
torch_compile
:
Optional
[
bool
]
=
None
,
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
if
use_mscclpp_allreduce
is
None
:
use_mscclpp_allreduce
=
_ENABLE_MSCCLPP_ALL_REDUCE
if
use_symm_mem_allreduce
is
None
:
use_symm_mem_allreduce
=
_ENABLE_SYMM_MEM_ALL_REDUCE
if
use_
torch_
symm_mem_allreduce
is
None
:
use_
torch_
symm_mem_allreduce
=
_ENABLE_
TORCH_
SYMM_MEM_ALL_REDUCE
return
GroupCoordinator
(
group_ranks
=
group_ranks
,
local_rank
=
local_rank
,
...
...
@@ -1300,7 +1302,7 @@ def init_model_parallel_group(
use_pynccl
=
not
(
_is_npu
or
_is_xpu
),
use_pymscclpp
=
use_mscclpp_allreduce
,
use_custom_allreduce
=
use_custom_allreduce
,
use_torch_symm_mem
=
use
_symm_mem_allreduce
,
use_torch_symm_mem
_all_reduce
=
use_torch
_symm_mem_allreduce
,
use_hpu_communicator
=
True
,
use_xpu_communicator
=
True
,
use_npu_communicator
=
True
,
...
...
@@ -1388,7 +1390,7 @@ logger = logging.getLogger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE
=
True
_ENABLE_MSCCLPP_ALL_REDUCE
=
False
_ENABLE_SYMM_MEM_ALL_REDUCE
=
False
_ENABLE_
TORCH_
SYMM_MEM_ALL_REDUCE
=
False
def
set_custom_all_reduce
(
enable
:
bool
):
...
...
@@ -1401,9 +1403,9 @@ def set_mscclpp_all_reduce(enable: bool):
_ENABLE_MSCCLPP_ALL_REDUCE
=
enable
def
set_symm_mem_all_reduce
(
enable
:
bool
):
global
_ENABLE_SYMM_MEM_ALL_REDUCE
_ENABLE_SYMM_MEM_ALL_REDUCE
=
enable
def
set_
torch_
symm_mem_all_reduce
(
enable
:
bool
):
global
_ENABLE_
TORCH_
SYMM_MEM_ALL_REDUCE
_ENABLE_
TORCH_
SYMM_MEM_ALL_REDUCE
=
enable
def
init_distributed_environment
(
...
...
python/sglang/srt/layers/dp_attention.py
View file @
819fc591
...
...
@@ -280,7 +280,7 @@ def initialize_dp_attention(
use_pynccl
=
SYNC_TOKEN_IDS_ACROSS_TP
,
use_pymscclpp
=
False
,
use_custom_allreduce
=
False
,
use_torch_symm_mem
=
False
,
use_torch_symm_mem
_all_reduce
=
False
,
use_hpu_communicator
=
False
,
use_xpu_communicator
=
False
,
use_npu_communicator
=
False
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
819fc591
...
...
@@ -56,7 +56,7 @@ from sglang.srt.distributed import (
initialize_model_parallel
,
set_custom_all_reduce
,
set_mscclpp_all_reduce
,
set_symm_mem_all_reduce
,
set_
torch_
symm_mem_all_reduce
,
)
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.elastic_ep.elastic_ep
import
ElasticEPStateManager
...
...
@@ -608,7 +608,7 @@ class ModelRunner:
dist_init_method
=
f
"tcp://127.0.0.1:
{
self
.
dist_port
}
"
set_custom_all_reduce
(
not
self
.
server_args
.
disable_custom_all_reduce
)
set_mscclpp_all_reduce
(
self
.
server_args
.
enable_mscclpp
)
set_symm_mem_all_reduce
(
self
.
server_args
.
enable_torch_symm_mem
)
set_
torch_
symm_mem_all_reduce
(
self
.
server_args
.
enable_torch_symm_mem
)
if
not
self
.
is_draft_worker
:
if
self
.
device
==
"cpu"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment