Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
978b3974
Unverified
Commit
978b3974
authored
Nov 22, 2024
by
Tyler Michael Smith
Committed by
GitHub
Nov 22, 2024
Browse files
[Misc] Add pynccl wrappers for all_gather and reduce_scatter (#9432)
parent
ebda5196
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
155 additions
and
0 deletions
+155
-0
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+69
-0
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+42
-0
vllm/distributed/device_communicators/pynccl_wrapper.py
vllm/distributed/device_communicators/pynccl_wrapper.py
+44
-0
No files found.
tests/distributed/test_pynccl.py
View file @
978b3974
...
...
@@ -150,6 +150,75 @@ def worker_fn_with_cudagraph():
assert
a
.
mean
().
cpu
().
item
()
==
pynccl_comm
.
world_size
**
1
@
worker_fn_wrapper
def
all_gather_worker_fn
():
pynccl_comm
=
PyNcclCommunicator
(
get_world_group
().
cpu_group
,
device
=
get_world_group
().
device
)
rank
=
pynccl_comm
.
rank
world_size
=
pynccl_comm
.
world_size
device
=
f
'cuda:
{
pynccl_comm
.
rank
}
'
num_elems
=
1000
tensor
=
torch
.
arange
(
num_elems
,
dtype
=
torch
.
float32
,
device
=
device
)
+
rank
*
num_elems
result
=
torch
.
zeros
(
num_elems
*
world_size
,
dtype
=
torch
.
float32
,
device
=
device
)
expected
=
torch
.
cat
([
torch
.
arange
(
num_elems
,
dtype
=
torch
.
float32
)
+
r
*
num_elems
for
r
in
range
(
world_size
)
]).
to
(
device
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
pynccl_comm
.
all_gather
(
result
,
tensor
)
torch
.
testing
.
assert_close
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-8
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
def
test_pynccl_all_gather
():
distributed_run
(
all_gather_worker_fn
,
2
)
@
worker_fn_wrapper
def
reduce_scatter_worker_fn
():
pynccl_comm
=
PyNcclCommunicator
(
get_world_group
().
cpu_group
,
device
=
get_world_group
().
device
)
rank
=
pynccl_comm
.
rank
world_size
=
pynccl_comm
.
world_size
device
=
f
'cuda:
{
pynccl_comm
.
rank
}
'
num_elems
=
1000
tensor
=
torch
.
arange
(
num_elems
,
dtype
=
torch
.
float32
,
device
=
device
)
+
rank
*
num_elems
assert
(
num_elems
%
world_size
==
0
)
result
=
torch
.
zeros
(
num_elems
//
world_size
,
dtype
=
torch
.
float32
,
device
=
device
)
# Calculate expected result for this rank's chunk
scattered_size
=
num_elems
//
world_size
all_tensors
=
[
torch
.
arange
(
num_elems
,
dtype
=
torch
.
float32
)
+
r
*
num_elems
for
r
in
range
(
world_size
)
]
expected
=
sum
(
tensor
[
rank
*
scattered_size
:(
rank
+
1
)
*
scattered_size
]
for
tensor
in
all_tensors
).
to
(
device
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
pynccl_comm
.
reduce_scatter
(
result
,
tensor
)
torch
.
testing
.
assert_close
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-8
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
def
test_pynccl_reduce_scatter
():
distributed_run
(
reduce_scatter_worker_fn
,
2
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
def
test_pynccl_with_cudagraph
():
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
978b3974
...
...
@@ -131,6 +131,48 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
all_gather
(
self
,
output_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
stream
=
None
):
if
self
.
disabled
:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert
input_tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
input_tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
self
.
nccl
.
ncclAllGather
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
input_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
reduce_scatter
(
self
,
output_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
):
if
self
.
disabled
:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert
input_tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
input_tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
self
.
nccl
.
ncclReduceScatter
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
output_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
if
self
.
disabled
:
return
...
...
vllm/distributed/device_communicators/pynccl_wrapper.py
View file @
978b3974
...
...
@@ -151,6 +151,28 @@ class NCCLLibrary:
ncclRedOp_t
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function
(
"ncclAllGather"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function
(
"ncclReduceScatter"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ncclRedOp_t
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
...
...
@@ -258,6 +280,28 @@ class NCCLLibrary:
datatype
,
op
,
comm
,
stream
))
def
ncclReduceScatter
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
op
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclReduceScatter"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
op
,
comm
,
stream
))
def
ncclAllGather
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclAllGather"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
comm
,
stream
))
def
ncclSend
(
self
,
sendbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
dest
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclSend"
](
sendbuff
,
count
,
datatype
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment