Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a9e0e668
Commit
a9e0e668
authored
Nov 04, 2025
by
maxiao1
Browse files
Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'
enable custom_allreduce See merge request OpenDAS/sglang!6
parents
785e5e90
0c5532b0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
369 additions
and
15 deletions
+369
-15
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+82
-7
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+274
-2
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+13
-6
No files found.
python/sglang/srt/_custom_ops.py
View file @
a9e0e668
...
@@ -19,14 +19,15 @@ logger = logging.getLogger(__name__)
...
@@ -19,14 +19,15 @@ logger = logging.getLogger(__name__)
use_vllm_custom_allreduce
=
get_bool_env_var
(
use_vllm_custom_allreduce
=
get_bool_env_var
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
"false"
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
"false"
)
)
use_dcu_custom_allreduce
=
get_bool_env_var
(
"USE_DCU_CUSTOM_ALLREDUCE"
,
default
=
"false"
)
if
not
is_hpu
():
if
not
is_hpu
():
# ROCm does not use vllm custom allreduce
# ROCm does not use vllm custom allreduce
# if use_vllm_custom_allreduce and not is_hip():
if
use_vllm_custom_allreduce
and
not
is_hip
():
if
use_vllm_custom_allreduce
:
try
:
try
:
import
vllm._C
# noqa: F401
import
vllm._C
# noqa: F401
print
(
"[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)"
)
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
else
:
else
:
...
@@ -35,12 +36,15 @@ if not is_hpu():
...
@@ -35,12 +36,15 @@ if not is_hpu():
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
if
use_dcu_custom_allreduce
:
try
:
import
vllm._C
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
# if not is_hip() and not is_npu():
if
not
is_hip
()
and
not
is_npu
():
if
not
is_npu
():
if
use_vllm_custom_allreduce
:
if
use_vllm_custom_allreduce
:
custom_op
=
torch
.
ops
.
_C_custom_ar
custom_op
=
torch
.
ops
.
_C_custom_ar
print
(
"[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)"
)
else
:
else
:
custom_op
=
sgl_kernel
.
allreduce
custom_op
=
sgl_kernel
.
allreduce
...
@@ -79,8 +83,79 @@ if not is_npu():
...
@@ -79,8 +83,79 @@ if not is_npu():
)
->
None
:
)
->
None
:
custom_op
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
custom_op
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
elif
is_hip
and
use_dcu_custom_allreduce
:
# custom ar
def
init_custom_ar
(
ipc_tensors
:
list
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
fully_connected
:
bool
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
fully_connected
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
list
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
tuple
[
list
[
int
],
list
[
int
]]:
return
torch
.
ops
.
_C_custom_ar
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
list
[
list
[
int
]],
offsets
:
list
[
list
[
int
]])
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_shared_buffer_and_handle
(
size
:
int
)
->
tuple
[
int
,
torch
.
Tensor
]:
return
torch
.
ops
.
_C_custom_ar
.
allocate_shared_buffer_and_handle
(
size
)
def
open_mem_handle
(
mem_handle
:
torch
.
Tensor
):
return
torch
.
ops
.
_C_custom_ar
.
open_mem_handle
(
mem_handle
)
def
free_shared_buffer
(
ptr
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
free_shared_buffer
(
ptr
)
def
read_cache
(
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
list
[
torch
.
Tensor
],
value_caches
:
list
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
slot_mapping
,
kv_cache_dtype
)
def
write_cache_multi_layers
(
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
list
[
torch
.
Tensor
],
value_caches
:
list
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
slot_mapping
,
kv_cache_dtype
)
else
:
else
:
# ROCM custom allreduce
#
sgl_kernel
ROCM custom allreduce
def
init_custom_ar
(
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
meta
:
torch
.
Tensor
,
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
a9e0e668
...
@@ -27,10 +27,11 @@ _is_hip = is_hip()
...
@@ -27,10 +27,11 @@ _is_hip = is_hip()
try
:
try
:
# if ops.use_vllm_custom_allreduce and not _is_hip:
if
ops
.
use_vllm_custom_allreduce
and
not
_is_hip
:
if
ops
.
use_vllm_custom_allreduce
:
# Use vLLM custom allreduce
# Use vLLM custom allreduce
ops
.
meta_size
()
ops
.
meta_size
()
elif
ops
.
use_dcu_custom_allreduce
:
ops
.
meta_size
()
else
:
else
:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
# noqa: F401
import
sgl_kernel
# noqa: F401
...
@@ -420,3 +421,274 @@ class CustomAllreduce:
...
@@ -420,3 +421,274 @@ class CustomAllreduce:
def
__del__
(
self
):
def
__del__
(
self
):
self
.
close
()
self
.
close
()
class
DCUCustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
,
16
]
# max_size: max supported allreduce size
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
max_size
=
8192
*
512
)
->
None
:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self
.
_IS_CAPTURING
=
False
self
.
disabled
=
True
if
not
custom_ar
:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger
.
info
(
"Custom allreduce is disabled because "
"of missing custom allreduce library"
)
return
self
.
group
=
group
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"CustomAllreduce should be attached to a non-NCCL group."
)
if
not
all
(
in_the_same_node_as
(
group
,
source_rank
=
0
)):
# No need to initialize custom allreduce for multi-node case.
logger
.
warning
(
"Custom allreduce is disabled because this process group"
" spans across nodes."
)
return
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
self
.
rank
=
rank
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
# if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX:
if
world_size
>
16
:
return
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
return
if
world_size
not
in
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
:
logger
.
warning
(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
))
return
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
cuda_visible_devices
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
None
)
if
cuda_visible_devices
:
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
device_ids
=
list
(
range
(
torch
.
cuda
.
device_count
()))
physical_device_id
=
device_ids
[
device
.
index
]
tensor
=
torch
.
tensor
([
physical_device_id
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
gather_list
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
for
_
in
range
(
world_size
)
]
dist
.
all_gather
(
gather_list
,
tensor
,
group
=
self
.
group
)
physical_device_ids
=
[
t
.
item
()
for
t
in
gather_list
]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# assert current_platform.is_cuda_alike()
# fully_connected = current_platform.is_fully_connected(
# physical_device_ids)
if
_is_cuda
or
_is_hip
:
fully_connected
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
# if world_size > 2 and not fully_connected:
if
not
fully_connected
:
max_size
=
32
*
8192
*
2
# if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
# logger.warning(
# "Custom allreduce is disabled because it's not supported on"
# " more than two PCIe-only GPUs. To silence this warning, "
# "specify disable_custom_all_reduce=True explicitly.")
# return
logger
.
warning
(
"We are using PCIe's custom allreduce."
"If the performance is poor, we can add "
"--disable-custom-all-reduce in the instruction."
)
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if
not
_is_hip
and
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self
.
disabled
=
False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
,
uncached
=
True
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
max_size
=
max_size
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
fully_connected
=
fully_connected
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta_ptrs
,
self
.
rank_data
,
rank
,
self
.
fully_connected
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
@
contextmanager
def
capture
(
self
):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try
:
self
.
_IS_CAPTURING
=
True
yield
finally
:
self
.
_IS_CAPTURING
=
False
if
not
self
.
disabled
:
self
.
register_graph_buffers
()
def
register_graph_buffers
(
self
):
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data
=
[[
None
,
None
]
for
_
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
))]
all_data
[
self
.
rank
]
=
[
handle
,
offset
]
ranks
=
sorted
(
dist
.
get_process_group_ranks
(
group
=
self
.
group
))
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# Unpack list of tuples to tuple of lists.
handles
=
[
d
[
0
]
for
d
in
all_data
]
# type: ignore
offsets
=
[
d
[
1
]
for
d
in
all_data
]
# type: ignore
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
if
self
.
disabled
:
return
False
inp_size
=
inp
.
numel
()
*
inp
.
element_size
()
# custom allreduce requires input byte size to be multiples of 16
if
inp_size
%
16
!=
0
:
return
False
if
not
is_weak_contiguous
(
inp
):
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
return
inp_size
<=
self
.
max_size
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if
self
.
disabled
or
not
self
.
should_custom_ar
(
input
):
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
,
registered
=
False
)
else
:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
else
:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
ops
is
not
None
:
ops
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
free_shared_buffer
(
self
.
meta_ptrs
,
rank
=
self
.
rank
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
,
rank
=
self
.
rank
)
def
__del__
(
self
):
self
.
close
()
@
staticmethod
def
create_shared_buffer
(
size_in_bytes
:
int
,
group
:
Optional
[
ProcessGroup
]
=
None
,
uncached
:
Optional
[
bool
]
=
False
)
->
list
[
int
]:
pointer
,
handle
=
ops
.
allocate_shared_buffer_and_handle
(
size_in_bytes
)
world_size
=
dist
.
get_world_size
(
group
=
group
)
rank
=
dist
.
get_rank
(
group
=
group
)
handles
=
[
None
]
*
world_size
dist
.
all_gather_object
(
handles
,
handle
,
group
=
group
)
pointers
:
list
[
int
]
=
[]
for
i
,
h
in
enumerate
(
handles
):
if
i
==
rank
:
pointers
.
append
(
pointer
)
# type: ignore
else
:
pointers
.
append
(
ops
.
open_mem_handle
(
h
))
return
pointers
@
staticmethod
def
free_shared_buffer
(
pointers
:
list
[
int
],
group
:
Optional
[
ProcessGroup
]
=
None
,
rank
:
Optional
[
int
]
=
0
)
->
None
:
if
rank
is
None
:
rank
=
dist
.
get_rank
(
group
=
group
)
if
ops
is
not
None
:
ops
.
free_shared_buffer
(
pointers
[
rank
])
python/sglang/srt/distributed/parallel_state.py
View file @
a9e0e668
...
@@ -53,6 +53,7 @@ from sglang.srt.utils import (
...
@@ -53,6 +53,7 @@ from sglang.srt.utils import (
is_xpu
,
is_xpu
,
supports_custom_op
,
supports_custom_op
,
)
)
from
sglang.srt
import
_custom_ops
as
ops
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
...
@@ -303,7 +304,7 @@ class GroupCoordinator:
...
@@ -303,7 +304,7 @@ class GroupCoordinator:
# Lazy import to avoid documentation build error
# Lazy import to avoid documentation build error
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
,
CustomAllreduce
,
DCUCustomAllreduce
)
)
from
sglang.srt.distributed.device_communicators.pymscclpp
import
(
from
sglang.srt.distributed.device_communicators.pymscclpp
import
(
PyMscclppCommunicator
,
PyMscclppCommunicator
,
...
@@ -347,11 +348,17 @@ class GroupCoordinator:
...
@@ -347,11 +348,17 @@ class GroupCoordinator:
else
:
else
:
ca_max_size
=
8
*
1024
*
1024
ca_max_size
=
8
*
1024
*
1024
try
:
try
:
self
.
ca_comm
=
CustomAllreduce
(
if
is_hip
()
and
ops
.
use_dcu_custom_allreduce
:
group
=
self
.
cpu_group
,
self
.
ca_comm
=
DCUCustomAllreduce
(
device
=
self
.
device
,
group
=
self
.
cpu_group
,
max_size
=
ca_max_size
,
device
=
self
.
device
,
)
)
else
:
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
max_size
=
ca_max_size
,
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
warning
(
logger
.
warning
(
f
"Setup Custom allreduce failed with
{
e
}
. To silence this "
f
"Setup Custom allreduce failed with
{
e
}
. To silence this "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment