Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2148441f
"vscode:/vscode.git/clone" did not exist on "2b30afa4420cbada6dd9084de3ee7eb19142b7ff"
Unverified
Commit
2148441f
authored
Aug 30, 2024
by
Richard Liu
Committed by
GitHub
Aug 30, 2024
Browse files
[TPU] Support single and multi-host TPUs on GKE (#7613)
parent
dc13e993
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
74 additions
and
4 deletions
+74
-4
requirements-tpu.txt
requirements-tpu.txt
+1
-1
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+4
-1
vllm/distributed/device_communicators/tpu_communicator.py
vllm/distributed/device_communicators/tpu_communicator.py
+25
-2
vllm/executor/ray_tpu_executor.py
vllm/executor/ray_tpu_executor.py
+15
-0
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+29
-0
No files found.
requirements-tpu.txt
View file @
2148441f
...
...
@@ -4,4 +4,4 @@
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
ray
ray
[default]
vllm/attention/backends/pallas.py
View file @
2148441f
...
...
@@ -123,7 +123,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
self
.
megacore_mode
=
None
tpu_type
=
torch_xla
.
tpu
.
get_tpu_env
()[
"TYPE"
].
lower
()
tpu_env
=
torch_xla
.
tpu
.
get_tpu_env
()
tpu_type
=
tpu_env
.
get
(
"TYPE"
)
or
tpu_env
.
get
(
"ACCELERATOR_TYPE"
)
tpu_type
=
tpu_type
.
lower
()
if
"lite"
not
in
tpu_type
:
if
self
.
num_kv_heads
%
2
==
0
:
self
.
megacore_mode
=
"kv_head"
...
...
vllm/distributed/device_communicators/tpu_communicator.py
View file @
2148441f
import
os
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
...
...
@@ -5,11 +7,12 @@ from torch.distributed import ProcessGroup
from
vllm.platforms
import
current_platform
if
current_platform
.
is_tpu
():
import
ray
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
from
torch_xla._internal
import
pjrt
from
vllm.executor
import
ray_utils
class
TpuCommunicator
:
...
...
@@ -24,9 +27,29 @@ class TpuCommunicator:
# be simply calculated as follows.
global_rank
=
dist
.
get_rank
(
group
)
global_world_size
=
dist
.
get_world_size
(
group
)
num_nodes
=
len
(
ray
.
nodes
())
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
# nodes is computed by the total number of TPUs divided by the
# number of TPU accelerators per node, to account for clusters
# with both CPUs and TPUs.
num_nodes
=
ray_utils
.
get_num_tpu_nodes
()
num_nodes_in_pg
=
ray_utils
.
get_num_nodes_in_placement_group
()
if
num_nodes_in_pg
>
0
:
num_nodes
=
num_nodes_in_pg
local_world_size
=
global_world_size
//
num_nodes
local_rank
=
global_rank
%
local_world_size
# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# chip is actually visible. Otherwise the TPU driver will fail to
# initialize because the number of devices would be different from
# the number of visible worker addresses.
os
.
environ
[
"CLOUD_TPU_TASK_ID"
]
=
str
(
global_rank
)
os
.
environ
[
"TPU_VISIBLE_CHIPS"
]
=
str
(
local_rank
)
pjrt
.
initialize_multiprocess
(
local_rank
,
local_world_size
)
xr
.
_init_world_size_ordinal
()
...
...
vllm/executor/ray_tpu_executor.py
View file @
2148441f
...
...
@@ -71,6 +71,19 @@ class RayTPUExecutor(TPUExecutor):
worker_module_name
=
"vllm.worker.tpu_worker"
worker_class_name
=
"TPUWorker"
# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# need to override the Ray environment variables manually.
override_env
=
{}
if
"TPU_CHIPS_PER_HOST_BOUNDS"
in
os
.
environ
:
override_env
.
update
({
"TPU_CHIPS_PER_HOST_BOUNDS"
:
os
.
environ
[
"TPU_CHIPS_PER_HOST_BOUNDS"
]
})
if
"TPU_HOST_BOUNDS"
in
os
.
environ
:
override_env
.
update
(
{
"TPU_HOST_BOUNDS"
:
os
.
environ
[
"TPU_HOST_BOUNDS"
]})
worker
=
ray
.
remote
(
num_cpus
=
0
,
resources
=
{
"TPU"
:
1
},
...
...
@@ -81,6 +94,8 @@ class RayTPUExecutor(TPUExecutor):
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
if
override_env
:
worker
.
override_env_vars
.
remote
(
override_env
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
...
...
vllm/executor/ray_utils.py
View file @
2148441f
import
os
import
time
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
...
...
@@ -84,6 +85,9 @@ try:
return
output
def
override_env_vars
(
self
,
vars
:
Dict
[
str
,
str
]):
os
.
environ
.
update
(
vars
)
ray_import_err
=
None
except
ImportError
as
e
:
...
...
@@ -291,3 +295,28 @@ def initialize_ray_cluster(
_verify_bundles
(
current_placement_group
,
parallel_config
,
device_str
)
# Set the placement group in the parallel config
parallel_config
.
placement_group
=
current_placement_group
def
get_num_tpu_nodes
()
->
int
:
from
ray._private.accelerators
import
TPUAcceleratorManager
cluster_resources
=
ray
.
cluster_resources
()
total_tpus
=
int
(
cluster_resources
[
"TPU"
])
tpus_per_node
=
TPUAcceleratorManager
.
get_current_node_num_accelerators
()
assert
total_tpus
%
tpus_per_node
==
0
return
total_tpus
//
tpus_per_node
def
get_num_nodes_in_placement_group
()
->
int
:
pg_table
=
ray
.
util
.
placement_group_table
()
current_pg
=
ray
.
util
.
get_current_placement_group
()
num_nodes
=
0
if
current_pg
:
nodes_in_pg
=
set
()
for
pg_key
,
pg
in
pg_table
.
items
():
if
pg_key
==
current_pg
.
id
.
hex
():
for
_
,
node
in
pg
[
"bundles_to_node_id"
].
items
():
nodes_in_pg
.
add
(
node
)
num_nodes
=
len
(
nodes_in_pg
)
return
num_nodes
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment