Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c40692bf
Unverified
Commit
c40692bf
authored
Jun 25, 2025
by
Nick Hill
Committed by
GitHub
Jun 25, 2025
Browse files
[Misc] Add parallel state `node_count` function (#20045)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
4734704b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
98 additions
and
2 deletions
+98
-2
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-0
tests/distributed/test_node_count.py
tests/distributed/test_node_count.py
+43
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+53
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
c40692bf
...
@@ -619,11 +619,13 @@ steps:
...
@@ -619,11 +619,13 @@ steps:
commands
:
commands
:
-
# the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
-
# the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
-
NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
-
python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
-
python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
-
# the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
-
# the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
-
NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
-
python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
-
python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
-
label
:
Distributed Tests (2 GPUs)
# 40min
-
label
:
Distributed Tests (2 GPUs)
# 40min
...
...
tests/distributed/test_node_count.py
0 → 100644
View file @
c40692bf
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
torch.distributed
as
dist
from
vllm.distributed.parallel_state
import
_node_count
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.utils
import
get_ip
,
get_open_port
if
__name__
==
"__main__"
:
dist
.
init_process_group
(
backend
=
"gloo"
)
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
if
rank
==
0
:
port
=
get_open_port
()
ip
=
get_ip
()
dist
.
broadcast_object_list
([
ip
,
port
],
src
=
0
)
else
:
recv
=
[
None
,
None
]
dist
.
broadcast_object_list
(
recv
,
src
=
0
)
ip
,
port
=
recv
stateless_pg
=
StatelessProcessGroup
.
create
(
ip
,
port
,
rank
,
world_size
)
for
pg
in
[
dist
.
group
.
WORLD
,
stateless_pg
]:
test_result
=
_node_count
(
pg
)
# Expected node count based on environment variable)
expected
=
int
(
os
.
environ
.
get
(
"NUM_NODES"
,
"1"
))
assert
test_result
==
expected
,
\
f
"Expected
{
expected
}
nodes, got
{
test_result
}
"
if
pg
==
dist
.
group
.
WORLD
:
print
(
f
"Node count test passed! Got
{
test_result
}
nodes "
f
"when using torch distributed!"
)
else
:
print
(
f
"Node count test passed! Got
{
test_result
}
nodes "
f
"when using StatelessProcessGroup!"
)
vllm/distributed/parallel_state.py
View file @
c40692bf
...
@@ -802,6 +802,7 @@ class GroupCoordinator:
...
@@ -802,6 +802,7 @@ class GroupCoordinator:
_WORLD
:
Optional
[
GroupCoordinator
]
=
None
_WORLD
:
Optional
[
GroupCoordinator
]
=
None
_NODE_COUNT
:
Optional
[
int
]
=
None
def
get_world_group
()
->
GroupCoordinator
:
def
get_world_group
()
->
GroupCoordinator
:
...
@@ -961,10 +962,13 @@ def init_distributed_environment(
...
@@ -961,10 +962,13 @@ def init_distributed_environment(
local_rank
=
envs
.
LOCAL_RANK
local_rank
=
envs
.
LOCAL_RANK
else
:
else
:
local_rank
=
rank
local_rank
=
rank
global
_WORLD
global
_WORLD
,
_NODE_COUNT
if
_WORLD
is
None
:
if
_WORLD
is
None
:
ranks
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
ranks
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
_WORLD
=
init_world_group
(
ranks
,
local_rank
,
backend
)
_WORLD
=
init_world_group
(
ranks
,
local_rank
,
backend
)
_NODE_COUNT
=
_node_count
(
_WORLD
.
cpu_group
)
logger
.
debug
(
"Detected %d nodes in the distributed environment"
,
_NODE_COUNT
)
else
:
else
:
assert
_WORLD
.
world_size
==
torch
.
distributed
.
get_world_size
(),
(
assert
_WORLD
.
world_size
==
torch
.
distributed
.
get_world_size
(),
(
"world group already initialized with a different world size"
)
"world group already initialized with a different world size"
)
...
@@ -1164,6 +1168,13 @@ def get_tensor_model_parallel_rank():
...
@@ -1164,6 +1168,13 @@ def get_tensor_model_parallel_rank():
return
get_tp_group
().
rank_in_group
return
get_tp_group
().
rank_in_group
def
get_node_count
()
->
int
:
"""Return the total number of nodes in the distributed environment. """
assert
_NODE_COUNT
is
not
None
,
(
"distributed environment is not initialized"
)
return
_NODE_COUNT
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none and destroy them."""
"""Set the groups to none and destroy them."""
global
_TP
global
_TP
...
@@ -1189,10 +1200,11 @@ def destroy_model_parallel():
...
@@ -1189,10 +1200,11 @@ def destroy_model_parallel():
def
destroy_distributed_environment
():
def
destroy_distributed_environment
():
global
_WORLD
global
_WORLD
,
_NODE_COUNT
if
_WORLD
:
if
_WORLD
:
_WORLD
.
destroy
()
_WORLD
.
destroy
()
_WORLD
=
None
_WORLD
=
None
_NODE_COUNT
=
None
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
destroy_process_group
()
torch
.
distributed
.
destroy_process_group
()
...
@@ -1301,3 +1313,42 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
...
@@ -1301,3 +1313,42 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
aggregated_data
+=
rank_data
aggregated_data
+=
rank_data
return
[
x
==
1
for
x
in
aggregated_data
.
tolist
()]
return
[
x
==
1
for
x
in
aggregated_data
.
tolist
()]
def
_node_count
(
pg
:
Union
[
ProcessGroup
,
StatelessProcessGroup
])
->
int
:
"""
Returns the total number of nodes in the process group.
Args:
pg: The process group to analyze
Returns:
int: The total number of nodes
"""
if
isinstance
(
pg
,
ProcessGroup
):
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
pg
)
else
:
world_size
=
pg
.
world_size
if
world_size
==
1
:
return
1
# Build node assignment map
node_assignment
=
[
0
]
*
world_size
# rank -> node_id
next_node_id
=
0
for
current_rank
in
range
(
world_size
):
if
node_assignment
[
current_rank
]
!=
0
:
continue
# Already assigned to a node
# Assign current rank to a new node
next_node_id
+=
1
node_assignment
[
current_rank
]
=
next_node_id
# Find all ranks on the same node as current_rank
same_node_flags
=
in_the_same_node_as
(
pg
,
current_rank
)
for
other_rank
,
is_same_node
in
enumerate
(
same_node_flags
):
if
is_same_node
and
node_assignment
[
other_rank
]
==
0
:
node_assignment
[
other_rank
]
=
next_node_id
return
next_node_id
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment