Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
aba5ca15
Unverified
Commit
aba5ca15
authored
Apr 06, 2025
by
Yi Zhang
Committed by
GitHub
Apr 05, 2025
Browse files
python transfer custom allreduce from trt kernel to vllm kernel (#5080)
parent
496dde84
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
86 additions
and
171 deletions
+86
-171
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+59
-92
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+25
-77
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+1
-1
No files found.
python/pyproject.toml
View file @
aba5ca15
...
@@ -47,7 +47,7 @@ runtime_common = [
...
@@ -47,7 +47,7 @@ runtime_common = [
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.0.
7
"
,
"sgl-kernel==0.0.
8
"
,
"flashinfer_python==0.2.3"
,
"flashinfer_python==0.2.3"
,
"torch==2.5.1"
,
"torch==2.5.1"
,
"cuda-python"
,
"cuda-python"
,
...
...
python/sglang/srt/_custom_ops.py
View file @
aba5ca15
...
@@ -27,17 +27,20 @@ if not is_hpu():
...
@@ -27,17 +27,20 @@ if not is_hpu():
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
if
use_vllm_custom_allreduce
and
not
is_hip
():
if
not
is_hip
():
# vLLM custom allreduce
if
use_vllm_custom_allreduce
:
custom_op
=
torch
.
ops
.
_C_custom_ar
else
:
custom_op
=
sgl_kernel
.
allreduce
# custom allreduce
def
init_custom_ar
(
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
ipc_tensors
:
List
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
rank
:
int
,
rank
:
int
,
full_nvlink
:
bool
,
full_nvlink
:
bool
,
)
->
int
:
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
return
custom_op
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
)
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
)
def
all_reduce
(
def
all_reduce
(
fa
:
int
,
fa
:
int
,
...
@@ -46,27 +49,26 @@ if use_vllm_custom_allreduce and not is_hip():
...
@@ -46,27 +49,26 @@ if use_vllm_custom_allreduce and not is_hip():
reg_buffer
:
int
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
,
reg_buffer_sz_bytes
:
int
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_C_
custom_
ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
custom_
op
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
dispose
(
fa
:
int
)
->
None
:
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_
custom_
ar
.
dispose
(
fa
)
custom_
op
.
dispose
(
fa
)
def
meta_size
()
->
int
:
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_
custom_
ar
.
meta_size
()
return
custom_
op
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
List
[
int
])
->
None
:
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
List
[
int
])
->
None
:
return
torch
.
ops
.
_C_
custom_
ar
.
register_buffer
(
fa
,
ipc_tensors
)
return
custom_
op
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
torch
.
ops
.
_C_
custom_
ar
.
get_graph_buffer_ipc_meta
(
fa
)
return
custom_
op
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
)
->
None
:
torch
.
ops
.
_C_
custom_
ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
custom_
op
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
else
:
else
:
if
is_hip
():
# ROCM custom allreduce
# ROCM custom allreduce
def
init_custom_ar
(
def
init_custom_ar
(
...
@@ -113,38 +115,3 @@ else:
...
@@ -113,38 +115,3 @@ else:
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
sgl_kernel
.
allreduce
.
get_meta_buffer_ipc_handle
(
inp
)
return
sgl_kernel
.
allreduce
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
# TRTLLM custom allreduce
def
init_custom_ar
(
rank_id
:
int
,
world_size
:
int
,
rank_data_base
:
torch
.
Tensor
,
buffers
:
List
[
int
],
tmp_result_buffers
:
List
[
int
],
barrier_in
:
List
[
int
],
barrier_out
:
List
[
int
],
)
->
int
:
return
sgl_kernel
.
init_custom_reduce
(
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
custom_reduce
(
fa
,
inp
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
sgl_kernel
.
custom_dispose
(
fa
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
sgl_kernel
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
aba5ca15
...
@@ -257,7 +257,7 @@ class CustomAllreduce:
...
@@ -257,7 +257,7 @@ class CustomAllreduce:
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
self
.
full_nvlink
=
full_nvlink
if
ops
.
use_vllm_custom_allreduce
and
not
_is_hip
:
if
not
_is_hip
:
# Buffers memory are owned by this Python class and passed to C++.
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
# temporary buffer for storing intermediate allreduce results.
...
@@ -280,12 +280,9 @@ class CustomAllreduce:
...
@@ -280,12 +280,9 @@ class CustomAllreduce:
)
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
else
:
else
:
if
_is_hip
:
# meta data buffers need to be "uncached" for signal on MI200
# meta data buffers need to be "uncached" for signal on MI200
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
buffer
=
torch
.
empty
(
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
handle
=
ops
.
get_meta_buffer_ipc_handle
(
self
.
meta
)
handle
=
ops
.
get_meta_buffer_ipc_handle
(
self
.
meta
)
shard_data
=
(
shard_data
=
(
bytes
(
handle
),
# ipc handle to base ptr
bytes
(
handle
),
# ipc handle to base ptr
...
@@ -300,36 +297,7 @@ class CustomAllreduce:
...
@@ -300,36 +297,7 @@ class CustomAllreduce:
)
)
self
.
register_buffer
(
self
.
buffer
)
self
.
register_buffer
(
self
.
buffer
)
self
.
MSCCL
=
os
.
getenv
(
"RCCL_MSCCL_ENABLE"
,
"1"
)
==
"1"
self
.
MSCCL
=
os
.
getenv
(
"RCCL_MSCCL_ENABLE"
,
"1"
)
==
"1"
else
:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
rank_data_base
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
_ptr
=
ops
.
init_custom_ar
(
rank
,
world_size
,
self
.
rank_data_base
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
self
.
disabled
=
False
self
.
disabled
=
False
@
staticmethod
@
staticmethod
...
@@ -455,7 +423,7 @@ class CustomAllreduce:
...
@@ -455,7 +423,7 @@ class CustomAllreduce:
return
False
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
# little performance improvement over NCCL.
if
ops
.
use_vllm_custom_allreduce
and
not
_is_hip
:
if
not
_is_hip
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
inp_size
<
self
.
max_size
return
False
return
False
...
@@ -471,18 +439,6 @@ class CustomAllreduce:
...
@@ -471,18 +439,6 @@ class CustomAllreduce:
return
inp_size
<
self
.
max_size
return
inp_size
<
self
.
max_size
return
False
return
False
if
self
.
world_size
==
2
:
return
(
inp_size
<
self
.
max_size
and
inp_size
<
self
.
max_required_workspace_size
[
0
]
)
if
self
.
full_nvlink
:
return
(
inp_size
<
self
.
max_size
and
inp_size
<
self
.
max_required_workspace_size
[
1
]
)
return
False
return
False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# all reduce, assuming inp tensor is IPC registered with register_buffer,
...
@@ -515,15 +471,12 @@ class CustomAllreduce:
...
@@ -515,15 +471,12 @@ class CustomAllreduce:
"""
"""
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
out
=
torch
.
empty_like
(
inp
)
if
ops
.
use_vllm_custom_allreduce
:
if
registered
:
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
else
:
ops
.
all_reduce
(
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
)
return
out
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
@@ -554,14 +507,9 @@ class CustomAllreduce:
...
@@ -554,14 +507,9 @@ class CustomAllreduce:
def
close
(
self
):
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
not
self
.
disabled
and
self
.
_ptr
:
ops
.
dispose
(
self
.
_ptr
)
ops
.
dispose
(
self
.
_ptr
)
if
ops
.
use_vllm_custom_allreduce
:
if
_is_cuda
:
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
elif
_is_cuda
:
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
)
self
.
_ptr
=
0
self
.
_ptr
=
0
def
__del__
(
self
):
def
__del__
(
self
):
...
...
scripts/ci_install_dependency.sh
View file @
aba5ca15
...
@@ -20,7 +20,7 @@ pip install --upgrade pip
...
@@ -20,7 +20,7 @@ pip install --upgrade pip
# Install flashinfer and sgl-kernel
# Install flashinfer and sgl-kernel
pip
install
flashinfer_python
==
0.2.3
--find-links
${
FLASHINFER_REPO
}
--no-cache-dir
pip
install
flashinfer_python
==
0.2.3
--find-links
${
FLASHINFER_REPO
}
--no-cache-dir
pip
install
sgl-kernel
==
0.0.
7
--no-cache-dir
pip
install
sgl-kernel
==
0.0.
8
--no-cache-dir
# Install the main package
# Install the main package
pip
install
-e
"python[all]"
--find-links
${
FLASHINFER_REPO
}
pip
install
-e
"python[all]"
--find-links
${
FLASHINFER_REPO
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment