Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
aba5ca15
Unverified
Commit
aba5ca15
authored
Apr 06, 2025
by
Yi Zhang
Committed by
GitHub
Apr 05, 2025
Browse files
python transfer custom allreduce from trt kernel to vllm kernel (#5080)
parent
496dde84
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
86 additions
and
171 deletions
+86
-171
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+59
-92
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+25
-77
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+1
-1
No files found.
python/pyproject.toml
View file @
aba5ca15
...
...
@@ -47,7 +47,7 @@ runtime_common = [
srt
=
[
"sglang[runtime_common]"
,
"sgl-kernel==0.0.
7
"
,
"sgl-kernel==0.0.
8
"
,
"flashinfer_python==0.2.3"
,
"torch==2.5.1"
,
"cuda-python"
,
...
...
python/sglang/srt/_custom_ops.py
View file @
aba5ca15
...
...
@@ -27,17 +27,20 @@ if not is_hpu():
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
if
use_vllm_custom_allreduce
and
not
is_hip
():
# vLLM custom allreduce
if
not
is_hip
():
if
use_vllm_custom_allreduce
:
custom_op
=
torch
.
ops
.
_C_custom_ar
else
:
custom_op
=
sgl_kernel
.
allreduce
# custom allreduce
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
)
return
custom_op
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
)
def
all_reduce
(
fa
:
int
,
...
...
@@ -46,105 +49,69 @@ if use_vllm_custom_allreduce and not is_hip():
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
,
)
->
None
:
torch
.
ops
.
_C_
custom_
ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
custom_
op
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_
custom_
ar
.
dispose
(
fa
)
custom_
op
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_
custom_
ar
.
meta_size
()
return
custom_
op
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
List
[
int
])
->
None
:
return
torch
.
ops
.
_C_
custom_
ar
.
register_buffer
(
fa
,
ipc_tensors
)
return
custom_
op
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
torch
.
ops
.
_C_
custom_
ar
.
get_graph_buffer_ipc_meta
(
fa
)
return
custom_
op
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
_C_
custom_
ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
custom_
op
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
else
:
if
is_hip
():
# ROCM custom allreduce
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
sgl_kernel
.
allreduce
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
allreduce
.
all_reduce_reg
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
allreduce
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
sgl_kernel
.
allreduce
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
sgl_kernel
.
allreduce
.
meta_size
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
sgl_kernel
.
allreduce
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
sgl_kernel
.
allreduce
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
allreduce
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
sgl_kernel
.
allreduce
.
allocate_meta_buffer
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
sgl_kernel
.
allreduce
.
get_meta_buffer_ipc_handle
(
inp
)
# ROCM custom allreduce
else
:
# TRTLLM custom allreduce
def
init_custom_ar
(
rank_id
:
int
,
world_size
:
int
,
rank_data_base
:
torch
.
Tensor
,
buffers
:
List
[
int
],
tmp_result_buffers
:
List
[
int
],
barrier_in
:
List
[
int
],
barrier_out
:
List
[
int
],
)
->
int
:
return
sgl_kernel
.
init_custom_reduce
(
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
custom_reduce
(
fa
,
inp
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
sgl_kernel
.
custom_dispose
(
fa
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
sgl_kernel
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
sgl_kernel
.
allreduce
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
allreduce
.
all_reduce_reg
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
allreduce
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
sgl_kernel
.
allreduce
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
sgl_kernel
.
allreduce
.
meta_size
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
sgl_kernel
.
allreduce
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
sgl_kernel
.
allreduce
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
allreduce
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
sgl_kernel
.
allreduce
.
allocate_meta_buffer
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
sgl_kernel
.
allreduce
.
get_meta_buffer_ipc_handle
(
inp
)
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
aba5ca15
...
...
@@ -257,7 +257,7 @@ class CustomAllreduce:
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
if
ops
.
use_vllm_custom_allreduce
and
not
_is_hip
:
if
not
_is_hip
:
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
...
...
@@ -280,56 +280,24 @@ class CustomAllreduce:
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
else
:
if
_is_hip
:
# meta data buffers need to be "uncached" for signal on MI200
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
handle
=
ops
.
get_meta_buffer_ipc_handle
(
self
.
meta
)
shard_data
=
(
bytes
(
handle
),
# ipc handle to base ptr
0
,
# offset of base ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
(
shard_data
)
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
handles
,
offsets
,
rank
,
self
.
full_nvlink
)
self
.
register_buffer
(
self
.
buffer
)
self
.
MSCCL
=
os
.
getenv
(
"RCCL_MSCCL_ENABLE"
,
"1"
)
==
"1"
else
:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
rank_data_base
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
# meta data buffers need to be "uncached" for signal on MI200
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
handle
=
ops
.
get_meta_buffer_ipc_handle
(
self
.
meta
)
shard_data
=
(
bytes
(
handle
),
# ipc handle to base ptr
0
,
# offset of base ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
(
shard_data
)
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
handles
,
offsets
,
rank
,
self
.
full_nvlink
)
self
.
register_buffer
(
self
.
buffer
)
self
.
MSCCL
=
os
.
getenv
(
"RCCL_MSCCL_ENABLE"
,
"1"
)
==
"1"
self
.
_ptr
=
ops
.
init_custom_ar
(
rank
,
world_size
,
self
.
rank_data_base
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
self
.
disabled
=
False
@
staticmethod
...
...
@@ -455,7 +423,7 @@ class CustomAllreduce:
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if
ops
.
use_vllm_custom_allreduce
and
not
_is_hip
:
if
not
_is_hip
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
False
...
...
@@ -471,18 +439,6 @@ class CustomAllreduce:
return
inp_size
<
self
.
max_size
return
False
if
self
.
world_size
==
2
:
return
(
inp_size
<
self
.
max_size
and
inp_size
<
self
.
max_required_workspace_size
[
0
]
)
if
self
.
full_nvlink
:
return
(
inp_size
<
self
.
max_size
and
inp_size
<
self
.
max_required_workspace_size
[
1
]
)
return
False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
...
...
@@ -515,15 +471,12 @@ class CustomAllreduce:
"""
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
if
ops
.
use_vllm_custom_allreduce
:
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
)
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
@@ -554,14 +507,9 @@ class CustomAllreduce:
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
ops
.
dispose
(
self
.
_ptr
)
if
ops
.
use_vllm_custom_allreduce
:
if
_is_cuda
:
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
elif
_is_cuda
:
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
)
self
.
_ptr
=
0
def
__del__
(
self
):
...
...
scripts/ci_install_dependency.sh
View file @
aba5ca15
...
...
@@ -20,7 +20,7 @@ pip install --upgrade pip
# Install flashinfer and sgl-kernel
pip
install
flashinfer_python
==
0.2.3
--find-links
${
FLASHINFER_REPO
}
--no-cache-dir
pip
install
sgl-kernel
==
0.0.
7
--no-cache-dir
pip
install
sgl-kernel
==
0.0.
8
--no-cache-dir
# Install the main package
pip
install
-e
"python[all]"
--find-links
${
FLASHINFER_REPO
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment