Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
24cafe31
Unverified
Commit
24cafe31
authored
Jan 19, 2025
by
yizhang2077
Committed by
GitHub
Jan 19, 2025
Browse files
add config to swtich from vllm custom allreduce to sgl_kernel custom allreduce (#2981)
parent
5a176c92
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
160 additions
and
72 deletions
+160
-72
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+79
-36
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+81
-36
No files found.
python/sglang/srt/_custom_ops.py
View file @
24cafe31
...
...
@@ -3,6 +3,7 @@ import contextlib
import
functools
import
importlib
import
logging
import
os
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -11,12 +12,19 @@ import torch.library
from
sglang.srt.utils
import
is_hpu
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
False
)
if
not
is_hpu
():
try
:
import
sgl_kernel
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
if
use_vllm_custom_allreduce
:
try
:
import
vllm._C
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
else
:
try
:
import
sgl_kernel
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
def
hint_on_error
(
fn
):
...
...
@@ -48,43 +56,78 @@ def hint_on_error(fn):
return
wrapper
# custom ar
def
init_custom_ar
(
rank_id
:
int
,
world_size
:
int
,
rank_data_base
:
torch
.
Tensor
,
buffers
:
List
[
int
],
tmp_result_buffers
:
List
[
int
],
barrier_in
:
List
[
int
],
barrier_out
:
List
[
int
],
)
->
int
:
return
sgl_kernel
.
ops
.
init_custom_reduce
(
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
custom_reduce
(
fa
,
inp
,
out
)
if
use_vllm_custom_allreduce
:
# custom ar
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
)
def
dispose
(
fa
:
int
)
->
None
:
sgl_kernel
.
ops
.
custom_dispose
(
fa
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
,
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
List
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
torch
.
ops
.
_C_custom_ar
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
else
:
# custom ar
def
init_custom_ar
(
rank_id
:
int
,
world_size
:
int
,
rank_data_base
:
torch
.
Tensor
,
buffers
:
List
[
int
],
tmp_result_buffers
:
List
[
int
],
barrier_in
:
List
[
int
],
barrier_out
:
List
[
int
],
)
->
int
:
return
sgl_kernel
.
ops
.
init_custom_reduce
(
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
custom_reduce
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]
:
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
def
dispose
(
fa
:
int
)
->
None
:
sgl_kernel
.
ops
.
custom_dispose
(
fa
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
24cafe31
...
...
@@ -21,8 +21,10 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
from
sglang.srt.utils
import
cuda_device_count_stateless
,
is_cuda
try
:
import
sgl_kernel
if
ops
.
use_vllm_custom_allreduce
:
ops
.
meta_size
()
else
:
import
sgl_kernel
custom_ar
=
True
except
Exception
:
# For AMD GPUs and CPUs
...
...
@@ -201,33 +203,58 @@ class CustomAllreduce:
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
rank_data_base
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
_ptr
=
ops
.
init_custom_ar
(
rank
,
world_size
,
self
.
rank_data_base
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
if
ops
.
use_vllm_custom_allreduce
:
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta_ptrs
,
self
.
rank_data
,
rank
,
self
.
full_nvlink
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
else
:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
rank_data_base
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
_ptr
=
ops
.
init_custom_ar
(
rank
,
world_size
,
self
.
rank_data_base
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
self
.
disabled
=
False
@
staticmethod
...
...
@@ -307,6 +334,11 @@ class CustomAllreduce:
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if
ops
.
use_vllm_custom_allreduce
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
False
if
self
.
world_size
==
2
:
return
(
inp_size
<
self
.
max_size
...
...
@@ -326,6 +358,7 @@ class CustomAllreduce:
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
,
):
"""Performs an out-of-place all reduce.
...
...
@@ -335,7 +368,15 @@ class CustomAllreduce:
"""
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
)
if
ops
.
use_vllm_custom_allreduce
:
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
)
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
@@ -345,21 +386,25 @@ class CustomAllreduce:
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
)
return
self
.
all_reduce
(
input
,
registered
=
True
)
else
:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
else
:
return
self
.
all_reduce
(
input
)
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
ops
.
dispose
(
self
.
_ptr
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
)
if
ops
.
use_vllm_custom_allreduce
:
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
else
:
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
)
self
.
_ptr
=
0
def
__del__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment