Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5d4d9053
Unverified
Commit
5d4d9053
authored
Jun 23, 2024
by
Murali Andoorveedu
Committed by
GitHub
Jun 23, 2024
Browse files
[Distributed] Add send and recv helpers (#5719)
parent
6c916ac8
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
278 additions
and
24 deletions
+278
-24
tests/distributed/test_comm_ops.py
tests/distributed/test_comm_ops.py
+74
-4
tests/distributed/test_custom_all_reduce.py
tests/distributed/test_custom_all_reduce.py
+2
-3
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+12
-4
tests/utils.py
tests/utils.py
+1
-1
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+2
-12
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+187
-0
No files found.
tests/distributed/test_comm_ops.py
View file @
5d4d9053
...
...
@@ -8,12 +8,11 @@ import pytest
import
ray
import
torch
from
vllm.distributed
import
(
broadcast_tensor_dict
,
from
vllm.distributed
import
(
broadcast_tensor_dict
,
get_pp_group
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
..utils
import
(
init_test_distributed_environment
,
multi_process_tensor_parallel
)
from
..utils
import
init_test_distributed_environment
,
multi_process_parallel
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
...
...
@@ -105,6 +104,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
assert
torch
.
allclose
(
recv_dict
[
"f"
],
test_dict
[
"f"
])
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
send_recv_tensor_dict_test_worker
(
tp_size
:
int
,
pp_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
tp_size
,
pp_size
,
rank
,
distributed_init_port
)
test_dict
=
{
# device tensor
"a"
:
torch
.
arange
(
8
,
dtype
=
torch
.
float32
,
device
=
"cuda"
),
# CPU tensor
"b"
:
torch
.
arange
(
16
,
dtype
=
torch
.
int8
,
device
=
"cpu"
),
"c"
:
"test"
,
"d"
:
[
1
,
2
,
3
],
"e"
:
{
"a"
:
1
,
"b"
:
2
},
# empty tensor
"f"
:
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
"cuda"
),
}
if
not
get_pp_group
().
is_first_rank
:
recv_dict
=
get_pp_group
().
recv_tensor_dict
()
if
not
get_pp_group
().
is_last_rank
:
get_pp_group
().
send_tensor_dict
(
test_dict
)
if
not
get_pp_group
().
is_first_rank
:
assert
len
(
recv_dict
)
==
len
(
test_dict
)
assert
torch
.
allclose
(
recv_dict
[
"a"
],
test_dict
[
"a"
])
assert
torch
.
allclose
(
recv_dict
[
"b"
],
test_dict
[
"b"
])
assert
recv_dict
[
"c"
]
==
test_dict
[
"c"
]
assert
recv_dict
[
"d"
]
==
test_dict
[
"d"
]
assert
recv_dict
[
"e"
]
==
test_dict
[
"e"
]
assert
torch
.
allclose
(
recv_dict
[
"f"
],
test_dict
[
"f"
])
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
send_recv_test_worker
(
tp_size
:
int
,
pp_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
tp_size
,
pp_size
,
rank
,
distributed_init_port
)
size
=
64
test_tensor
=
torch
.
arange
(
64
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
not
get_pp_group
().
is_first_rank
:
recv_tensor
=
get_pp_group
().
recv
(
size
,
dtype
=
torch
.
float32
)
if
not
get_pp_group
().
is_last_rank
:
get_pp_group
().
send
(
test_tensor
)
if
not
get_pp_group
().
is_first_rank
:
assert
torch
.
allclose
(
test_tensor
,
recv_tensor
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
...
...
@@ -113,4 +174,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
broadcast_tensor_dict_test_worker
])
def
test_multi_process_tensor_parallel
(
tp_size
,
test_target
):
multi_process_tensor_parallel
(
tp_size
,
1
,
test_target
)
multi_process_parallel
(
tp_size
,
1
,
test_target
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"pp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
send_recv_test_worker
,
send_recv_tensor_dict_test_worker
])
def
test_multi_process_pipeline_parallel
(
pp_size
,
test_target
):
multi_process_parallel
(
1
,
pp_size
,
test_target
)
tests/distributed/test_custom_all_reduce.py
View file @
5d4d9053
...
...
@@ -12,8 +12,7 @@ from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_group
,
graph_capture
)
from
..utils
import
(
ensure_model_parallel_initialized
,
init_test_distributed_environment
,
multi_process_tensor_parallel
)
init_test_distributed_environment
,
multi_process_parallel
)
random
.
seed
(
42
)
test_sizes
=
[
random
.
randint
(
1024
,
2048
*
1024
)
for
_
in
range
(
8
)]
...
...
@@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
world_size
=
tp_size
*
pipeline_parallel_size
if
world_size
>
torch
.
cuda
.
device_count
():
pytest
.
skip
(
"Not enough GPUs to run the test."
)
multi_process_
tensor_
parallel
(
tp_size
,
pipeline_parallel_size
,
test_target
)
multi_process_parallel
(
tp_size
,
pipeline_parallel_size
,
test_target
)
tests/distributed/test_pynccl.py
View file @
5d4d9053
...
...
@@ -168,9 +168,13 @@ def send_recv_worker_fn():
dtype
=
torch
.
float32
).
cuda
(
pynccl_comm
.
rank
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
if
pynccl_comm
.
rank
==
0
:
pynccl_comm
.
send
(
tensor
)
pynccl_comm
.
send
(
tensor
,
dst
=
(
pynccl_comm
.
rank
+
1
)
%
pynccl_comm
.
world_size
)
else
:
pynccl_comm
.
recv
(
tensor
)
pynccl_comm
.
recv
(
tensor
,
src
=
(
pynccl_comm
.
rank
-
1
)
%
pynccl_comm
.
world_size
)
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
1
...
...
@@ -203,9 +207,13 @@ def multiple_send_recv_worker_fn():
device
=
device
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
pynccl_comm
.
send
(
tensor
)
pynccl_comm
.
send
(
tensor
,
dst
=
(
pynccl_comm
.
rank
+
1
)
%
pynccl_comm
.
world_size
)
else
:
pynccl_comm
.
recv
(
tensor
)
pynccl_comm
.
recv
(
tensor
,
src
=
(
pynccl_comm
.
rank
-
1
)
%
pynccl_comm
.
world_size
)
result
=
tensor
.
mean
().
cpu
().
item
()
if
torch
.
distributed
.
get_rank
()
in
[
0
,
2
]:
assert
result
==
1
...
...
tests/utils.py
View file @
5d4d9053
...
...
@@ -129,7 +129,7 @@ def init_test_distributed_environment(
ensure_model_parallel_initialized
(
tp_size
,
pp_size
)
def
multi_process_
tensor_
parallel
(
def
multi_process_parallel
(
tp_size
:
int
,
pp_size
:
int
,
test_target
,
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
5d4d9053
...
...
@@ -121,10 +121,7 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
,
stream
=
None
):
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
if
self
.
disabled
:
return
assert
tensor
.
device
==
self
.
device
,
(
...
...
@@ -132,16 +129,11 @@ class PyNcclCommunicator:
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
if
dst
is
None
:
dst
=
(
self
.
rank
+
1
)
%
self
.
world_size
self
.
nccl
.
ncclSend
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
dst
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
recv
(
self
,
tensor
:
torch
.
Tensor
,
src
:
Optional
[
int
]
=
None
,
stream
=
None
):
def
recv
(
self
,
tensor
:
torch
.
Tensor
,
src
:
int
,
stream
=
None
):
if
self
.
disabled
:
return
assert
tensor
.
device
==
self
.
device
,
(
...
...
@@ -149,8 +141,6 @@ class PyNcclCommunicator:
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
if
src
is
None
:
src
=
(
self
.
rank
-
1
)
%
self
.
world_size
self
.
nccl
.
ncclRecv
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
...
...
vllm/distributed/parallel_state.py
View file @
5d4d9053
...
...
@@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline
steps.
"""
import
contextlib
import
pickle
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
...
...
@@ -28,6 +29,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from
unittest.mock
import
patch
import
torch
import
torch.distributed
from
torch.distributed
import
Backend
,
ProcessGroup
import
vllm.envs
as
envs
...
...
@@ -180,6 +182,16 @@ class GroupCoordinator:
"""Return the global rank of the last process in the group"""
return
self
.
ranks
[
-
1
]
@
property
def
is_first_rank
(
self
):
"""Return whether the caller is the first process in the group"""
return
self
.
rank
==
self
.
first_rank
@
property
def
is_last_rank
(
self
):
"""Return whether the caller is the last process in the group"""
return
self
.
rank
==
self
.
last_rank
@
property
def
next_rank
(
self
):
"""Return the global rank of the process that follows the caller"""
...
...
@@ -374,6 +386,70 @@ class GroupCoordinator:
group
=
self
.
device_group
)
return
obj_list
def
send_object
(
self
,
obj
:
Any
,
dst
:
int
)
->
None
:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
assert
dst
<
self
.
world_size
,
f
"Invalid dst rank (
{
dst
}
)"
assert
dst
!=
self
.
rank
,
(
"Invalid destination rank. Destination rank is the same "
"as the current rank."
)
# Serialize object to tensor and get the size as well
object_tensor
=
torch
.
frombuffer
(
pickle
.
dumps
(
obj
),
dtype
=
torch
.
uint8
)
size_tensor
=
torch
.
tensor
([
object_tensor
.
numel
()],
dtype
=
torch
.
long
,
device
=
"cpu"
)
# Send object size
torch
.
distributed
.
send
(
size_tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
cpu_group
)
# Send object
torch
.
distributed
.
send
(
object_tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
cpu_group
)
return
None
def
recv_object
(
self
,
src
:
int
)
->
Any
:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
assert
src
!=
self
.
rank
,
(
"Invalid source rank. Source rank is the same as the current rank."
)
size_tensor
=
torch
.
empty
(
1
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
# Receive object size
rank_size
=
torch
.
distributed
.
recv
(
size_tensor
,
src
=
src
,
group
=
self
.
cpu_group
)
# Tensor to receive serialized objects into.
object_tensor
=
torch
.
empty
(
# type: ignore[call-overload]
size_tensor
.
item
(),
# type: ignore[arg-type]
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
rank_object
=
torch
.
distributed
.
recv
(
object_tensor
,
src
=
src
,
group
=
self
.
cpu_group
)
assert
rank_object
==
rank_size
,
(
"Received object sender rank does not match the size sender rank."
)
obj
=
pickle
.
loads
(
object_tensor
.
numpy
().
tobytes
())
return
obj
def
broadcast_tensor_dict
(
self
,
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
...
...
@@ -459,6 +535,88 @@ class GroupCoordinator:
async_handle
.
wait
()
return
tensor_dict
def
send_tensor_dict
(
self
,
tensor_dict
:
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]],
dst
:
Optional
[
int
]
=
None
)
->
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
tensor_dict
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
if
dst
is
None
:
dst
=
self
.
next_rank
assert
dst
<
self
.
world_size
,
f
"Invalid dst rank (
{
dst
}
)"
metadata_list
:
List
[
Tuple
[
Any
,
Any
]]
=
[]
assert
isinstance
(
tensor_dict
,
dict
),
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
metadata_list
,
tensor_list
=
_split_tensor_dict
(
tensor_dict
)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self
.
send_object
(
metadata_list
,
dst
=
dst
)
for
tensor
in
tensor_list
:
if
tensor
.
numel
()
==
0
:
# Skip sending empty tensors.
continue
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
torch
.
distributed
.
send
(
tensor
,
dst
=
dst
,
group
=
metadata_group
)
else
:
# use group for GPU tensors
torch
.
distributed
.
send
(
tensor
,
dst
=
dst
,
group
=
group
)
return
None
def
recv_tensor_dict
(
self
,
src
:
Optional
[
int
]
=
None
)
->
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
None
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
if
src
is
None
:
src
=
self
.
prev_rank
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
recv_metadata_list
=
self
.
recv_object
(
src
=
src
)
tensor_dict
=
{}
for
key
,
value
in
recv_metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
device
=
value
.
device
)
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
tensor_dict
[
key
]
=
tensor
continue
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
torch
.
distributed
.
recv
(
tensor
,
src
=
src
,
group
=
metadata_group
)
else
:
# use group for GPU tensors
torch
.
distributed
.
recv
(
tensor
,
src
=
src
,
group
=
group
)
tensor_dict
[
key
]
=
tensor
else
:
tensor_dict
[
key
]
=
value
return
tensor_dict
def
barrier
(
self
):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
...
...
@@ -468,6 +626,35 @@ class GroupCoordinator:
"""
torch
.
distributed
.
barrier
(
group
=
self
.
cpu_group
)
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
)
->
None
:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if
dst
is
None
:
dst
=
self
.
next_rank
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
send
(
tensor
,
dst
)
else
:
torch
.
distributed
.
send
(
tensor
,
self
.
ranks
[
dst
],
self
.
device_group
)
def
recv
(
self
,
size
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the destination rank."""
if
src
is
None
:
src
=
self
.
prev_rank
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
self
.
device
)
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
recv
(
tensor
,
src
)
else
:
torch
.
distributed
.
recv
(
tensor
,
self
.
ranks
[
src
],
self
.
device_group
)
return
tensor
def
destroy
(
self
):
if
self
.
device_group
is
not
None
:
torch
.
distributed
.
destroy_process_group
(
self
.
device_group
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment