Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c4bd03c7
Unverified
Commit
c4bd03c7
authored
Jun 11, 2024
by
youkaichao
Committed by
GitHub
Jun 11, 2024
Browse files
[Core][Distributed] add same-node detection (#5369)
parent
dcbf4286
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
87 additions
and
1 deletion
+87
-1
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/distributed/test_same_node.py
tests/distributed/test_same_node.py
+11
-0
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+8
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+67
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
c4bd03c7
...
...
@@ -37,6 +37,7 @@ steps:
working_dir
:
"
/vllm-workspace/tests"
num_gpus
:
2
commands
:
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
-
TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
...
...
tests/distributed/test_same_node.py
0 → 100644
View file @
c4bd03c7
import
os
import
torch
from
vllm.distributed.parallel_state
import
is_in_the_same_node
torch
.
distributed
.
init_process_group
(
backend
=
"gloo"
)
test_result
=
is_in_the_same_node
(
torch
.
distributed
.
group
.
WORLD
)
expected
=
os
.
environ
.
get
(
"VLLM_TEST_SAME_HOST"
,
"1"
)
==
"1"
assert
test_result
==
expected
,
f
"Expected
{
expected
}
, got
{
test_result
}
"
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
c4bd03c7
...
...
@@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
from
vllm.distributed.device_communicators.custom_all_reduce_utils
import
(
gpu_p2p_access_check
)
from
vllm.distributed.parallel_state
import
(
get_local_rank
,
get_tensor_model_parallel_cpu_group
)
get_local_rank
,
get_tensor_model_parallel_cpu_group
,
is_in_the_same_node
)
from
vllm.logger
import
init_logger
try
:
...
...
@@ -113,6 +113,13 @@ class CustomAllreduce:
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"CustomAllreduce should be attached to a non-NCCL group."
)
if
not
is_in_the_same_node
(
group
):
# No need to initialize custom allreduce for multi-node case.
logger
.
warning
(
"Custom allreduce is disabled because this process group"
" spans across nodes."
)
return
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
if
world_size
==
1
:
...
...
vllm/distributed/parallel_state.py
View file @
c4bd03c7
...
...
@@ -3,6 +3,8 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import
contextlib
from
multiprocessing
import
resource_tracker
,
shared_memory
from
typing
import
List
,
Optional
import
torch
...
...
@@ -376,3 +378,68 @@ def destroy_model_parallel():
_PP_DEVICE_GROUP
=
None
global
_PP_GLOBAL_RANKS
_PP_GLOBAL_RANKS
=
None
def
is_in_the_same_node
(
pg
:
ProcessGroup
):
"""
This is a collective operation that checks if all processes in the group
are in the same node. It tests if all processes are attached to the same
memory system (shared access to shared memory).
"""
assert
torch
.
distributed
.
get_backend
(
pg
)
!=
torch
.
distributed
.
Backend
.
NCCL
,
(
"is_in_the_same_node should be tested with a non-NCCL group."
)
# local rank inside the group
rank
=
torch
.
distributed
.
get_rank
(
group
=
pg
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
pg
)
# local tensor in each process to store the result
is_in_the_same_node
=
torch
.
tensor
([
0
]
*
world_size
,
dtype
=
torch
.
int32
)
# global ranks of the processes in the group
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
pg
)
magic_message
=
b
"magic_message"
shm
=
None
try
:
with
contextlib
.
suppress
(
OSError
):
if
rank
==
0
:
# create a shared memory segment
shm
=
shared_memory
.
SharedMemory
(
create
=
True
,
size
=
128
)
shm
.
buf
[:
len
(
magic_message
)]
=
magic_message
torch
.
distributed
.
broadcast_object_list
([
shm
.
name
],
src
=
ranks
[
0
],
group
=
pg
)
is_in_the_same_node
[
0
]
=
1
else
:
# try to open the shared memory segment
recv
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv
,
src
=
ranks
[
0
],
group
=
pg
)
name
=
recv
[
0
]
shm
=
shared_memory
.
SharedMemory
(
name
=
name
)
if
shm
.
buf
[:
len
(
magic_message
)]
==
magic_message
:
is_in_the_same_node
[
rank
]
=
1
except
Exception
as
e
:
logger
.
error
(
"Error ignored in is_in_the_same_node: %s"
,
e
)
finally
:
if
shm
:
shm
.
close
()
torch
.
distributed
.
barrier
(
group
=
pg
)
# clean up the shared memory segment
with
contextlib
.
suppress
(
OSError
):
if
rank
==
0
:
if
shm
:
shm
.
unlink
()
else
:
if
shm
:
# fix to https://stackoverflow.com/q/62748654/9191338
resource_tracker
.
unregister
(
shm
.
_name
,
"shared_memory"
)
# type: ignore[attr-defined]
torch
.
distributed
.
all_reduce
(
is_in_the_same_node
,
group
=
pg
)
return
is_in_the_same_node
.
sum
().
item
()
==
world_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment