Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aa486779
Unverified
Commit
aa486779
authored
Jul 26, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 26, 2024
Browse files
[Misc][TPU] Support TPU in initialize_ray_cluster (#6812)
parent
71734f1b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
15 deletions
+21
-15
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+21
-15
No files found.
vllm/executor/ray_utils.py
View file @
aa486779
...
...
@@ -3,7 +3,7 @@ from typing import List, Optional, Tuple
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
get_ip
,
is_hip
,
is_xpu
from
vllm.utils
import
get_ip
,
is_hip
,
is_tpu
,
is_xpu
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
...
...
@@ -93,32 +93,38 @@ def initialize_ray_cluster(
# Placement group is already set.
return
device_str
=
"GPU"
if
not
is_tpu
()
else
"TPU"
# Create placement group for worker processes
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
if
current_placement_group
:
# We are in a placement group
bundles
=
current_placement_group
.
bundle_specs
# Verify that we can use the placement group.
gpu
_bundles
=
0
device
_bundles
=
0
for
bundle
in
bundles
:
bundle_
gpu
s
=
bundle
.
get
(
"GPU"
,
0
)
if
bundle_
gpu
s
>
1
:
bundle_
device
s
=
bundle
.
get
(
device_str
,
0
)
if
bundle_
device
s
>
1
:
raise
ValueError
(
"Placement group bundle cannot have more than 1 GPU."
)
if
bundle_gpus
:
gpu_bundles
+=
1
if
parallel_config
.
world_size
>
gpu_bundles
:
"Placement group bundle cannot have more than 1 "
f
"
{
device_str
}
."
)
if
bundle_devices
:
device_bundles
+=
1
if
parallel_config
.
world_size
>
device_bundles
:
raise
ValueError
(
"The number of required GPUs exceeds the total number of "
"available GPUs in the placement group."
)
f
"The number of required
{
device_str
}
s exceeds the total "
f
"number of available
{
device_str
}
s in the placement group."
f
"Required number of devices:
{
parallel_config
.
world_size
}
. "
f
"Total number of devices:
{
device_bundles
}
."
)
else
:
num_
gpu
s_in_cluster
=
ray
.
cluster_resources
().
get
(
"GPU"
,
0
)
if
parallel_config
.
world_size
>
num_
gpu
s_in_cluster
:
num_
device
s_in_cluster
=
ray
.
cluster_resources
().
get
(
device_str
,
0
)
if
parallel_config
.
world_size
>
num_
device
s_in_cluster
:
raise
ValueError
(
"The number of required
GPU
s exceeds the total
number of
"
"available GPUs in the cluster
."
)
f
"The number of required
{
device_str
}
s exceeds the total "
f
"number of available
{
device_str
}
s in the placement group
."
)
# Create a new placement group
placement_group_specs
=
([{
"GPU"
:
1
}]
*
parallel_config
.
world_size
)
placement_group_specs
=
([{
device_str
:
1
}]
*
parallel_config
.
world_size
)
current_placement_group
=
ray
.
util
.
placement_group
(
placement_group_specs
)
# Wait until PG is ready - this will block until all
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment