Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
767c9dec
"vscode:/vscode.git/clone" did not exist on "9d852f1721d50713f04ac28bf69d60b7458214ab"
Unverified
Commit
767c9dec
authored
Jan 16, 2025
by
yizhang2077
Committed by
GitHub
Jan 16, 2025
Browse files
adapt custom allreduce for tensorrt llm (#2511)
parent
a53454c5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
241 additions
and
67 deletions
+241
-67
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+22
-27
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+53
-39
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_custom_allreduce.py
test/srt/test_custom_allreduce.py
+164
-0
No files found.
python/pyproject.toml
View file @
767c9dec
...
@@ -27,7 +27,7 @@ runtime_common = [
...
@@ -27,7 +27,7 @@ runtime_common = [
]
]
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"cuda-python"
,
"sglang[runtime_common]"
,
"cuda-python"
,
"sgl-kernel>=0.0.2.post1
2
"
,
"torch"
,
"vllm>=0.6.3.post1,<=0.6.4.post1"
,
"sgl-kernel>=0.0.2.post1
4
"
,
"torch"
,
"vllm>=0.6.3.post1,<=0.6.4.post1"
,
"flashinfer==0.1.6"
"flashinfer==0.1.6"
]
]
...
...
python/sglang/srt/_custom_ops.py
View file @
767c9dec
# Adapted from https://github.com/vllm-project/vllm/blob/
a6221a144af772fd1a68fe7e627935dc53e81738
/vllm/_custom_ops.py
# Adapted from https://github.com/vllm-project/vllm/blob/
v0.6.4.post1
/vllm/_custom_ops.py
import
contextlib
import
contextlib
import
functools
import
functools
import
importlib
import
importlib
...
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
...
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
if
not
is_hpu
():
if
not
is_hpu
():
try
:
try
:
import
custom_ar
import
sgl_kernel
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
logger
.
warning
(
"Failed to import from custom_ar with %r"
,
e
)
...
@@ -50,46 +50,41 @@ def hint_on_error(fn):
...
@@ -50,46 +50,41 @@ def hint_on_error(fn):
# custom ar
# custom ar
def
init_custom_ar
(
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
rank_id
:
int
,
rank_data
:
torch
.
Tensor
,
world_size
:
int
,
rank
:
int
,
rank_data_base
:
torch
.
Tensor
,
full_nvlink
:
bool
,
buffers
:
List
[
int
],
tmp_result_buffers
:
List
[
int
],
barrier_in
:
List
[
int
],
barrier_out
:
List
[
int
],
)
->
int
:
)
->
int
:
return
torch
.
ops
.
_C_vllm_ar
.
init_custom_ar
(
return
sgl_kernel
.
ops
.
init_custom_reduce
(
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
)
def
all_reduce
(
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
fa
:
int
,
sgl_kernel
.
ops
.
custom_reduce
(
fa
,
inp
,
out
)
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
,
)
->
None
:
torch
.
ops
.
_C_vllm_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
dispose
(
fa
:
int
)
->
None
:
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_vllm_ar
.
dispose
(
fa
)
sgl_kernel
.
ops
.
custom_dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_vllm_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
List
[
int
])
->
None
:
return
torch
.
ops
.
_C_vllm_ar
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
torch
.
ops
.
_C_vllm_ar
.
get_graph_buffer_ipc_meta
(
fa
)
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
)
->
None
:
torch
.
ops
.
_C_vllm_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
767c9dec
...
@@ -21,7 +21,8 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
...
@@ -21,7 +21,8 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
from
sglang.srt.utils
import
cuda_device_count_stateless
,
is_cuda
from
sglang.srt.utils
import
cuda_device_count_stateless
,
is_cuda
try
:
try
:
ops
.
meta_size
()
import
sgl_kernel
custom_ar
=
True
custom_ar
=
True
except
Exception
:
except
Exception
:
# For AMD GPUs and CPUs
# For AMD GPUs and CPUs
...
@@ -29,7 +30,6 @@ except Exception:
...
@@ -29,7 +30,6 @@ except Exception:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_P
=
ParamSpec
(
"_P"
)
_P
=
ParamSpec
(
"_P"
)
_R
=
TypeVar
(
"_R"
)
_R
=
TypeVar
(
"_R"
)
...
@@ -47,7 +47,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
...
@@ -47,7 +47,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@
with_nvml_context
@
with_nvml_context
def
is_full_nvlink
(
cls
,
physical_device_ids
:
List
[
int
])
->
bool
:
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
])
->
bool
:
"""
"""
query if the set of gpus are fully connected by nvlink (1 hop)
query if the set of gpus are fully connected by nvlink (1 hop)
"""
"""
...
@@ -196,32 +196,39 @@ class CustomAllreduce:
...
@@ -196,32 +196,39 @@ class CustomAllreduce:
)
)
return
return
self
.
disabled
=
False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
max_size
=
max_size
self
.
max_size
=
max_size
self
.
rank
=
rank
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
self
.
full_nvlink
=
full_nvlink
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
rank_data_base
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta_ptrs
,
self
.
rank_data
,
rank
,
self
.
full_nvlink
rank
,
world_size
,
self
.
rank_data_base
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
self
.
disabled
=
False
@
staticmethod
@
staticmethod
def
create_shared_buffer
(
def
create_shared_buffer
(
...
@@ -300,12 +307,25 @@ class CustomAllreduce:
...
@@ -300,12 +307,25 @@ class CustomAllreduce:
return
False
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
# little performance improvement over NCCL.
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
if
self
.
world_size
==
2
:
return
inp_size
<
self
.
max_size
return
(
inp_size
<
self
.
max_size
and
inp_size
<
self
.
max_required_workspace_size
[
0
]
)
if
self
.
full_nvlink
:
return
(
inp_size
<
self
.
max_size
and
inp_size
<
self
.
max_required_workspace_size
[
1
]
)
return
False
return
False
def
all_reduce
(
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
):
):
"""Performs an out-of-place all reduce.
"""Performs an out-of-place all reduce.
...
@@ -315,12 +335,7 @@ class CustomAllreduce:
...
@@ -315,12 +335,7 @@ class CustomAllreduce:
"""
"""
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
out
=
torch
.
empty_like
(
inp
)
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
)
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
return
out
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
@@ -330,23 +345,22 @@ class CustomAllreduce:
...
@@ -330,23 +345,22 @@ class CustomAllreduce:
return
None
return
None
if
self
.
_IS_CAPTURING
:
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
,
registered
=
True
)
return
self
.
all_reduce
(
input
)
else
:
else
:
# If warm up, mimic the allocation pattern since custom
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
return
torch
.
empty_like
(
input
)
else
:
else
:
# Note: outside of cuda graph context, custom allreduce incurs a
return
self
.
all_reduce
(
input
)
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
close
(
self
):
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
not
self
.
disabled
and
self
.
_ptr
:
ops
.
dispose
(
self
.
_ptr
)
ops
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
)
self
.
_ptr
=
0
def
__del__
(
self
):
def
__del__
(
self
):
self
.
close
()
self
.
close
()
test/srt/run_suite.py
View file @
767c9dec
...
@@ -12,6 +12,7 @@ suites = {
...
@@ -12,6 +12,7 @@ suites = {
"sampling/penaltylib"
,
"sampling/penaltylib"
,
"test_abort.py"
,
"test_abort.py"
,
"test_chunked_prefill.py"
,
"test_chunked_prefill.py"
,
"test_custom_allreduce.py"
,
"test_double_sparsity.py"
,
"test_double_sparsity.py"
,
"test_eagle_infer.py"
,
"test_eagle_infer.py"
,
"test_embedding_openai_server.py"
,
"test_embedding_openai_server.py"
,
...
...
test/srt/test_custom_allreduce.py
0 → 100644
View file @
767c9dec
import
os
import
random
import
socket
import
unittest
from
typing
import
Any
import
ray
import
torch
import
torch.distributed
as
dist
from
sglang.srt.distributed
import
init_distributed_environment
from
sglang.srt.distributed.communication_op
import
(
# noqa
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.distributed.parallel_state
import
(
get_tensor_model_parallel_group
,
graph_capture
,
initialize_model_parallel
,
)
def
get_open_port
()
->
int
:
# try ipv4
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
except
OSError
:
# try ipv6
with
socket
.
socket
(
socket
.
AF_INET6
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
def
multi_process_parallel
(
world_size
:
int
,
cls
:
Any
,
test_target
:
Any
,
)
->
None
:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray
.
init
(
log_to_driver
=
False
)
distributed_init_port
=
get_open_port
()
refs
=
[]
for
rank
in
range
(
world_size
):
refs
.
append
(
test_target
.
remote
(
cls
,
world_size
,
rank
,
distributed_init_port
))
ray
.
get
(
refs
)
ray
.
shutdown
()
class
TestCustomAllReduce
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
random
.
seed
(
42
)
# 512B to 32MB
cls
.
test_sizes
=
[
512
,
4096
,
32768
,
262144
,
2097152
,
16777216
,
33554432
]
cls
.
world_sizes
=
[
2
,
4
,
6
,
8
]
cls
.
test_loop
=
10
def
test_graph_allreduce
(
self
):
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
continue
multi_process_parallel
(
world_size
,
self
,
self
.
graph_allreduce
)
def
test_eager_allreduce
(
self
):
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
continue
multi_process_parallel
(
world_size
,
self
,
self
.
eager_allreduce
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
graph_allreduce
(
self
,
world_size
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
rank
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
group
=
get_tensor_model_parallel_group
().
device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data
=
torch
.
zeros
(
1
)
data
=
data
.
to
(
device
=
device
)
torch
.
distributed
.
all_reduce
(
data
,
group
=
group
)
torch
.
cuda
.
synchronize
()
del
data
for
sz
in
self
.
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
self
.
test_loop
):
with
graph_capture
()
as
graph_capture_context
:
# use integers so result matches NCCL exactly
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
inp2
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
torch
.
cuda
.
synchronize
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
graph_capture_context
.
stream
):
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
# the input buffer is immediately modified to test
# synchronization
dist
.
all_reduce
(
inp1
,
group
=
group
)
out2
=
tensor_model_parallel_all_reduce
(
inp2
)
dist
.
all_reduce
(
inp2
,
group
=
group
)
graph
.
replay
()
torch
.
testing
.
assert_close
(
out1
,
inp1
)
torch
.
testing
.
assert_close
(
out2
,
inp2
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
eager_allreduce
(
self
,
world_size
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
rank
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
group
=
get_tensor_model_parallel_group
().
device_group
for
sz
in
self
.
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
self
.
test_loop
):
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
()
)
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
dist
.
all_reduce
(
inp1
,
group
=
group
)
torch
.
testing
.
assert_close
(
out1
,
inp1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment