Unverified Commit 2148441f authored by Richard Liu's avatar Richard Liu Committed by GitHub
Browse files

[TPU] Support single and multi-host TPUs on GKE (#7613)

parent dc13e993
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
# Dependencies for TPU # Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA. # Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu. # You can install the dependencies in Dockerfile.tpu.
ray ray[default]
...@@ -123,7 +123,10 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -123,7 +123,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("TPU version must be 4 or higher.") raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE")
tpu_type = tpu_type.lower()
if "lite" not in tpu_type: if "lite" not in tpu_type:
if self.num_kv_heads % 2 == 0: if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head" self.megacore_mode = "kv_head"
......
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -5,11 +7,12 @@ from torch.distributed import ProcessGroup ...@@ -5,11 +7,12 @@ from torch.distributed import ProcessGroup
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_tpu(): if current_platform.is_tpu():
import ray
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
from torch_xla._internal import pjrt from torch_xla._internal import pjrt
from vllm.executor import ray_utils
class TpuCommunicator: class TpuCommunicator:
...@@ -24,9 +27,29 @@ class TpuCommunicator: ...@@ -24,9 +27,29 @@ class TpuCommunicator:
# be simply calculated as follows. # be simply calculated as follows.
global_rank = dist.get_rank(group) global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group) global_world_size = dist.get_world_size(group)
num_nodes = len(ray.nodes())
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
# nodes is computed by the total number of TPUs divided by the
# number of TPU accelerators per node, to account for clusters
# with both CPUs and TPUs.
num_nodes = ray_utils.get_num_tpu_nodes()
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
if num_nodes_in_pg > 0:
num_nodes = num_nodes_in_pg
local_world_size = global_world_size // num_nodes local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size local_rank = global_rank % local_world_size
# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# chip is actually visible. Otherwise the TPU driver will fail to
# initialize because the number of devices would be different from
# the number of visible worker addresses.
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
pjrt.initialize_multiprocess(local_rank, local_world_size) pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal() xr._init_world_size_ordinal()
......
...@@ -71,6 +71,19 @@ class RayTPUExecutor(TPUExecutor): ...@@ -71,6 +71,19 @@ class RayTPUExecutor(TPUExecutor):
worker_module_name = "vllm.worker.tpu_worker" worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker" worker_class_name = "TPUWorker"
# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# need to override the Ray environment variables manually.
override_env = {}
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
override_env.update({
"TPU_CHIPS_PER_HOST_BOUNDS":
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
})
if "TPU_HOST_BOUNDS" in os.environ:
override_env.update(
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
worker = ray.remote( worker = ray.remote(
num_cpus=0, num_cpus=0,
resources={"TPU": 1}, resources={"TPU": 1},
...@@ -81,6 +94,8 @@ class RayTPUExecutor(TPUExecutor): ...@@ -81,6 +94,8 @@ class RayTPUExecutor(TPUExecutor):
worker_class_name=worker_class_name, worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
if override_env:
worker.override_env_vars.remote(override_env)
worker_ip = ray.get(worker.get_node_ip.remote()) worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None: if worker_ip == driver_ip and self.driver_dummy_worker is None:
......
import os
import time import time
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
...@@ -84,6 +85,9 @@ try: ...@@ -84,6 +85,9 @@ try:
return output return output
def override_env_vars(self, vars: Dict[str, str]):
os.environ.update(vars)
ray_import_err = None ray_import_err = None
except ImportError as e: except ImportError as e:
...@@ -291,3 +295,28 @@ def initialize_ray_cluster( ...@@ -291,3 +295,28 @@ def initialize_ray_cluster(
_verify_bundles(current_placement_group, parallel_config, device_str) _verify_bundles(current_placement_group, parallel_config, device_str)
# Set the placement group in the parallel config # Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group parallel_config.placement_group = current_placement_group
def get_num_tpu_nodes() -> int:
from ray._private.accelerators import TPUAcceleratorManager
cluster_resources = ray.cluster_resources()
total_tpus = int(cluster_resources["TPU"])
tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
assert total_tpus % tpus_per_node == 0
return total_tpus // tpus_per_node
def get_num_nodes_in_placement_group() -> int:
pg_table = ray.util.placement_group_table()
current_pg = ray.util.get_current_placement_group()
num_nodes = 0
if current_pg:
nodes_in_pg = set()
for pg_key, pg in pg_table.items():
if pg_key == current_pg.id.hex():
for _, node in pg["bundles_to_node_id"].items():
nodes_in_pg.add(node)
num_nodes = len(nodes_in_pg)
return num_nodes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment