Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
24cafe31
Unverified
Commit
24cafe31
authored
Jan 19, 2025
by
yizhang2077
Committed by
GitHub
Jan 19, 2025
Browse files
add config to swtich from vllm custom allreduce to sgl_kernel custom allreduce (#2981)
parent
5a176c92
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
160 additions
and
72 deletions
+160
-72
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+79
-36
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+81
-36
No files found.
python/sglang/srt/_custom_ops.py
View file @
24cafe31
...
@@ -3,6 +3,7 @@ import contextlib
...
@@ -3,6 +3,7 @@ import contextlib
import
functools
import
functools
import
importlib
import
importlib
import
logging
import
logging
import
os
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -11,12 +12,19 @@ import torch.library
...
@@ -11,12 +12,19 @@ import torch.library
from
sglang.srt.utils
import
is_hpu
from
sglang.srt.utils
import
is_hpu
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
False
)
if
not
is_hpu
():
if
not
is_hpu
():
try
:
if
use_vllm_custom_allreduce
:
import
sgl_kernel
try
:
except
ImportError
as
e
:
import
vllm._C
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
else
:
try
:
import
sgl_kernel
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
def
hint_on_error
(
fn
):
def
hint_on_error
(
fn
):
...
@@ -48,43 +56,78 @@ def hint_on_error(fn):
...
@@ -48,43 +56,78 @@ def hint_on_error(fn):
return
wrapper
return
wrapper
# custom ar
if
use_vllm_custom_allreduce
:
def
init_custom_ar
(
# custom ar
rank_id
:
int
,
def
init_custom_ar
(
world_size
:
int
,
ipc_tensors
:
List
[
torch
.
Tensor
],
rank_data_base
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
buffers
:
List
[
int
],
rank
:
int
,
tmp_result_buffers
:
List
[
int
],
full_nvlink
:
bool
,
barrier_in
:
List
[
int
],
)
->
int
:
barrier_out
:
List
[
int
],
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
)
->
int
:
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
return
sgl_kernel
.
ops
.
init_custom_reduce
(
)
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
custom_reduce
(
fa
,
inp
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
def
all_reduce
(
sgl_kernel
.
ops
.
custom_dispose
(
fa
)
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
,
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
List
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
torch
.
ops
.
_C_custom_ar
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
else
:
# custom ar
def
init_custom_ar
(
rank_id
:
int
,
world_size
:
int
,
rank_data_base
:
torch
.
Tensor
,
buffers
:
List
[
int
],
tmp_result_buffers
:
List
[
int
],
barrier_in
:
List
[
int
],
barrier_out
:
List
[
int
],
)
->
int
:
return
sgl_kernel
.
ops
.
init_custom_reduce
(
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
custom_reduce
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]
:
def
dispose
(
fa
:
int
)
->
None
:
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
sgl_kernel
.
ops
.
custom_dispose
(
fa
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
24cafe31
...
@@ -21,8 +21,10 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
...
@@ -21,8 +21,10 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
from
sglang.srt.utils
import
cuda_device_count_stateless
,
is_cuda
from
sglang.srt.utils
import
cuda_device_count_stateless
,
is_cuda
try
:
try
:
import
sgl_kernel
if
ops
.
use_vllm_custom_allreduce
:
ops
.
meta_size
()
else
:
import
sgl_kernel
custom_ar
=
True
custom_ar
=
True
except
Exception
:
except
Exception
:
# For AMD GPUs and CPUs
# For AMD GPUs and CPUs
...
@@ -201,33 +203,58 @@ class CustomAllreduce:
...
@@ -201,33 +203,58 @@ class CustomAllreduce:
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
self
.
full_nvlink
=
full_nvlink
# From TensorRT-LLM getMaxRequiredWorkspaceSize
if
ops
.
use_vllm_custom_allreduce
:
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
# temporary buffer for storing intermediate allreduce results.
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
self
.
rank_data_base
=
torch
.
empty
(
# are first copied into this buffer before allreduce is performed
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
)
# This is a buffer for storing the tuples of pointers pointing to
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
# IPC buffers from all ranks. Each registered tuple has size of
self
.
barrier_max_size
,
group
=
group
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
)
# is enough for 131072 such tuples. The largest model I've seen only
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
# needs less than 10000 of registered tuples.
self
.
barrier_max_size
,
group
=
group
self
.
rank_data
=
torch
.
empty
(
)
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
_ptr
=
ops
.
init_custom_ar
(
rank
,
self
.
meta_ptrs
,
self
.
rank_data
,
rank
,
self
.
full_nvlink
world_size
,
)
self
.
rank_data_base
,
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
self
.
buffer_ptrs
,
else
:
self
.
tmp_result_buffer_ptrs
,
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self
.
barrier_in_ptrs
,
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
self
.
barrier_out_ptrs
,
)
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
rank_data_base
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
_ptr
=
ops
.
init_custom_ar
(
rank
,
world_size
,
self
.
rank_data_base
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
self
.
disabled
=
False
self
.
disabled
=
False
@
staticmethod
@
staticmethod
...
@@ -307,6 +334,11 @@ class CustomAllreduce:
...
@@ -307,6 +334,11 @@ class CustomAllreduce:
return
False
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
# little performance improvement over NCCL.
if
ops
.
use_vllm_custom_allreduce
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
False
if
self
.
world_size
==
2
:
if
self
.
world_size
==
2
:
return
(
return
(
inp_size
<
self
.
max_size
inp_size
<
self
.
max_size
...
@@ -326,6 +358,7 @@ class CustomAllreduce:
...
@@ -326,6 +358,7 @@ class CustomAllreduce:
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
*
,
*
,
out
:
torch
.
Tensor
=
None
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
,
):
):
"""Performs an out-of-place all reduce.
"""Performs an out-of-place all reduce.
...
@@ -335,7 +368,15 @@ class CustomAllreduce:
...
@@ -335,7 +368,15 @@ class CustomAllreduce:
"""
"""
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
out
=
torch
.
empty_like
(
inp
)
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
)
if
ops
.
use_vllm_custom_allreduce
:
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
)
return
out
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
@@ -345,21 +386,25 @@ class CustomAllreduce:
...
@@ -345,21 +386,25 @@ class CustomAllreduce:
return
None
return
None
if
self
.
_IS_CAPTURING
:
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
)
return
self
.
all_reduce
(
input
,
registered
=
True
)
else
:
else
:
# If warm up, mimic the allocation pattern since custom
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
return
torch
.
empty_like
(
input
)
else
:
else
:
return
self
.
all_reduce
(
input
)
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
close
(
self
):
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
not
self
.
disabled
and
self
.
_ptr
:
ops
.
dispose
(
self
.
_ptr
)
ops
.
dispose
(
self
.
_ptr
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
if
ops
.
use_vllm_custom_allreduce
:
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
)
else
:
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
)
self
.
_ptr
=
0
self
.
_ptr
=
0
def
__del__
(
self
):
def
__del__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment