Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
21fe7b48
Unverified
Commit
21fe7b48
authored
Dec 02, 2024
by
youkaichao
Committed by
GitHub
Dec 03, 2024
Browse files
[core][distributed] add pynccl broadcast (#10843)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
a4cf2561
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
2 deletions
+78
-2
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+43
-2
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+19
-0
vllm/distributed/device_communicators/pynccl_wrapper.py
vllm/distributed/device_communicators/pynccl_wrapper.py
+16
-0
No files found.
tests/distributed/test_pynccl.py
View file @
21fe7b48
...
...
@@ -61,6 +61,7 @@ def worker_fn():
dtype
=
torch
.
float32
).
cuda
(
pynccl_comm
.
rank
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
tensor
=
pynccl_comm
.
all_reduce
(
tensor
)
torch
.
cuda
.
synchronize
()
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
pynccl_comm
.
world_size
...
...
@@ -86,10 +87,12 @@ def multiple_allreduce_worker_fn():
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
tensor
=
pynccl_comm
.
all_reduce
(
tensor
)
tensor
=
pynccl_comm
.
all_reduce
(
tensor
)
torch
.
cuda
.
synchronize
()
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
4
else
:
tensor
=
pynccl_comm
.
all_reduce
(
tensor
)
torch
.
cuda
.
synchronize
()
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
2
...
...
@@ -112,10 +115,12 @@ def multiple_allreduce_with_vllm_worker_fn():
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
torch
.
cuda
.
synchronize
()
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
4
else
:
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
torch
.
cuda
.
synchronize
()
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
2
...
...
@@ -141,9 +146,9 @@ def worker_fn_with_cudagraph():
graph
,
stream
=
pynccl_comm
.
stream
),
pynccl_comm
.
change_state
(
enable
=
True
):
a_out
=
pynccl_comm
.
all_reduce
(
a
)
pynccl_comm
.
stream
.
synchronize
()
torch
.
cuda
.
synchronize
()
graph
.
replay
()
pynccl_comm
.
stream
.
synchronize
()
torch
.
cuda
.
synchronize
()
assert
a_out
.
mean
().
cpu
().
item
()
==
pynccl_comm
.
world_size
**
1
...
...
@@ -170,6 +175,7 @@ def all_gather_worker_fn():
with
pynccl_comm
.
change_state
(
enable
=
True
):
pynccl_comm
.
all_gather
(
result
,
tensor
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-8
)
...
...
@@ -207,6 +213,7 @@ def reduce_scatter_worker_fn():
with
pynccl_comm
.
change_state
(
enable
=
True
):
pynccl_comm
.
reduce_scatter
(
result
,
tensor
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-8
)
...
...
@@ -241,6 +248,7 @@ def send_recv_worker_fn():
pynccl_comm
.
recv
(
tensor
,
src
=
(
pynccl_comm
.
rank
-
1
)
%
pynccl_comm
.
world_size
)
torch
.
cuda
.
synchronize
()
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
1
...
...
@@ -280,6 +288,7 @@ def multiple_send_recv_worker_fn():
pynccl_comm
.
recv
(
tensor
,
src
=
(
pynccl_comm
.
rank
-
1
)
%
pynccl_comm
.
world_size
)
torch
.
cuda
.
synchronize
()
result
=
tensor
.
mean
().
cpu
().
item
()
if
torch
.
distributed
.
get_rank
()
in
[
0
,
2
]:
assert
result
==
1
...
...
@@ -293,6 +302,38 @@ def test_pynccl_multiple_send_recv():
distributed_run
(
multiple_send_recv_worker_fn
,
4
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Need at least 4 GPUs to run the test."
)
def
test_pynccl_broadcast
():
distributed_run
(
broadcast_worker_fn
,
4
)
@
worker_fn_wrapper
def
broadcast_worker_fn
():
# Test broadcast for every root rank.
# Essentially this is an all-gather operation.
pynccl_comm
=
PyNcclCommunicator
(
get_world_group
().
cpu_group
,
device
=
get_world_group
().
device
)
recv_tensors
=
[
torch
.
empty
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
pynccl_comm
.
device
)
for
i
in
range
(
pynccl_comm
.
world_size
)
]
recv_tensors
[
pynccl_comm
.
rank
]
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
pynccl_comm
.
device
)
*
pynccl_comm
.
rank
for
i
in
range
(
pynccl_comm
.
world_size
):
pynccl_comm
.
broadcast
(
recv_tensors
[
i
],
src
=
i
)
# the broadcast op might be launched in a different stream
# need to synchronize to make sure the tensor is ready
torch
.
cuda
.
synchronize
()
assert
torch
.
all
(
recv_tensors
[
i
]
==
i
).
cpu
().
item
()
def
test_ncclGetUniqueId
():
lib
=
NCCLLibrary
()
unique_id
=
lib
.
ncclGetUniqueId
()
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
21fe7b48
...
...
@@ -197,6 +197,25 @@ class PyNcclCommunicator:
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
broadcast
(
self
,
tensor
:
torch
.
Tensor
,
src
:
int
,
stream
=
None
):
if
self
.
disabled
:
return
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
if
src
==
self
.
rank
:
sendbuff
=
buffer_type
(
tensor
.
data_ptr
())
# NCCL requires the sender also to have a receive buffer
recvbuff
=
buffer_type
(
tensor
.
data_ptr
())
else
:
sendbuff
=
buffer_type
()
recvbuff
=
buffer_type
(
tensor
.
data_ptr
())
self
.
nccl
.
ncclBroadcast
(
sendbuff
,
recvbuff
,
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
@
contextmanager
def
change_state
(
self
,
enable
:
Optional
[
bool
]
=
None
,
...
...
vllm/distributed/device_communicators/pynccl_wrapper.py
View file @
21fe7b48
...
...
@@ -189,6 +189,15 @@ class NCCLLibrary:
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, int root, ncclComm_t comm,
# cudaStream_t stream);
Function
(
"ncclBroadcast"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ctypes
.
c_int
,
ncclComm_t
,
cudaStream_t
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
...
...
@@ -312,6 +321,13 @@ class NCCLLibrary:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclRecv"
](
recvbuff
,
count
,
datatype
,
src
,
comm
,
stream
))
def
ncclBroadcast
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
root
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclBroadcast"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
root
,
comm
,
stream
))
def
ncclCommDestroy
(
self
,
comm
:
ncclComm_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclCommDestroy"
](
comm
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment