Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a3ab768a
"server/vscode:/vscode.git/clone" did not exist on "30be188400d27b6fedd88cb3dfd88de45639703c"
Unverified
Commit
a3ab768a
authored
Mar 03, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 03, 2025
Browse files
Clean up custom allreduce (#4029)
parent
66301e12
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
77 deletions
+24
-77
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+6
-62
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+18
-15
No files found.
python/sglang/srt/_custom_ops.py
View file @
a3ab768a
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import
contextlib
import
functools
import
importlib
import
logging
import
logging
import
os
import
os
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Tuple
import
torch
import
torch
import
torch.library
import
torch.library
...
@@ -13,8 +10,9 @@ from sglang.srt.utils import is_hip, is_hpu
...
@@ -13,8 +10,9 @@ from sglang.srt.utils import is_hip, is_hpu
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
True
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
True
)
if
not
is_hpu
():
if
not
is_hpu
():
# R
emove vllm dependency for
custom allreduce
on ROCm
# R
OCm does not use vllm
custom allreduce
if
use_vllm_custom_allreduce
and
not
is_hip
():
if
use_vllm_custom_allreduce
and
not
is_hip
():
try
:
try
:
import
vllm._C
import
vllm._C
...
@@ -27,37 +25,8 @@ if not is_hpu():
...
@@ -27,37 +25,8 @@ if not is_hpu():
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
def
hint_on_error
(
fn
):
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
try
:
return
fn
(
*
args
,
**
kwargs
)
except
NotImplementedError
as
e
:
msg
=
(
"Error in calling custom op %s: %s
\n
"
"Not implemented or built, mostly likely because the current current device "
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
"incorrectly while building)"
)
logger
.
error
(
msg
,
fn
.
__name__
,
e
)
raise
NotImplementedError
(
msg
%
(
fn
.
__name__
,
e
))
from
e
except
AttributeError
as
e
:
msg
=
(
"Error in calling custom op %s: %s
\n
"
"Possibly you have built or installed an obsolete version of vllm.
\n
"
"Please try a clean build and install of vllm,"
"or remove old built files such as vllm/*cpython*.so and build/ ."
)
logger
.
error
(
msg
,
fn
.
__name__
,
e
)
raise
e
return
wrapper
if
use_vllm_custom_allreduce
and
not
is_hip
():
if
use_vllm_custom_allreduce
and
not
is_hip
():
# custom a
r
#
vLLM
custom a
llreduce
def
init_custom_ar
(
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
ipc_tensors
:
List
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
...
@@ -96,6 +65,7 @@ if use_vllm_custom_allreduce and not is_hip():
...
@@ -96,6 +65,7 @@ if use_vllm_custom_allreduce and not is_hip():
else
:
else
:
if
is_hip
():
if
is_hip
():
# ROCM custom allreduce
def
init_custom_ar
(
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
meta
:
torch
.
Tensor
,
...
@@ -143,7 +113,7 @@ else:
...
@@ -143,7 +113,7 @@ else:
return
sgl_kernel
.
ops
.
get_meta_buffer_ipc_handle
(
inp
)
return
sgl_kernel
.
ops
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
else
:
# custom a
r
#
TRTLLM
custom a
llreduce
def
init_custom_ar
(
def
init_custom_ar
(
rank_id
:
int
,
rank_id
:
int
,
world_size
:
int
,
world_size
:
int
,
...
@@ -176,29 +146,3 @@ else:
...
@@ -176,29 +146,3 @@ else:
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values
=
globals
()
names_and_values_to_update
=
{}
# prepare variables to avoid dict size change during iteration
k
,
v
,
arg
=
None
,
None
,
None
fn_type
=
type
(
lambda
x
:
x
)
for
k
,
v
in
names_and_values
.
items
():
# find functions that are defined in this file and have torch.Tensor
# in their annotations. `arg == "torch.Tensor"` is used to handle
# the case when users use `import __annotations__` to turn type
# hints into strings.
if
(
isinstance
(
v
,
fn_type
)
and
v
.
__code__
.
co_filename
==
__file__
and
any
(
arg
is
torch
.
Tensor
or
arg
==
"torch.Tensor"
for
arg
in
v
.
__annotations__
.
values
()
)
):
names_and_values_to_update
[
k
]
=
hint_on_error
(
v
)
names_and_values
.
update
(
names_and_values_to_update
)
del
names_and_values_to_update
,
names_and_values
,
v
,
k
,
fn_type
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
a3ab768a
...
@@ -22,17 +22,18 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
...
@@ -22,17 +22,18 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
is_hip_
=
is_hip
()
if
is_cuda
():
if
is_cuda
():
try
:
try
:
import
pynvml
import
pynvml
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import pynvml with %r"
,
e
)
logger
.
warning
(
"Failed to import pynvml with %r"
,
e
)
if
is_hip
()
:
if
is_hip
_
:
try
:
try
:
from
amdsmi
import
(
from
amdsmi
import
(
AmdSmiException
,
AmdSmiException
,
amdsmi_get_gpu_board_info
,
amdsmi_get_processor_handles
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_init
,
amdsmi_shut_down
,
amdsmi_shut_down
,
...
@@ -42,9 +43,11 @@ if is_hip():
...
@@ -42,9 +43,11 @@ if is_hip():
logger
.
warning
(
"Failed to import amdsmi with %r"
,
e
)
logger
.
warning
(
"Failed to import amdsmi with %r"
,
e
)
try
:
try
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
():
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip_
:
# Use vLLM custom allreduce
ops
.
meta_size
()
ops
.
meta_size
()
else
:
else
:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import
sgl_kernel
import
sgl_kernel
custom_ar
=
True
custom_ar
=
True
except
Exception
:
except
Exception
:
...
@@ -60,7 +63,7 @@ _R = TypeVar("_R")
...
@@ -60,7 +63,7 @@ _R = TypeVar("_R")
def
with_nvml_context
(
fn
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
def
with_nvml_context
(
fn
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
@
wraps
(
fn
)
@
wraps
(
fn
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
if
torch
.
version
.
hip
:
if
is_
hip
_
:
try
:
try
:
amdsmi_init
()
amdsmi_init
()
return
fn
(
*
args
,
**
kwargs
)
return
fn
(
*
args
,
**
kwargs
)
...
@@ -78,7 +81,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
...
@@ -78,7 +81,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@
with_nvml_context
@
with_nvml_context
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
],
world_size
:
int
)
->
bool
:
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
],
world_size
:
int
)
->
bool
:
if
is_hip
()
:
if
is_hip
_
:
"""
"""
query if the set of gpus are fully connected by xgmi (1 hop)
query if the set of gpus are fully connected by xgmi (1 hop)
"""
"""
...
@@ -142,7 +145,7 @@ def is_weak_contiguous(inp: torch.Tensor):
...
@@ -142,7 +145,7 @@ def is_weak_contiguous(inp: torch.Tensor):
class
CustomAllreduce
:
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_MAX_CAR_SIZE
=
8192
*
1024
_MAX_CAR_SIZE
=
8192
*
1024
if
is_hip
()
:
if
is_hip
_
:
# crossover is at 16MB buffer size for ROCm
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE
=
2
*
8192
*
1024
_MAX_CAR_SIZE
=
2
*
8192
*
1024
...
@@ -226,7 +229,7 @@ class CustomAllreduce:
...
@@ -226,7 +229,7 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# this checks hardware and driver support for NVLink
if
is_cuda
()
or
is_hip
()
:
if
is_cuda
()
or
is_hip
_
:
full_nvlink
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
full_nvlink
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
if
world_size
>
2
and
not
full_nvlink
:
if
world_size
>
2
and
not
full_nvlink
:
...
@@ -240,7 +243,7 @@ class CustomAllreduce:
...
@@ -240,7 +243,7 @@ class CustomAllreduce:
# this is expensive to compute at the first time
# this is expensive to compute at the first time
# then we cache the result
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if
not
is_hip
()
and
not
_can_p2p
(
rank
,
world_size
):
if
not
is_hip
_
and
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"GPU P2P capability or P2P test failed. To silence this "
...
@@ -253,7 +256,7 @@ class CustomAllreduce:
...
@@ -253,7 +256,7 @@ class CustomAllreduce:
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
self
.
full_nvlink
=
full_nvlink
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
_
:
# Buffers memory are owned by this Python class and passed to C++.
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
# temporary buffer for storing intermediate allreduce results.
...
@@ -276,7 +279,7 @@ class CustomAllreduce:
...
@@ -276,7 +279,7 @@ class CustomAllreduce:
)
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
else
:
else
:
if
is_hip
()
:
if
is_hip
_
:
# meta data buffers need to be "uncached" for signal on MI200
# meta data buffers need to be "uncached" for signal on MI200
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
buffer
=
torch
.
empty
(
self
.
buffer
=
torch
.
empty
(
...
@@ -415,7 +418,7 @@ class CustomAllreduce:
...
@@ -415,7 +418,7 @@ class CustomAllreduce:
ops
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
ops
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
def
register_graph_buffers
(
self
):
def
register_graph_buffers
(
self
):
if
is_hip
()
:
if
is_hip
_
:
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
...
@@ -451,12 +454,12 @@ class CustomAllreduce:
...
@@ -451,12 +454,12 @@ class CustomAllreduce:
return
False
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
# little performance improvement over NCCL.
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
_
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
inp_size
<
self
.
max_size
return
False
return
False
if
is_hip
()
:
if
is_hip
_
:
if
self
.
full_nvlink
:
if
self
.
full_nvlink
:
if
self
.
world_size
==
8
:
if
self
.
world_size
==
8
:
if
self
.
MSCCL
:
if
self
.
MSCCL
:
...
@@ -529,7 +532,7 @@ class CustomAllreduce:
...
@@ -529,7 +532,7 @@ class CustomAllreduce:
return
None
return
None
if
self
.
_IS_CAPTURING
:
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
if
torch
.
cuda
.
is_current_stream_capturing
():
if
is_hip
()
:
if
is_hip
_
:
return
self
.
all_reduce_reg
(
input
)
return
self
.
all_reduce_reg
(
input
)
else
:
else
:
return
self
.
all_reduce
(
input
,
registered
=
True
)
return
self
.
all_reduce
(
input
,
registered
=
True
)
...
@@ -538,7 +541,7 @@ class CustomAllreduce:
...
@@ -538,7 +541,7 @@ class CustomAllreduce:
# allreduce is out-of-place.
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
return
torch
.
empty_like
(
input
)
else
:
else
:
if
is_hip
()
:
if
is_hip
_
:
# note: outside of cuda graph context,
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# be small(<=1% of overall latency) compared to the performance
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment