Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a3ab768a
Unverified
Commit
a3ab768a
authored
Mar 03, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 03, 2025
Browse files
Clean up custom allreduce (#4029)
parent
66301e12
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
77 deletions
+24
-77
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+6
-62
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+18
-15
No files found.
python/sglang/srt/_custom_ops.py
View file @
a3ab768a
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import
contextlib
import
functools
import
importlib
import
logging
import
os
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Tuple
import
torch
import
torch.library
...
...
@@ -13,8 +10,9 @@ from sglang.srt.utils import is_hip, is_hpu
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
True
)
if
not
is_hpu
():
# R
emove vllm dependency for
custom allreduce
on ROCm
# R
OCm does not use vllm
custom allreduce
if
use_vllm_custom_allreduce
and
not
is_hip
():
try
:
import
vllm._C
...
...
@@ -27,37 +25,8 @@ if not is_hpu():
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
def
hint_on_error
(
fn
):
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
try
:
return
fn
(
*
args
,
**
kwargs
)
except
NotImplementedError
as
e
:
msg
=
(
"Error in calling custom op %s: %s
\n
"
"Not implemented or built, mostly likely because the current current device "
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
"incorrectly while building)"
)
logger
.
error
(
msg
,
fn
.
__name__
,
e
)
raise
NotImplementedError
(
msg
%
(
fn
.
__name__
,
e
))
from
e
except
AttributeError
as
e
:
msg
=
(
"Error in calling custom op %s: %s
\n
"
"Possibly you have built or installed an obsolete version of vllm.
\n
"
"Please try a clean build and install of vllm,"
"or remove old built files such as vllm/*cpython*.so and build/ ."
)
logger
.
error
(
msg
,
fn
.
__name__
,
e
)
raise
e
return
wrapper
if
use_vllm_custom_allreduce
and
not
is_hip
():
# custom a
r
#
vLLM
custom a
llreduce
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
...
...
@@ -96,6 +65,7 @@ if use_vllm_custom_allreduce and not is_hip():
else
:
if
is_hip
():
# ROCM custom allreduce
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
...
...
@@ -143,7 +113,7 @@ else:
return
sgl_kernel
.
ops
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
# custom a
r
#
TRTLLM
custom a
llreduce
def
init_custom_ar
(
rank_id
:
int
,
world_size
:
int
,
...
...
@@ -176,29 +146,3 @@ else:
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values
=
globals
()
names_and_values_to_update
=
{}
# prepare variables to avoid dict size change during iteration
k
,
v
,
arg
=
None
,
None
,
None
fn_type
=
type
(
lambda
x
:
x
)
for
k
,
v
in
names_and_values
.
items
():
# find functions that are defined in this file and have torch.Tensor
# in their annotations. `arg == "torch.Tensor"` is used to handle
# the case when users use `import __annotations__` to turn type
# hints into strings.
if
(
isinstance
(
v
,
fn_type
)
and
v
.
__code__
.
co_filename
==
__file__
and
any
(
arg
is
torch
.
Tensor
or
arg
==
"torch.Tensor"
for
arg
in
v
.
__annotations__
.
values
()
)
):
names_and_values_to_update
[
k
]
=
hint_on_error
(
v
)
names_and_values
.
update
(
names_and_values_to_update
)
del
names_and_values_to_update
,
names_and_values
,
v
,
k
,
fn_type
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
a3ab768a
...
...
@@ -22,17 +22,18 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
logger
=
logging
.
getLogger
(
__name__
)
is_hip_
=
is_hip
()
if
is_cuda
():
try
:
import
pynvml
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import pynvml with %r"
,
e
)
if
is_hip
()
:
if
is_hip
_
:
try
:
from
amdsmi
import
(
AmdSmiException
,
amdsmi_get_gpu_board_info
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_shut_down
,
...
...
@@ -42,9 +43,11 @@ if is_hip():
logger
.
warning
(
"Failed to import amdsmi with %r"
,
e
)
try
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
():
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip_
:
# Use vLLM custom allreduce
ops
.
meta_size
()
else
:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
custom_ar
=
True
except
Exception
:
...
...
@@ -60,7 +63,7 @@ _R = TypeVar("_R")
def
with_nvml_context
(
fn
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
@
wraps
(
fn
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
if
torch
.
version
.
hip
:
if
is_
hip
_
:
try
:
amdsmi_init
()
return
fn
(
*
args
,
**
kwargs
)
...
...
@@ -78,7 +81,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@
with_nvml_context
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
],
world_size
:
int
)
->
bool
:
if
is_hip
()
:
if
is_hip
_
:
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
...
...
@@ -142,7 +145,7 @@ def is_weak_contiguous(inp: torch.Tensor):
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_MAX_CAR_SIZE
=
8192
*
1024
if
is_hip
()
:
if
is_hip
_
:
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE
=
2
*
8192
*
1024
...
...
@@ -226,7 +229,7 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
if
is_cuda
()
or
is_hip
()
:
if
is_cuda
()
or
is_hip
_
:
full_nvlink
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
if
world_size
>
2
and
not
full_nvlink
:
...
...
@@ -240,7 +243,7 @@ class CustomAllreduce:
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if
not
is_hip
()
and
not
_can_p2p
(
rank
,
world_size
):
if
not
is_hip
_
and
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
...
...
@@ -253,7 +256,7 @@ class CustomAllreduce:
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
_
:
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
...
...
@@ -276,7 +279,7 @@ class CustomAllreduce:
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
else
:
if
is_hip
()
:
if
is_hip
_
:
# meta data buffers need to be "uncached" for signal on MI200
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
buffer
=
torch
.
empty
(
...
...
@@ -415,7 +418,7 @@ class CustomAllreduce:
ops
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
def
register_graph_buffers
(
self
):
if
is_hip
()
:
if
is_hip
_
:
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
...
...
@@ -451,12 +454,12 @@ class CustomAllreduce:
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
_
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
False
if
is_hip
()
:
if
is_hip
_
:
if
self
.
full_nvlink
:
if
self
.
world_size
==
8
:
if
self
.
MSCCL
:
...
...
@@ -529,7 +532,7 @@ class CustomAllreduce:
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
if
is_hip
()
:
if
is_hip
_
:
return
self
.
all_reduce_reg
(
input
)
else
:
return
self
.
all_reduce
(
input
,
registered
=
True
)
...
...
@@ -538,7 +541,7 @@ class CustomAllreduce:
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
else
:
if
is_hip
()
:
if
is_hip
_
:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment