Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
62de37a3
Unverified
Commit
62de37a3
authored
Dec 12, 2024
by
youkaichao
Committed by
GitHub
Dec 12, 2024
Browse files
[core][distributed] initialization from StatelessProcessGroup (#10986)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
81958242
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
153 additions
and
69 deletions
+153
-69
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-3
tests/distributed/test_same_node.py
tests/distributed/test_same_node.py
+25
-4
tests/distributed/test_shm_broadcast.py
tests/distributed/test_shm_broadcast.py
+56
-28
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+26
-13
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+43
-21
No files found.
.buildkite/test-pipeline.yaml
View file @
62de37a3
...
@@ -432,11 +432,11 @@ steps:
...
@@ -432,11 +432,11 @@ steps:
-
tests/distributed/
-
tests/distributed/
commands
:
commands
:
-
# the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
-
# the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep
-q
'Same node test passed'
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
-
# the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
-
# the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep
-q
'Same node test passed'
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
-
label
:
Distributed Tests (2 GPUs)
# 40min
-
label
:
Distributed Tests (2 GPUs)
# 40min
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
...
@@ -455,7 +455,7 @@ steps:
...
@@ -455,7 +455,7 @@ steps:
commands
:
commands
:
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_wrapper.py
-
pytest -v -s ./compile/test_wrapper.py
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep
-q
'Same node test passed'
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
-
TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
-
TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
# Avoid importing model tests that cause CUDA reinitialization error
# Avoid importing model tests that cause CUDA reinitialization error
-
pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
-
pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
...
...
tests/distributed/test_same_node.py
View file @
62de37a3
...
@@ -3,11 +3,32 @@ import os
...
@@ -3,11 +3,32 @@ import os
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
vllm.distributed.parallel_state
import
in_the_same_node_as
from
vllm.distributed.parallel_state
import
in_the_same_node_as
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.utils
import
get_ip
,
get_open_port
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
dist
.
init_process_group
(
backend
=
"gloo"
)
dist
.
init_process_group
(
backend
=
"gloo"
)
test_result
=
all
(
in_the_same_node_as
(
dist
.
group
.
WORLD
,
source_rank
=
0
))
rank
=
dist
.
get_rank
()
if
rank
==
0
:
port
=
get_open_port
()
ip
=
get_ip
()
dist
.
broadcast_object_list
([
ip
,
port
],
src
=
0
)
else
:
recv
=
[
None
,
None
]
dist
.
broadcast_object_list
(
recv
,
src
=
0
)
ip
,
port
=
recv
stateless_pg
=
StatelessProcessGroup
.
create
(
ip
,
port
,
rank
,
dist
.
get_world_size
())
for
pg
in
[
dist
.
group
.
WORLD
,
stateless_pg
]:
test_result
=
all
(
in_the_same_node_as
(
pg
,
source_rank
=
0
))
expected
=
os
.
environ
.
get
(
"VLLM_TEST_SAME_HOST"
,
"1"
)
==
"1"
expected
=
os
.
environ
.
get
(
"VLLM_TEST_SAME_HOST"
,
"1"
)
==
"1"
assert
test_result
==
expected
,
f
"Expected
{
expected
}
, got
{
test_result
}
"
assert
test_result
==
expected
,
\
print
(
"Same node test passed!"
)
f
"Expected
{
expected
}
, got
{
test_result
}
"
if
pg
==
dist
.
group
.
WORLD
:
print
(
"Same node test passed! when using torch distributed!"
)
else
:
print
(
"Same node test passed! when using StatelessProcessGroup!"
)
tests/distributed/test_shm_broadcast.py
View file @
62de37a3
...
@@ -7,7 +7,8 @@ import numpy as np
...
@@ -7,7 +7,8 @@ import numpy as np
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
vllm.distributed.device_communicators.shm_broadcast
import
MessageQueue
from
vllm.distributed.device_communicators.shm_broadcast
import
MessageQueue
from
vllm.utils
import
update_environment_variables
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.utils
import
get_ip
,
get_open_port
,
update_environment_variables
def
get_arrays
(
n
:
int
,
seed
:
int
=
0
)
->
List
[
np
.
ndarray
]:
def
get_arrays
(
n
:
int
,
seed
:
int
=
0
)
->
List
[
np
.
ndarray
]:
...
@@ -54,23 +55,44 @@ def worker_fn_wrapper(fn):
...
@@ -54,23 +55,44 @@ def worker_fn_wrapper(fn):
@
worker_fn_wrapper
@
worker_fn_wrapper
def
worker_fn
():
def
worker_fn
():
rank
=
dist
.
get_rank
()
if
rank
==
0
:
port
=
get_open_port
()
ip
=
get_ip
()
dist
.
broadcast_object_list
([
ip
,
port
],
src
=
0
)
else
:
recv
=
[
None
,
None
]
dist
.
broadcast_object_list
(
recv
,
src
=
0
)
ip
,
port
=
recv
stateless_pg
=
StatelessProcessGroup
.
create
(
ip
,
port
,
rank
,
dist
.
get_world_size
())
for
pg
in
[
dist
.
group
.
WORLD
,
stateless_pg
]:
writer_rank
=
2
writer_rank
=
2
broadcaster
=
MessageQueue
.
create_from_process_group
(
broadcaster
=
MessageQueue
.
create_from_process_group
(
dist
.
group
.
WORLD
,
40
*
1024
,
2
,
writer_rank
)
pg
,
40
*
1024
,
2
,
writer_rank
)
if
dist
.
get_
rank
()
==
writer_rank
:
if
rank
==
writer_rank
:
seed
=
random
.
randint
(
0
,
1000
)
seed
=
random
.
randint
(
0
,
1000
)
dist
.
broadcast_object_list
([
seed
],
writer_rank
)
dist
.
broadcast_object_list
([
seed
],
writer_rank
)
else
:
else
:
recv
=
[
None
]
recv
=
[
None
]
dist
.
broadcast_object_list
(
recv
,
writer_rank
)
dist
.
broadcast_object_list
(
recv
,
writer_rank
)
seed
=
recv
[
0
]
# type: ignore
seed
=
recv
[
0
]
# type: ignore
if
pg
==
dist
.
group
.
WORLD
:
dist
.
barrier
()
dist
.
barrier
()
else
:
pg
.
barrier
()
# in case we find a race condition
# in case we find a race condition
# print the seed so that we can reproduce the error
# print the seed so that we can reproduce the error
print
(
f
"Rank
{
dist
.
get_
rank
()
}
got seed
{
seed
}
"
)
print
(
f
"Rank
{
rank
}
got seed
{
seed
}
"
)
# test broadcasting with about 400MB of data
# test broadcasting with about 400MB of data
N
=
10_000
N
=
10_000
if
dist
.
get_
rank
()
==
writer_rank
:
if
rank
==
writer_rank
:
arrs
=
get_arrays
(
N
,
seed
)
arrs
=
get_arrays
(
N
,
seed
)
for
x
in
arrs
:
for
x
in
arrs
:
broadcaster
.
broadcast_object
(
x
)
broadcaster
.
broadcast_object
(
x
)
...
@@ -81,7 +103,13 @@ def worker_fn():
...
@@ -81,7 +103,13 @@ def worker_fn():
y
=
broadcaster
.
broadcast_object
(
None
)
y
=
broadcaster
.
broadcast_object
(
None
)
assert
np
.
array_equal
(
x
,
y
)
assert
np
.
array_equal
(
x
,
y
)
time
.
sleep
(
random
.
random
()
/
1000
)
time
.
sleep
(
random
.
random
()
/
1000
)
if
pg
==
dist
.
group
.
WORLD
:
dist
.
barrier
()
dist
.
barrier
()
print
(
"torch distributed passed the test!"
)
else
:
pg
.
barrier
()
print
(
"StatelessProcessGroup passed the test!"
)
def
test_shm_broadcast
():
def
test_shm_broadcast
():
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
62de37a3
...
@@ -5,7 +5,7 @@ import time
...
@@ -5,7 +5,7 @@ import time
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
multiprocessing
import
shared_memory
from
multiprocessing
import
shared_memory
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
...
@@ -15,6 +15,7 @@ from zmq import IPV6 # type: ignore
...
@@ -15,6 +15,7 @@ from zmq import IPV6 # type: ignore
from
zmq
import
SUB
,
SUBSCRIBE
,
XPUB
,
XPUB_VERBOSE
,
Context
# type: ignore
from
zmq
import
SUB
,
SUBSCRIBE
,
XPUB
,
XPUB_VERBOSE
,
Context
# type: ignore
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_ip
,
get_open_port
,
is_valid_ipv6_address
from
vllm.utils
import
get_ip
,
get_open_port
,
is_valid_ipv6_address
...
@@ -476,13 +477,19 @@ class MessageQueue:
...
@@ -476,13 +477,19 @@ class MessageQueue:
return
self
.
dequeue
()
return
self
.
dequeue
()
@
staticmethod
@
staticmethod
def
create_from_process_group
(
pg
:
ProcessGroup
,
def
create_from_process_group
(
pg
:
Union
[
ProcessGroup
,
StatelessProcessGroup
],
max_chunk_bytes
,
max_chunk_bytes
,
max_chunks
,
max_chunks
,
writer_rank
=
0
)
->
"MessageQueue"
:
writer_rank
=
0
)
->
"MessageQueue"
:
if
isinstance
(
pg
,
ProcessGroup
):
group_rank
=
dist
.
get_rank
(
pg
)
group_rank
=
dist
.
get_rank
(
pg
)
group_world_size
=
dist
.
get_world_size
(
pg
)
group_world_size
=
dist
.
get_world_size
(
pg
)
global_ranks
=
dist
.
get_process_group_ranks
(
pg
)
global_ranks
=
dist
.
get_process_group_ranks
(
pg
)
else
:
group_rank
=
pg
.
rank
group_world_size
=
pg
.
world_size
global_ranks
=
list
(
range
(
pg
.
world_size
))
from
vllm.distributed.parallel_state
import
in_the_same_node_as
from
vllm.distributed.parallel_state
import
in_the_same_node_as
status
=
in_the_same_node_as
(
pg
,
source_rank
=
writer_rank
)
status
=
in_the_same_node_as
(
pg
,
source_rank
=
writer_rank
)
...
@@ -500,15 +507,21 @@ class MessageQueue:
...
@@ -500,15 +507,21 @@ class MessageQueue:
max_chunks
=
max_chunks
,
max_chunks
=
max_chunks
,
)
)
handle
=
buffer_io
.
export_handle
()
handle
=
buffer_io
.
export_handle
()
if
isinstance
(
pg
,
ProcessGroup
):
dist
.
broadcast_object_list
([
handle
],
dist
.
broadcast_object_list
([
handle
],
src
=
global_ranks
[
writer_rank
],
src
=
global_ranks
[
writer_rank
],
group
=
pg
)
group
=
pg
)
else
:
else
:
pg
.
broadcast_obj
(
handle
,
writer_rank
)
else
:
if
isinstance
(
pg
,
ProcessGroup
):
recv
=
[
None
]
recv
=
[
None
]
dist
.
broadcast_object_list
(
recv
,
dist
.
broadcast_object_list
(
recv
,
src
=
global_ranks
[
writer_rank
],
src
=
global_ranks
[
writer_rank
],
group
=
pg
)
group
=
pg
)
handle
=
recv
[
0
]
# type: ignore
handle
=
recv
[
0
]
# type: ignore
else
:
handle
=
pg
.
broadcast_obj
(
None
,
writer_rank
)
buffer_io
=
MessageQueue
.
create_from_handle
(
handle
,
group_rank
)
buffer_io
=
MessageQueue
.
create_from_handle
(
handle
,
group_rank
)
buffer_io
.
wait_until_ready
()
buffer_io
.
wait_until_ready
()
return
buffer_io
return
buffer_io
vllm/distributed/parallel_state.py
View file @
62de37a3
...
@@ -37,6 +37,7 @@ from torch.distributed import Backend, ProcessGroup
...
@@ -37,6 +37,7 @@ from torch.distributed import Backend, ProcessGroup
import
vllm.distributed.kv_transfer.kv_transfer_agent
as
kv_transfer
import
vllm.distributed.kv_transfer.kv_transfer_agent
as
kv_transfer
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
,
supports_custom_op
from
vllm.utils
import
direct_register_custom_op
,
supports_custom_op
...
@@ -1191,12 +1192,14 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
...
@@ -1191,12 +1192,14 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
in_the_same_node_as
(
pg
:
ProcessGroup
,
source_rank
:
int
=
0
)
->
List
[
bool
]:
def
in_the_same_node_as
(
pg
:
Union
[
ProcessGroup
,
StatelessProcessGroup
],
source_rank
:
int
=
0
)
->
List
[
bool
]:
"""
"""
This is a collective operation that returns if each rank is in the same node
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
memory system (shared access to shared memory).
"""
"""
if
isinstance
(
pg
,
ProcessGroup
):
assert
torch
.
distributed
.
get_backend
(
assert
torch
.
distributed
.
get_backend
(
pg
)
!=
torch
.
distributed
.
Backend
.
NCCL
,
(
pg
)
!=
torch
.
distributed
.
Backend
.
NCCL
,
(
"in_the_same_node_as should be tested with a non-NCCL group."
)
"in_the_same_node_as should be tested with a non-NCCL group."
)
...
@@ -1204,11 +1207,15 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
...
@@ -1204,11 +1207,15 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
rank
=
torch
.
distributed
.
get_rank
(
group
=
pg
)
rank
=
torch
.
distributed
.
get_rank
(
group
=
pg
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
pg
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
pg
)
# local tensor in each process to store the result
is_in_the_same_node
=
torch
.
tensor
([
0
]
*
world_size
,
dtype
=
torch
.
int32
)
# global ranks of the processes in the group
# global ranks of the processes in the group
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
pg
)
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
pg
)
else
:
rank
=
pg
.
rank
world_size
=
pg
.
world_size
ranks
=
list
(
range
(
world_size
))
# local tensor in each process to store the result
is_in_the_same_node
=
torch
.
tensor
([
0
]
*
world_size
,
dtype
=
torch
.
int32
)
magic_message
=
b
"magic_message"
magic_message
=
b
"magic_message"
shm
=
None
shm
=
None
...
@@ -1219,17 +1226,21 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
...
@@ -1219,17 +1226,21 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
# create a shared memory segment
# create a shared memory segment
shm
=
shared_memory
.
SharedMemory
(
create
=
True
,
size
=
128
)
shm
=
shared_memory
.
SharedMemory
(
create
=
True
,
size
=
128
)
shm
.
buf
[:
len
(
magic_message
)]
=
magic_message
shm
.
buf
[:
len
(
magic_message
)]
=
magic_message
torch
.
distributed
.
broadcast_object_list
([
shm
.
name
],
if
isinstance
(
pg
,
ProcessGroup
):
src
=
ranks
[
source_rank
],
torch
.
distributed
.
broadcast_object_list
(
group
=
pg
)
[
shm
.
name
],
src
=
ranks
[
source_rank
],
group
=
pg
)
else
:
pg
.
broadcast_obj
(
shm
.
name
,
src
=
source_rank
)
is_in_the_same_node
[
rank
]
=
1
is_in_the_same_node
[
rank
]
=
1
else
:
else
:
# try to open the shared memory segment
# try to open the shared memory segment
if
isinstance
(
pg
,
ProcessGroup
):
recv
=
[
None
]
recv
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv
,
torch
.
distributed
.
broadcast_object_list
(
src
=
ranks
[
source_rank
],
recv
,
src
=
ranks
[
source_rank
],
group
=
pg
)
group
=
pg
)
name
=
recv
[
0
]
name
=
recv
[
0
]
else
:
name
=
pg
.
broadcast_obj
(
None
,
src
=
source_rank
)
# fix to https://stackoverflow.com/q/62748654/9191338
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
# created by the process. The following patch is a workaround.
...
@@ -1244,12 +1255,23 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
...
@@ -1244,12 +1255,23 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
if
shm
:
if
shm
:
shm
.
close
()
shm
.
close
()
if
isinstance
(
pg
,
ProcessGroup
):
torch
.
distributed
.
barrier
(
group
=
pg
)
torch
.
distributed
.
barrier
(
group
=
pg
)
else
:
pg
.
barrier
()
# clean up the shared memory segment
# clean up the shared memory segment
with
contextlib
.
suppress
(
OSError
):
with
contextlib
.
suppress
(
OSError
):
if
rank
==
source_rank
and
shm
:
if
rank
==
source_rank
and
shm
:
shm
.
unlink
()
shm
.
unlink
()
if
isinstance
(
pg
,
ProcessGroup
):
torch
.
distributed
.
all_reduce
(
is_in_the_same_node
,
group
=
pg
)
torch
.
distributed
.
all_reduce
(
is_in_the_same_node
,
group
=
pg
)
aggregated_data
=
is_in_the_same_node
else
:
aggregated_data
=
torch
.
zeros_like
(
is_in_the_same_node
)
for
i
in
range
(
world_size
):
rank_data
=
pg
.
broadcast_obj
(
is_in_the_same_node
,
src
=
i
)
aggregated_data
+=
rank_data
return
[
x
==
1
for
x
in
is_in_the_same_node
.
tolist
()]
return
[
x
==
1
for
x
in
aggregated_data
.
tolist
()]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment