Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e6de9784
Unverified
Commit
e6de9784
authored
Nov 11, 2024
by
youkaichao
Committed by
GitHub
Nov 11, 2024
Browse files
[core][distributed] add stateless process group (#10216)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
36fc439d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
217 additions
and
112 deletions
+217
-112
tests/distributed/test_utils.py
tests/distributed/test_utils.py
+52
-27
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+24
-14
vllm/distributed/utils.py
vllm/distributed/utils.py
+141
-71
No files found.
tests/distributed/test_utils.py
View file @
e6de9784
import
pytest
import
pytest
import
ray
import
ray
import
torch
import
torch
import
torch.distributed
as
dist
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.utils
import
stateless_init_process_group
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.utils
import
(
cuda_device_count_stateless
,
from
vllm.utils
import
(
cuda_device_count_stateless
,
update_environment_variables
)
update_environment_variables
)
...
@@ -41,42 +41,45 @@ def test_cuda_device_count_stateless():
...
@@ -41,42 +41,45 @@ def test_cuda_device_count_stateless():
def
cpu_worker
(
rank
,
WORLD_SIZE
):
def
cpu_worker
(
rank
,
WORLD_SIZE
):
pg1
=
s
tateless
_init_p
rocess
_g
roup
(
init_method
=
"tcp://127.0.0.1:29500"
,
pg1
=
S
tateless
P
rocess
G
roup
.
create
(
init_method
=
"tcp://127.0.0.1:29500"
,
rank
=
rank
,
rank
=
rank
,
world_size
=
WORLD_SIZE
,
world_size
=
WORLD_SIZE
)
backend
=
"gloo"
)
if
rank
<=
2
:
if
rank
<=
2
:
pg2
=
s
tateless
_init_p
rocess
_g
roup
(
init_method
=
"tcp://127.0.0.1:29501"
,
pg2
=
S
tateless
P
rocess
G
roup
.
create
(
init_method
=
"tcp://127.0.0.1:29501"
,
rank
=
rank
,
rank
=
rank
,
world_size
=
3
,
world_size
=
3
)
backend
=
"gloo"
)
data
=
torch
.
tensor
([
rank
])
data
=
torch
.
tensor
([
rank
])
dist
.
all_reduce
(
data
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
pg1
)
data
=
pg1
.
broadcast_obj
(
data
,
src
=
2
)
assert
data
.
item
()
==
2
if
rank
<=
2
:
if
rank
<=
2
:
dist
.
all_reduce
(
data
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
pg2
)
data
=
torch
.
tensor
([
rank
+
1
])
item
=
data
[
0
].
item
()
data
=
pg2
.
broadcast_obj
(
data
,
src
=
2
)
print
(
f
"rank:
{
rank
}
, item:
{
item
}
"
)
assert
data
.
item
()
==
3
if
rank
==
3
:
pg2
.
barrier
()
assert
item
==
6
pg1
.
barrier
()
else
:
assert
item
==
18
def
gpu_worker
(
rank
,
WORLD_SIZE
):
def
gpu_worker
(
rank
,
WORLD_SIZE
):
pg1
=
stateless_init_process_group
(
init_method
=
"tcp://127.0.0.1:29502"
,
torch
.
cuda
.
set_device
(
rank
)
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
"tcp://127.0.0.1:29502"
,
rank
=
rank
,
rank
=
rank
,
world_size
=
WORLD_SIZE
,
world_size
=
WORLD_SIZE
)
backend
=
"nccl"
)
pynccl1
=
PyNcclCommunicator
(
pg1
,
device
=
rank
)
pynccl1
.
disabled
=
False
if
rank
<=
2
:
if
rank
<=
2
:
pg2
=
s
tateless
_init_p
rocess
_g
roup
(
init_method
=
"tcp://127.0.0.1:29503"
,
pg2
=
S
tateless
P
rocess
G
roup
.
create
(
init_method
=
"tcp://127.0.0.1:29503"
,
rank
=
rank
,
rank
=
rank
,
world_size
=
3
,
world_size
=
3
)
backend
=
"nccl"
)
pynccl2
=
PyNcclCommunicator
(
pg2
,
device
=
rank
)
torch
.
cuda
.
set_device
(
rank
)
pynccl2
.
disabled
=
False
data
=
torch
.
tensor
([
rank
]).
cuda
()
data
=
torch
.
tensor
([
rank
]).
cuda
()
dist
.
all_reduce
(
data
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
pg1
)
pynccl1
.
all_reduce
(
data
)
pg1
.
barrier
()
torch
.
cuda
.
synchronize
()
if
rank
<=
2
:
if
rank
<=
2
:
dist
.
all_reduce
(
data
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
pg2
)
pynccl2
.
all_reduce
(
data
)
pg2
.
barrier
()
torch
.
cuda
.
synchronize
()
item
=
data
[
0
].
item
()
item
=
data
[
0
].
item
()
print
(
f
"rank:
{
rank
}
, item:
{
item
}
"
)
print
(
f
"rank:
{
rank
}
, item:
{
item
}
"
)
if
rank
==
3
:
if
rank
==
3
:
...
@@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE):
...
@@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE):
assert
item
==
18
assert
item
==
18
def
broadcast_worker
(
rank
,
WORLD_SIZE
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
"tcp://127.0.0.1:29504"
,
rank
=
rank
,
world_size
=
WORLD_SIZE
)
if
rank
==
2
:
pg1
.
broadcast_obj
(
"secret"
,
src
=
2
)
else
:
obj
=
pg1
.
broadcast_obj
(
None
,
src
=
2
)
assert
obj
==
"secret"
pg1
.
barrier
()
def
allgather_worker
(
rank
,
WORLD_SIZE
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
"tcp://127.0.0.1:29505"
,
rank
=
rank
,
world_size
=
WORLD_SIZE
)
data
=
pg1
.
all_gather_obj
(
rank
)
assert
data
==
list
(
range
(
WORLD_SIZE
))
pg1
.
barrier
()
@
multi_gpu_test
(
num_gpus
=
4
)
@
multi_gpu_test
(
num_gpus
=
4
)
@
pytest
.
mark
.
parametrize
(
"worker"
,
[
cpu_worker
,
gpu_worker
])
@
pytest
.
mark
.
parametrize
(
def
test_stateless_init_process_group
(
worker
):
"worker"
,
[
cpu_worker
,
gpu_worker
,
broadcast_worker
,
allgather_worker
])
def
test_stateless_process_group
(
worker
):
WORLD_SIZE
=
4
WORLD_SIZE
=
4
from
multiprocessing
import
get_context
from
multiprocessing
import
get_context
ctx
=
get_context
(
"fork"
)
ctx
=
get_context
(
"fork"
)
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
e6de9784
...
@@ -9,6 +9,7 @@ from torch.distributed import ProcessGroup, ReduceOp
...
@@ -9,6 +9,7 @@ from torch.distributed import ProcessGroup, ReduceOp
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
NCCLLibrary
,
buffer_type
,
cudaStream_t
,
ncclComm_t
,
ncclDataTypeEnum
,
NCCLLibrary
,
buffer_type
,
cudaStream_t
,
ncclComm_t
,
ncclDataTypeEnum
,
ncclRedOpTypeEnum
,
ncclUniqueId
)
ncclRedOpTypeEnum
,
ncclUniqueId
)
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -18,7 +19,7 @@ class PyNcclCommunicator:
...
@@ -18,7 +19,7 @@ class PyNcclCommunicator:
def
__init__
(
def
__init__
(
self
,
self
,
group
:
ProcessGroup
,
group
:
Union
[
ProcessGroup
,
StatelessProcessGroup
],
device
:
Union
[
int
,
str
,
torch
.
device
],
device
:
Union
[
int
,
str
,
torch
.
device
],
library_path
:
Optional
[
str
]
=
None
,
library_path
:
Optional
[
str
]
=
None
,
):
):
...
@@ -33,13 +34,18 @@ class PyNcclCommunicator:
...
@@ -33,13 +34,18 @@ class PyNcclCommunicator:
It is the caller's responsibility to make sure each communicator
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
is bind to a unique device.
"""
"""
assert
dist
.
is_initialized
()
if
not
isinstance
(
group
,
StatelessProcessGroup
):
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
assert
dist
.
is_initialized
()
"PyNcclCommunicator should be attached to a non-NCCL group."
)
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"PyNcclCommunicator should be attached to a non-NCCL group."
)
# note: this rank is the rank in the group
self
.
rank
=
dist
.
get_rank
(
group
)
self
.
world_size
=
dist
.
get_world_size
(
group
)
else
:
self
.
rank
=
group
.
rank
self
.
world_size
=
group
.
world_size
self
.
group
=
group
self
.
group
=
group
# note: this rank is the rank in the group
self
.
rank
=
dist
.
get_rank
(
group
)
self
.
world_size
=
dist
.
get_world_size
(
group
)
# if world_size == 1, no need to create communicator
# if world_size == 1, no need to create communicator
if
self
.
world_size
==
1
:
if
self
.
world_size
==
1
:
...
@@ -68,13 +74,17 @@ class PyNcclCommunicator:
...
@@ -68,13 +74,17 @@ class PyNcclCommunicator:
else
:
else
:
# construct an empty unique id
# construct an empty unique id
self
.
unique_id
=
ncclUniqueId
()
self
.
unique_id
=
ncclUniqueId
()
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
ranks
=
dist
.
get_process_group_ranks
(
group
)
if
not
isinstance
(
group
,
StatelessProcessGroup
):
# arg `src` in `broadcast` is the global rank
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
dist
.
broadcast
(
tensor
,
src
=
ranks
[
0
],
group
=
group
)
ranks
=
dist
.
get_process_group_ranks
(
group
)
byte_list
=
tensor
.
tolist
()
# arg `src` in `broadcast` is the global rank
for
i
,
byte
in
enumerate
(
byte_list
):
dist
.
broadcast
(
tensor
,
src
=
ranks
[
0
],
group
=
group
)
self
.
unique_id
.
internal
[
i
]
=
byte
byte_list
=
tensor
.
tolist
()
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
else
:
self
.
unique_id
=
group
.
broadcast_obj
(
self
.
unique_id
,
src
=
0
)
if
isinstance
(
device
,
int
):
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
elif
isinstance
(
device
,
str
):
...
...
vllm/distributed/utils.py
View file @
e6de9784
...
@@ -2,13 +2,13 @@
...
@@ -2,13 +2,13 @@
# Adapted from
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Sequence
,
Tuple
import
dataclasses
import
pickle
import
time
from
collections
import
deque
from
typing
import
Any
,
Deque
,
Dict
,
Optional
,
Sequence
,
Tuple
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed.distributed_c10d
import
(
Backend
,
PrefixStore
,
_get_default_timeout
,
is_nccl_available
)
from
torch.distributed.rendezvous
import
rendezvous
from
torch.distributed.rendezvous
import
rendezvous
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -91,69 +91,139 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
...
@@ -91,69 +91,139 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
return
(
start_layer
,
end_layer
)
return
(
start_layer
,
end_layer
)
def
stateless_init_process_group
(
init_method
:
str
,
rank
:
int
,
world_size
:
int
,
@
dataclasses
.
dataclass
backend
:
str
)
->
ProcessGroup
:
class
StatelessProcessGroup
:
"""A replacement for `torch.distributed.init_process_group` that does not
"""A dataclass to hold a metadata store, and the rank, world_size of the
pollute the global state.
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
If we have process A and process B called `torch.distributed.init_process_group`
"""
to form a group, and then we want to form another group with process A, B, C,
prefix
:
str
D, it is not possible in PyTorch, because process A and process B have already
rank
:
int
formed a group, and process C and process D cannot join that group. This
world_size
:
int
function is a workaround for this issue.
store
:
torch
.
_C
.
_distributed_c10d
.
Store
data_expiration_seconds
:
int
=
3600
# 1 hour
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `ProcessGroup` object that can be used
# dst rank -> counter
for collective communication. With this function, process A and process B
send_dst_counter
:
Dict
[
int
,
int
]
=
dataclasses
.
field
(
default_factory
=
dict
)
can call `stateless_init_process_group` to form a group, and then process A, B,
# src rank -> counter
C, and D can call `stateless_init_process_group` to form another group.
recv_src_counter
:
Dict
[
int
,
int
]
=
dataclasses
.
field
(
default_factory
=
dict
)
"""
# noqa
broadcast_send_counter
:
int
=
0
broadcast_recv_src_counter
:
Dict
[
int
,
int
]
=
dataclasses
.
field
(
backend
=
Backend
(
backend
)
# it is basically string
default_factory
=
dict
)
timeout
=
_get_default_timeout
(
backend
)
# A deque to store the data entries, with key and timestamp.
store
,
rank
,
world_size
=
next
(
entries
:
Deque
[
Tuple
[
str
,
rendezvous
(
init_method
,
rank
,
world_size
,
timeout
=
timeout
))
float
]]
=
dataclasses
.
field
(
default_factory
=
deque
)
store
.
set_timeout
(
timeout
)
def
__post_init__
(
self
):
group_rank
=
rank
assert
self
.
rank
<
self
.
world_size
group_size
=
world_size
self
.
send_dst_counter
=
{
i
:
0
for
i
in
range
(
self
.
world_size
)}
self
.
recv_src_counter
=
{
i
:
0
for
i
in
range
(
self
.
world_size
)}
# Use a PrefixStore to avoid accidental overrides of keys used by
self
.
broadcast_recv_src_counter
=
{
# different systems (e.g. RPC) in case the store is multi-tenant.
i
:
0
prefix_store
=
PrefixStore
(
init_method
,
store
)
for
i
in
range
(
self
.
world_size
)
}
pg_options
=
ProcessGroup
.
Options
(
backend
=
backend
,
timeout
=
timeout
)
def
send_obj
(
self
,
obj
:
Any
,
dst
:
int
):
pg
:
ProcessGroup
=
ProcessGroup
(
"""Send an object to a destination rank."""
prefix_store
,
self
.
expire_data
()
group_rank
,
key
=
f
"
{
self
.
prefix
}
/send_to/
{
dst
}
/
{
self
.
send_dst_counter
[
dst
]
}
"
group_size
,
self
.
store
.
set
(
key
,
pickle
.
dumps
(
obj
))
pg_options
,
self
.
send_dst_counter
[
dst
]
+=
1
)
self
.
entries
.
append
((
key
,
time
.
time
()))
if
backend
==
"gloo"
:
def
expire_data
(
self
):
from
torch.distributed.distributed_c10d
import
ProcessGroupGloo
"""Expire data that is older than `data_expiration_seconds` seconds."""
backend_class
=
ProcessGroupGloo
(
prefix_store
,
while
self
.
entries
:
group_rank
,
# check the oldest entry
group_size
,
key
,
timestamp
=
self
.
entries
[
0
]
timeout
=
timeout
)
if
time
.
time
()
-
timestamp
>
self
.
data_expiration_seconds
:
backend_type
=
ProcessGroup
.
BackendType
.
GLOO
self
.
store
.
delete_key
(
key
)
device
=
torch
.
device
(
"cpu"
)
self
.
entries
.
popleft
()
elif
backend
==
"nccl"
:
else
:
assert
is_nccl_available
()
break
from
torch.distributed.distributed_c10d
import
ProcessGroupNCCL
def
recv_obj
(
self
,
src
:
int
)
->
Any
:
backend_options
=
ProcessGroupNCCL
.
Options
()
"""Receive an object from a source rank."""
backend_options
.
_timeout
=
timeout
obj
=
pickle
.
loads
(
self
.
store
.
get
(
backend_class
=
ProcessGroupNCCL
(
prefix_store
,
group_rank
,
group_size
,
f
"
{
self
.
prefix
}
/send_to/
{
self
.
rank
}
/
{
self
.
recv_src_counter
[
src
]
}
"
backend_options
)
))
backend_type
=
ProcessGroup
.
BackendType
.
NCCL
self
.
recv_src_counter
[
src
]
+=
1
device
=
torch
.
device
(
"cuda"
)
return
obj
backend_class
.
_set_sequence_number_for_group
()
def
broadcast_obj
(
self
,
obj
:
Optional
[
Any
],
src
:
int
)
->
Any
:
"""Broadcast an object from a source rank to all other ranks.
pg
.
_register_backend
(
device
,
backend_type
,
backend_class
)
It does not clean up after all ranks have received the object.
Use it for limited times, e.g., for initialization.
return
pg
"""
if
self
.
rank
==
src
:
self
.
expire_data
()
key
=
(
f
"
{
self
.
prefix
}
/broadcast_from/
{
src
}
/"
f
"
{
self
.
broadcast_send_counter
}
"
)
self
.
store
.
set
(
key
,
pickle
.
dumps
(
obj
))
self
.
broadcast_send_counter
+=
1
self
.
entries
.
append
((
key
,
time
.
time
()))
return
obj
else
:
key
=
(
f
"
{
self
.
prefix
}
/broadcast_from/
{
src
}
/"
f
"
{
self
.
broadcast_recv_src_counter
[
src
]
}
"
)
recv_obj
=
pickle
.
loads
(
self
.
store
.
get
(
key
))
self
.
broadcast_recv_src_counter
[
src
]
+=
1
return
recv_obj
def
all_gather_obj
(
self
,
obj
:
Any
)
->
list
[
Any
]:
"""All gather an object from all ranks."""
gathered_objs
=
[]
for
i
in
range
(
self
.
world_size
):
if
i
==
self
.
rank
:
gathered_objs
.
append
(
obj
)
self
.
broadcast_obj
(
obj
,
src
=
self
.
rank
)
else
:
recv_obj
=
self
.
broadcast_obj
(
None
,
src
=
i
)
gathered_objs
.
append
(
recv_obj
)
return
gathered_objs
def
barrier
(
self
):
"""A barrier to synchronize all ranks."""
for
i
in
range
(
self
.
world_size
):
if
i
==
self
.
rank
:
self
.
broadcast_obj
(
None
,
src
=
self
.
rank
)
else
:
self
.
broadcast_obj
(
None
,
src
=
i
)
@
staticmethod
def
create
(
init_method
:
str
,
rank
:
int
,
world_size
:
int
,
data_expiration_seconds
:
int
=
3600
,
)
->
"StatelessProcessGroup"
:
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `StatelessProcessGroup` object that can be
used for exchanging metadata. With this function, process A and process B
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
"""
# noqa
from
torch._C._distributed_c10d
import
_DEFAULT_PG_TIMEOUT
timeout
=
_DEFAULT_PG_TIMEOUT
store
,
rank
,
world_size
=
next
(
rendezvous
(
init_method
,
rank
,
world_size
,
timeout
=
timeout
))
store
.
set_timeout
(
timeout
)
return
StatelessProcessGroup
(
prefix
=
init_method
,
rank
=
rank
,
world_size
=
world_size
,
store
=
store
,
data_expiration_seconds
=
data_expiration_seconds
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment