Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c01a6cb2
Unverified
Commit
c01a6cb2
authored
Aug 22, 2024
by
SangBin Cho
Committed by
GitHub
Aug 22, 2024
Browse files
[Ray backend] Better error when pg topology is bad. (#7584)
Co-authored-by:
youkaichao
<
youkaichao@126.com
>
parent
b903e1ba
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
197 additions
and
9 deletions
+197
-9
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/distributed/test_multi_node_assignment.py
tests/distributed/test_multi_node_assignment.py
+64
-0
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+132
-9
No files found.
.buildkite/test-pipeline.yaml
View file @
c01a6cb2
...
@@ -293,6 +293,7 @@ steps:
...
@@ -293,6 +293,7 @@ steps:
commands
:
commands
:
-
# the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
-
# the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
-
# the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
-
# the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
...
...
tests/distributed/test_multi_node_assignment.py
0 → 100644
View file @
c01a6cb2
"""Make sure ray assigns GPU workers to the correct node.
Run:
```sh
cd $VLLM_PATH/tests
pytest distributed/test_multi_node_assignment.py
```
"""
import
os
import
pytest
import
ray
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
vllm
import
initialize_ray_cluster
from
vllm.config
import
ParallelConfig
from
vllm.executor.ray_utils
import
_wait_until_pg_removed
from
vllm.utils
import
get_ip
VLLM_MULTI_NODE
=
os
.
getenv
(
"VLLM_MULTI_NODE"
,
"0"
)
==
"1"
@
pytest
.
mark
.
skipif
(
not
VLLM_MULTI_NODE
,
reason
=
"Need at least 2 nodes to run the test."
)
def
test_multi_node_assignment
()
->
None
:
# NOTE: important to keep this class definition here
# to let ray use cloudpickle to serialize it.
class
Actor
:
def
get_ip
(
self
):
return
get_ip
()
for
_
in
range
(
10
):
config
=
ParallelConfig
(
1
,
2
)
initialize_ray_cluster
(
config
)
current_ip
=
get_ip
()
workers
=
[]
for
bundle_id
,
bundle
in
enumerate
(
config
.
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
config
.
placement_group
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
)
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
1
,
scheduling_strategy
=
scheduling_strategy
,
)(
Actor
).
remote
()
worker_ip
=
ray
.
get
(
worker
.
get_ip
.
remote
())
assert
worker_ip
==
current_ip
workers
.
append
(
worker
)
for
worker
in
workers
:
ray
.
kill
(
worker
)
_wait_until_pg_removed
(
config
.
placement_group
)
vllm/executor/ray_utils.py
View file @
c01a6cb2
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
time
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
msgspec
import
msgspec
...
@@ -11,9 +13,13 @@ from vllm.utils import get_ip, is_hip, is_xpu
...
@@ -11,9 +13,13 @@ from vllm.utils import get_ip, is_hip, is_xpu
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
PG_WAIT_TIMEOUT
=
1800
try
:
try
:
import
ray
import
ray
from
ray._private.state
import
available_resources_per_node
from
ray.util
import
placement_group_table
from
ray.util.placement_group
import
PlacementGroup
class
RayWorkerWrapper
(
WorkerWrapperBase
):
class
RayWorkerWrapper
(
WorkerWrapperBase
):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
...
@@ -98,6 +104,106 @@ def assert_ray_available():
...
@@ -98,6 +104,106 @@ def assert_ray_available():
"`pip install ray`."
)
from
ray_import_err
"`pip install ray`."
)
from
ray_import_err
def
_verify_bundles
(
placement_group
:
"PlacementGroup"
,
parallel_config
:
ParallelConfig
,
device_str
:
str
):
"""Verify a given placement group has bundles located in the right place.
There are 2 rules.
- Warn if all tensor parallel workers cannot fit in a single node.
- Fail if driver node is not included in a placement group.
"""
assert
ray
.
is_initialized
(),
(
"Ray is not initialized although distributed-executor-backend is ray."
)
pg_data
=
placement_group_table
(
placement_group
)
# bundle_idx -> node_id
bundle_to_node_ids
=
pg_data
[
"bundles_to_node_id"
]
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles
=
pg_data
[
"bundles"
]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle
:
Dict
[
str
,
List
[
Dict
[
str
,
float
]]]
=
defaultdict
(
list
)
for
bundle_idx
,
node_id
in
bundle_to_node_ids
.
items
():
node_id_to_bundle
[
node_id
].
append
(
bundles
[
bundle_idx
])
driver_node_id
=
ray
.
get_runtime_context
().
get_node_id
()
if
driver_node_id
not
in
node_id_to_bundle
:
raise
RuntimeError
(
f
"driver node id
{
driver_node_id
}
is not included in a placement "
f
"group
{
placement_group
.
id
}
. Node id -> bundles "
f
"
{
node_id_to_bundle
}
. "
"You don't have enough GPUs available in a current node. Check "
"`ray status` to see if you have available GPUs in a node "
f
"
{
driver_node_id
}
before starting an vLLM engine."
)
for
node_id
,
bundles
in
node_id_to_bundle
.
items
():
if
len
(
bundles
)
<
parallel_config
.
tensor_parallel_size
:
logger
.
warning
(
"tensor_parallel_size=%d "
"is bigger than a reserved number of %ss (%d "
"%ss) in a node %s. Tensor parallel workers can be "
"spread out to 2+ nodes which can degrade the performance "
"unless you have fast interconnect across nodes, like "
"Infiniband. To resolve this issue, make sure you have more "
"than %d GPUs available at each node."
,
parallel_config
.
tensor_parallel_size
,
device_str
,
len
(
bundles
),
device_str
,
node_id
,
parallel_config
.
tensor_parallel_size
)
def
_wait_until_pg_ready
(
current_placement_group
:
"PlacementGroup"
):
"""Wait until a placement group is ready.
It prints the informative log messages if the placement group is
not created within time.
"""
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
placement_group_specs
=
current_placement_group
.
bundle_specs
s
=
time
.
time
()
pg_ready_ref
=
current_placement_group
.
ready
()
wait_interval
=
10
while
time
.
time
()
-
s
<
PG_WAIT_TIMEOUT
:
ready
,
_
=
ray
.
wait
([
pg_ready_ref
],
timeout
=
wait_interval
)
if
len
(
ready
)
>
0
:
break
# Exponential backoff for warning print.
wait_interval
*=
2
logger
.
info
(
"Waiting for creating a placement group of specs for "
"%d seconds. specs=%s. Check "
"`ray status` to see if you have enough resources."
,
int
(
time
.
time
()
-
s
),
placement_group_specs
)
try
:
ray
.
get
(
pg_ready_ref
,
timeout
=
0
)
except
ray
.
exceptions
.
GetTimeoutError
:
raise
ValueError
(
"Cannot provide a placement group of "
f
"
{
placement_group_specs
=
}
within
{
PG_WAIT_TIMEOUT
}
seconds. See "
"`ray status` to make sure the cluster has enough resources."
)
from
None
def
_wait_until_pg_removed
(
current_placement_group
:
"PlacementGroup"
):
ray
.
util
.
remove_placement_group
(
current_placement_group
)
s
=
time
.
time
()
wait_interval
=
10
while
time
.
time
()
-
s
<
PG_WAIT_TIMEOUT
:
pg
=
ray
.
util
.
get_current_placement_group
()
if
pg
is
None
:
break
# Exponential backoff for warning print.
wait_interval
*=
2
logger
.
info
(
"Waiting for removing a placement group of specs for "
"%d seconds."
,
int
(
time
.
time
()
-
s
))
time
.
sleep
(
wait_interval
)
def
initialize_ray_cluster
(
def
initialize_ray_cluster
(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
ray_address
:
Optional
[
str
]
=
None
,
ray_address
:
Optional
[
str
]
=
None
,
...
@@ -156,15 +262,32 @@ def initialize_ray_cluster(
...
@@ -156,15 +262,32 @@ def initialize_ray_cluster(
f
"The number of required
{
device_str
}
s exceeds the total "
f
"The number of required
{
device_str
}
s exceeds the total "
f
"number of available
{
device_str
}
s in the placement group."
)
f
"number of available
{
device_str
}
s in the placement group."
)
# Create a new placement group
# Create a new placement group
placement_group_specs
=
([{
placement_group_specs
:
List
[
Dict
[
str
,
float
]]
=
([{
device_str
:
1
device_str
:
1.0
}]
*
parallel_config
.
world_size
)
}
for
_
in
range
(
parallel_config
.
world_size
)])
# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
# the current node has at least one device.
current_ip
=
get_ip
()
current_node_id
=
ray
.
get_runtime_context
().
get_node_id
()
current_node_resource
=
available_resources_per_node
()[
current_node_id
]
if
current_node_resource
.
get
(
device_str
,
0
)
<
1
:
raise
ValueError
(
f
"Current node has no
{
device_str
}
available. "
f
"
{
current_node_resource
=
}
. vLLM engine cannot start without "
f
"
{
device_str
}
. Make sure you have at least 1
{
device_str
}
"
f
"available in a node
{
current_node_id
=
}
{
current_ip
=
}
."
)
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs
[
0
][
f
"node:
{
current_ip
}
"
]
=
0.001
# By default, Ray packs resources as much as possible.
current_placement_group
=
ray
.
util
.
placement_group
(
current_placement_group
=
ray
.
util
.
placement_group
(
placement_group_specs
)
placement_group_specs
,
strategy
=
"PACK"
)
# Wait until PG is ready - this will block until all
_wait_until_pg_ready
(
current_placement_group
)
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray
.
get
(
current_placement_group
.
ready
(),
timeout
=
1800
)
assert
current_placement_group
is
not
None
_verify_bundles
(
current_placement_group
,
parallel_config
,
device_str
)
# Set the placement group in the parallel config
# Set the placement group in the parallel config
parallel_config
.
placement_group
=
current_placement_group
parallel_config
.
placement_group
=
current_placement_group
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment