Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8a7fe47d
Unverified
Commit
8a7fe47d
authored
Nov 11, 2024
by
youkaichao
Committed by
GitHub
Nov 11, 2024
Browse files
[misc][distributed] auto port selection and disable tests (#10226)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
4800339c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
16 deletions
+23
-16
tests/distributed/test_utils.py
tests/distributed/test_utils.py
+23
-16
No files found.
tests/distributed/test_utils.py
View file @
8a7fe47d
import
socket
import
pytest
import
ray
import
torch
...
...
@@ -5,7 +7,7 @@ import torch
import
vllm.envs
as
envs
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.utils
import
(
cuda_device_count_stateless
,
from
vllm.utils
import
(
cuda_device_count_stateless
,
get_open_port
,
update_environment_variables
)
from
..utils
import
multi_gpu_test
...
...
@@ -40,14 +42,13 @@ def test_cuda_device_count_stateless():
assert
ray
.
get
(
actor
.
get_count
.
remote
())
==
0
def
cpu_worker
(
rank
,
WORLD_SIZE
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
"tcp://127.0.0.1:
29500
"
,
def
cpu_worker
(
rank
,
WORLD_SIZE
,
port1
,
port2
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port1
}
"
,
rank
=
rank
,
world_size
=
WORLD_SIZE
)
if
rank
<=
2
:
pg2
=
StatelessProcessGroup
.
create
(
init_method
=
"tcp://127.0.0.1:29501"
,
rank
=
rank
,
world_size
=
3
)
pg2
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port2
}
"
,
rank
=
rank
,
world_size
=
3
)
data
=
torch
.
tensor
([
rank
])
data
=
pg1
.
broadcast_obj
(
data
,
src
=
2
)
assert
data
.
item
()
==
2
...
...
@@ -59,17 +60,16 @@ def cpu_worker(rank, WORLD_SIZE):
pg1
.
barrier
()
def
gpu_worker
(
rank
,
WORLD_SIZE
):
def
gpu_worker
(
rank
,
WORLD_SIZE
,
port1
,
port2
):
torch
.
cuda
.
set_device
(
rank
)
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
"tcp://127.0.0.1:
29502
"
,
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port1
}
"
,
rank
=
rank
,
world_size
=
WORLD_SIZE
)
pynccl1
=
PyNcclCommunicator
(
pg1
,
device
=
rank
)
pynccl1
.
disabled
=
False
if
rank
<=
2
:
pg2
=
StatelessProcessGroup
.
create
(
init_method
=
"tcp://127.0.0.1:29503"
,
rank
=
rank
,
world_size
=
3
)
pg2
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port2
}
"
,
rank
=
rank
,
world_size
=
3
)
pynccl2
=
PyNcclCommunicator
(
pg2
,
device
=
rank
)
pynccl2
.
disabled
=
False
data
=
torch
.
tensor
([
rank
]).
cuda
()
...
...
@@ -88,8 +88,8 @@ def gpu_worker(rank, WORLD_SIZE):
assert
item
==
18
def
broadcast_worker
(
rank
,
WORLD_SIZE
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
"tcp://127.0.0.1:
29504
"
,
def
broadcast_worker
(
rank
,
WORLD_SIZE
,
port1
,
port2
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port1
}
"
,
rank
=
rank
,
world_size
=
WORLD_SIZE
)
if
rank
==
2
:
...
...
@@ -100,8 +100,8 @@ def broadcast_worker(rank, WORLD_SIZE):
pg1
.
barrier
()
def
allgather_worker
(
rank
,
WORLD_SIZE
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
"tcp://127.0.0.1:
29505
"
,
def
allgather_worker
(
rank
,
WORLD_SIZE
,
port1
,
port2
):
pg1
=
StatelessProcessGroup
.
create
(
init_method
=
f
"tcp://127.0.0.1:
{
port1
}
"
,
rank
=
rank
,
world_size
=
WORLD_SIZE
)
data
=
pg1
.
all_gather_obj
(
rank
)
...
...
@@ -109,17 +109,24 @@ def allgather_worker(rank, WORLD_SIZE):
pg1
.
barrier
()
# TODO: investigate why this test is flaky. It hangs during initialization.
@
pytest
.
mark
.
skip
(
"Skip the test because it is flaky."
)
@
multi_gpu_test
(
num_gpus
=
4
)
@
pytest
.
mark
.
parametrize
(
"worker"
,
[
cpu_worker
,
gpu_worker
,
broadcast_worker
,
allgather_worker
])
def
test_stateless_process_group
(
worker
):
port1
=
get_open_port
()
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
port1
))
port2
=
get_open_port
()
WORLD_SIZE
=
4
from
multiprocessing
import
get_context
ctx
=
get_context
(
"fork"
)
processes
=
[]
for
i
in
range
(
WORLD_SIZE
):
rank
=
i
processes
.
append
(
ctx
.
Process
(
target
=
worker
,
args
=
(
rank
,
WORLD_SIZE
)))
processes
.
append
(
ctx
.
Process
(
target
=
worker
,
args
=
(
rank
,
WORLD_SIZE
,
port1
,
port2
)))
for
p
in
processes
:
p
.
start
()
for
p
in
processes
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment