Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
767c9dec
Unverified
Commit
767c9dec
authored
Jan 16, 2025
by
yizhang2077
Committed by
GitHub
Jan 16, 2025
Browse files
adapt custom allreduce for tensorrt llm (#2511)
parent
a53454c5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
241 additions
and
67 deletions
+241
-67
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+22
-27
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+53
-39
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_custom_allreduce.py
test/srt/test_custom_allreduce.py
+164
-0
No files found.
python/pyproject.toml
View file @
767c9dec
...
...
@@ -27,7 +27,7 @@ runtime_common = [
]
srt
=
[
"sglang[runtime_common]"
,
"cuda-python"
,
"sgl-kernel>=0.0.2.post1
2
"
,
"torch"
,
"vllm>=0.6.3.post1,<=0.6.4.post1"
,
"sgl-kernel>=0.0.2.post1
4
"
,
"torch"
,
"vllm>=0.6.3.post1,<=0.6.4.post1"
,
"flashinfer==0.1.6"
]
...
...
python/sglang/srt/_custom_ops.py
View file @
767c9dec
# Adapted from https://github.com/vllm-project/vllm/blob/
a6221a144af772fd1a68fe7e627935dc53e81738
/vllm/_custom_ops.py
# Adapted from https://github.com/vllm-project/vllm/blob/
v0.6.4.post1
/vllm/_custom_ops.py
import
contextlib
import
functools
import
importlib
...
...
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
if
not
is_hpu
():
try
:
import
custom_ar
import
sgl_kernel
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
...
...
@@ -50,46 +50,41 @@ def hint_on_error(fn):
# custom ar
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
full_nvlink
:
bool
,
rank_id
:
int
,
world_size
:
int
,
rank_data_base
:
torch
.
Tensor
,
buffers
:
List
[
int
],
tmp_result_buffers
:
List
[
int
],
barrier_in
:
List
[
int
],
barrier_out
:
List
[
int
],
)
->
int
:
return
torch
.
ops
.
_C_vllm_ar
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
return
sgl_kernel
.
ops
.
init_custom_reduce
(
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
,
)
->
None
:
torch
.
ops
.
_C_vllm_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
custom_reduce
(
fa
,
inp
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_vllm_ar
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_vllm_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
List
[
int
])
->
None
:
return
torch
.
ops
.
_C_vllm_ar
.
register_buffer
(
fa
,
ipc_tensors
)
sgl_kernel
.
ops
.
custom_dispose
(
fa
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
torch
.
ops
.
_C_vllm_ar
.
get_graph_buffer_ipc_meta
(
fa
)
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
_C_vllm_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
767c9dec
...
...
@@ -21,7 +21,8 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
from
sglang.srt.utils
import
cuda_device_count_stateless
,
is_cuda
try
:
ops
.
meta_size
()
import
sgl_kernel
custom_ar
=
True
except
Exception
:
# For AMD GPUs and CPUs
...
...
@@ -29,7 +30,6 @@ except Exception:
logger
=
logging
.
getLogger
(
__name__
)
_P
=
ParamSpec
(
"_P"
)
_R
=
TypeVar
(
"_R"
)
...
...
@@ -47,7 +47,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@
with_nvml_context
def
is_full_nvlink
(
cls
,
physical_device_ids
:
List
[
int
])
->
bool
:
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
])
->
bool
:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
...
...
@@ -196,32 +196,39 @@ class CustomAllreduce:
)
return
self
.
disabled
=
False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
max_size
=
max_size
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
rank_data_base
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta_ptrs
,
self
.
rank_data
,
rank
,
self
.
full_nvlink
rank
,
world_size
,
self
.
rank_data_base
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
self
.
disabled
=
False
@
staticmethod
def
create_shared_buffer
(
...
...
@@ -300,12 +307,25 @@ class CustomAllreduce:
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
if
self
.
world_size
==
2
:
return
(
inp_size
<
self
.
max_size
and
inp_size
<
self
.
max_required_workspace_size
[
0
]
)
if
self
.
full_nvlink
:
return
(
inp_size
<
self
.
max_size
and
inp_size
<
self
.
max_required_workspace_size
[
1
]
)
return
False
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
):
"""Performs an out-of-place all reduce.
...
...
@@ -315,12 +335,7 @@ class CustomAllreduce:
"""
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
)
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
@@ -330,23 +345,22 @@ class CustomAllreduce:
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
,
registered
=
True
)
return
self
.
all_reduce
(
input
)
else
:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
else
:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return
self
.
all_reduce
(
input
,
registered
=
False
)
return
self
.
all_reduce
(
input
)
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
ops
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
)
self
.
_ptr
=
0
def
__del__
(
self
):
self
.
close
()
test/srt/run_suite.py
View file @
767c9dec
...
...
@@ -12,6 +12,7 @@ suites = {
"sampling/penaltylib"
,
"test_abort.py"
,
"test_chunked_prefill.py"
,
"test_custom_allreduce.py"
,
"test_double_sparsity.py"
,
"test_eagle_infer.py"
,
"test_embedding_openai_server.py"
,
...
...
test/srt/test_custom_allreduce.py
0 → 100644
View file @
767c9dec
import
os
import
random
import
socket
import
unittest
from
typing
import
Any
import
ray
import
torch
import
torch.distributed
as
dist
from
sglang.srt.distributed
import
init_distributed_environment
from
sglang.srt.distributed.communication_op
import
(
# noqa
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.distributed.parallel_state
import
(
get_tensor_model_parallel_group
,
graph_capture
,
initialize_model_parallel
,
)
def
get_open_port
()
->
int
:
# try ipv4
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
except
OSError
:
# try ipv6
with
socket
.
socket
(
socket
.
AF_INET6
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
def
multi_process_parallel
(
world_size
:
int
,
cls
:
Any
,
test_target
:
Any
,
)
->
None
:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray
.
init
(
log_to_driver
=
False
)
distributed_init_port
=
get_open_port
()
refs
=
[]
for
rank
in
range
(
world_size
):
refs
.
append
(
test_target
.
remote
(
cls
,
world_size
,
rank
,
distributed_init_port
))
ray
.
get
(
refs
)
ray
.
shutdown
()
class
TestCustomAllReduce
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
random
.
seed
(
42
)
# 512B to 32MB
cls
.
test_sizes
=
[
512
,
4096
,
32768
,
262144
,
2097152
,
16777216
,
33554432
]
cls
.
world_sizes
=
[
2
,
4
,
6
,
8
]
cls
.
test_loop
=
10
def
test_graph_allreduce
(
self
):
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
continue
multi_process_parallel
(
world_size
,
self
,
self
.
graph_allreduce
)
def
test_eager_allreduce
(
self
):
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
continue
multi_process_parallel
(
world_size
,
self
,
self
.
eager_allreduce
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
graph_allreduce
(
self
,
world_size
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
rank
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
group
=
get_tensor_model_parallel_group
().
device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data
=
torch
.
zeros
(
1
)
data
=
data
.
to
(
device
=
device
)
torch
.
distributed
.
all_reduce
(
data
,
group
=
group
)
torch
.
cuda
.
synchronize
()
del
data
for
sz
in
self
.
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
self
.
test_loop
):
with
graph_capture
()
as
graph_capture_context
:
# use integers so result matches NCCL exactly
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
inp2
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
torch
.
cuda
.
synchronize
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
graph_capture_context
.
stream
):
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
# the input buffer is immediately modified to test
# synchronization
dist
.
all_reduce
(
inp1
,
group
=
group
)
out2
=
tensor_model_parallel_all_reduce
(
inp2
)
dist
.
all_reduce
(
inp2
,
group
=
group
)
graph
.
replay
()
torch
.
testing
.
assert_close
(
out1
,
inp1
)
torch
.
testing
.
assert_close
(
out2
,
inp2
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
eager_allreduce
(
self
,
world_size
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
rank
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
group
=
get_tensor_model_parallel_group
().
device_group
for
sz
in
self
.
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
self
.
test_loop
):
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
()
)
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
dist
.
all_reduce
(
inp1
,
group
=
group
)
torch
.
testing
.
assert_close
(
out1
,
inp1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment