Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a3ab768a
Unverified
Commit
a3ab768a
authored
Mar 03, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 03, 2025
Browse files
Clean up custom allreduce (#4029)
parent
66301e12
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
77 deletions
+24
-77
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+6
-62
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+18
-15
No files found.
python/sglang/srt/_custom_ops.py
View file @
a3ab768a
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import
contextlib
import
functools
import
importlib
import
logging
import
logging
import
os
import
os
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Tuple
import
torch
import
torch
import
torch.library
import
torch.library
...
@@ -13,8 +10,9 @@ from sglang.srt.utils import is_hip, is_hpu
...
@@ -13,8 +10,9 @@ from sglang.srt.utils import is_hip, is_hpu
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
True
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
True
)
if
not
is_hpu
():
if
not
is_hpu
():
# R
emove vllm dependency for
custom allreduce
on ROCm
# R
OCm does not use vllm
custom allreduce
if
use_vllm_custom_allreduce
and
not
is_hip
():
if
use_vllm_custom_allreduce
and
not
is_hip
():
try
:
try
:
import
vllm._C
import
vllm._C
...
@@ -27,37 +25,8 @@ if not is_hpu():
...
@@ -27,37 +25,8 @@ if not is_hpu():
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
def
hint_on_error
(
fn
):
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
try
:
return
fn
(
*
args
,
**
kwargs
)
except
NotImplementedError
as
e
:
msg
=
(
"Error in calling custom op %s: %s
\n
"
"Not implemented or built, mostly likely because the current current device "
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
"incorrectly while building)"
)
logger
.
error
(
msg
,
fn
.
__name__
,
e
)
raise
NotImplementedError
(
msg
%
(
fn
.
__name__
,
e
))
from
e
except
AttributeError
as
e
:
msg
=
(
"Error in calling custom op %s: %s
\n
"
"Possibly you have built or installed an obsolete version of vllm.
\n
"
"Please try a clean build and install of vllm,"
"or remove old built files such as vllm/*cpython*.so and build/ ."
)
logger
.
error
(
msg
,
fn
.
__name__
,
e
)
raise
e
return
wrapper
if
use_vllm_custom_allreduce
and
not
is_hip
():
if
use_vllm_custom_allreduce
and
not
is_hip
():
# custom a
r
#
vLLM
custom a
llreduce
def
init_custom_ar
(
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
ipc_tensors
:
List
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
...
@@ -96,6 +65,7 @@ if use_vllm_custom_allreduce and not is_hip():
...
@@ -96,6 +65,7 @@ if use_vllm_custom_allreduce and not is_hip():
else
:
else
:
if
is_hip
():
if
is_hip
():
# ROCM custom allreduce
def
init_custom_ar
(
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
meta
:
torch
.
Tensor
,
...
@@ -143,7 +113,7 @@ else:
...
@@ -143,7 +113,7 @@ else:
return
sgl_kernel
.
ops
.
get_meta_buffer_ipc_handle
(
inp
)
return
sgl_kernel
.
ops
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
else
:
# custom a
r
#
TRTLLM
custom a
llreduce
def
init_custom_ar
(
def
init_custom_ar
(
rank_id
:
int
,
rank_id
:
int
,
world_size
:
int
,
world_size
:
int
,
...
@@ -176,29 +146,3 @@ else:
...
@@ -176,29 +146,3 @@ else:
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values
=
globals
()
names_and_values_to_update
=
{}
# prepare variables to avoid dict size change during iteration
k
,
v
,
arg
=
None
,
None
,
None
fn_type
=
type
(
lambda
x
:
x
)
for
k
,
v
in
names_and_values
.
items
():
# find functions that are defined in this file and have torch.Tensor
# in their annotations. `arg == "torch.Tensor"` is used to handle
# the case when users use `import __annotations__` to turn type
# hints into strings.
if
(
isinstance
(
v
,
fn_type
)
and
v
.
__code__
.
co_filename
==
__file__
and
any
(
arg
is
torch
.
Tensor
or
arg
==
"torch.Tensor"
for
arg
in
v
.
__annotations__
.
values
()
)
):
names_and_values_to_update
[
k
]
=
hint_on_error
(
v
)
names_and_values
.
update
(
names_and_values_to_update
)
del
names_and_values_to_update
,
names_and_values
,
v
,
k
,
fn_type
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
a3ab768a
...
@@ -22,17 +22,18 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
...
@@ -22,17 +22,18 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
is_hip_
=
is_hip
()
if
is_cuda
():
if
is_cuda
():
try
:
try
:
import
pynvml
import
pynvml
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import pynvml with %r"
,
e
)
logger
.
warning
(
"Failed to import pynvml with %r"
,
e
)
if
is_hip
()
:
if
is_hip
_
:
try
:
try
:
from
amdsmi
import
(
from
amdsmi
import
(
AmdSmiException
,
AmdSmiException
,
amdsmi_get_gpu_board_info
,
amdsmi_get_processor_handles
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_init
,
amdsmi_shut_down
,
amdsmi_shut_down
,
...
@@ -42,9 +43,11 @@ if is_hip():
...
@@ -42,9 +43,11 @@ if is_hip():
logger
.
warning
(
"Failed to import amdsmi with %r"
,
e
)
logger
.
warning
(
"Failed to import amdsmi with %r"
,
e
)
try
:
try
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
():
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip_
:
# Use vLLM custom allreduce
ops
.
meta_size
()
ops
.
meta_size
()
else
:
else
:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
import
sgl_kernel
custom_ar
=
True
custom_ar
=
True
except
Exception
:
except
Exception
:
...
@@ -60,7 +63,7 @@ _R = TypeVar("_R")
...
@@ -60,7 +63,7 @@ _R = TypeVar("_R")
def
with_nvml_context
(
fn
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
def
with_nvml_context
(
fn
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
@
wraps
(
fn
)
@
wraps
(
fn
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
if
torch
.
version
.
hip
:
if
is_
hip
_
:
try
:
try
:
amdsmi_init
()
amdsmi_init
()
return
fn
(
*
args
,
**
kwargs
)
return
fn
(
*
args
,
**
kwargs
)
...
@@ -78,7 +81,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
...
@@ -78,7 +81,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@
with_nvml_context
@
with_nvml_context
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
],
world_size
:
int
)
->
bool
:
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
],
world_size
:
int
)
->
bool
:
if
is_hip
()
:
if
is_hip
_
:
"""
"""
query if the set of gpus are fully connected by xgmi (1 hop)
query if the set of gpus are fully connected by xgmi (1 hop)
"""
"""
...
@@ -142,7 +145,7 @@ def is_weak_contiguous(inp: torch.Tensor):
...
@@ -142,7 +145,7 @@ def is_weak_contiguous(inp: torch.Tensor):
class
CustomAllreduce
:
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_MAX_CAR_SIZE
=
8192
*
1024
_MAX_CAR_SIZE
=
8192
*
1024
if
is_hip
()
:
if
is_hip
_
:
# crossover is at 16MB buffer size for ROCm
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE
=
2
*
8192
*
1024
_MAX_CAR_SIZE
=
2
*
8192
*
1024
...
@@ -226,7 +229,7 @@ class CustomAllreduce:
...
@@ -226,7 +229,7 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# this checks hardware and driver support for NVLink
if
is_cuda
()
or
is_hip
()
:
if
is_cuda
()
or
is_hip
_
:
full_nvlink
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
full_nvlink
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
if
world_size
>
2
and
not
full_nvlink
:
if
world_size
>
2
and
not
full_nvlink
:
...
@@ -240,7 +243,7 @@ class CustomAllreduce:
...
@@ -240,7 +243,7 @@ class CustomAllreduce:
# this is expensive to compute at the first time
# this is expensive to compute at the first time
# then we cache the result
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if
not
is_hip
()
and
not
_can_p2p
(
rank
,
world_size
):
if
not
is_hip
_
and
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"GPU P2P capability or P2P test failed. To silence this "
...
@@ -253,7 +256,7 @@ class CustomAllreduce:
...
@@ -253,7 +256,7 @@ class CustomAllreduce:
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
self
.
full_nvlink
=
full_nvlink
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
_
:
# Buffers memory are owned by this Python class and passed to C++.
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
# temporary buffer for storing intermediate allreduce results.
...
@@ -276,7 +279,7 @@ class CustomAllreduce:
...
@@ -276,7 +279,7 @@ class CustomAllreduce:
)
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
else
:
else
:
if
is_hip
()
:
if
is_hip
_
:
# meta data buffers need to be "uncached" for signal on MI200
# meta data buffers need to be "uncached" for signal on MI200
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
buffer
=
torch
.
empty
(
self
.
buffer
=
torch
.
empty
(
...
@@ -415,7 +418,7 @@ class CustomAllreduce:
...
@@ -415,7 +418,7 @@ class CustomAllreduce:
ops
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
ops
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
def
register_graph_buffers
(
self
):
def
register_graph_buffers
(
self
):
if
is_hip
()
:
if
is_hip
_
:
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
...
@@ -451,12 +454,12 @@ class CustomAllreduce:
...
@@ -451,12 +454,12 @@ class CustomAllreduce:
return
False
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
# little performance improvement over NCCL.
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
_
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
inp_size
<
self
.
max_size
return
False
return
False
if
is_hip
()
:
if
is_hip
_
:
if
self
.
full_nvlink
:
if
self
.
full_nvlink
:
if
self
.
world_size
==
8
:
if
self
.
world_size
==
8
:
if
self
.
MSCCL
:
if
self
.
MSCCL
:
...
@@ -529,7 +532,7 @@ class CustomAllreduce:
...
@@ -529,7 +532,7 @@ class CustomAllreduce:
return
None
return
None
if
self
.
_IS_CAPTURING
:
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
if
torch
.
cuda
.
is_current_stream_capturing
():
if
is_hip
()
:
if
is_hip
_
:
return
self
.
all_reduce_reg
(
input
)
return
self
.
all_reduce_reg
(
input
)
else
:
else
:
return
self
.
all_reduce
(
input
,
registered
=
True
)
return
self
.
all_reduce
(
input
,
registered
=
True
)
...
@@ -538,7 +541,7 @@ class CustomAllreduce:
...
@@ -538,7 +541,7 @@ class CustomAllreduce:
# allreduce is out-of-place.
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
return
torch
.
empty_like
(
input
)
else
:
else
:
if
is_hip
()
:
if
is_hip
_
:
# note: outside of cuda graph context,
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# be small(<=1% of overall latency) compared to the performance
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment