Unverified Commit cfbca8a2 authored by Alexander Matveev's avatar Alexander Matveev Committed by GitHub
Browse files

[V1] TPU - Tensor parallel MP support (#15059)

parent 0fe56098
......@@ -1473,7 +1473,7 @@ class ParallelConfig:
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
ray_only_devices = ["tpu"]
ray_only_devices: list[str] = []
from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices
and self.world_size > 1):
......
......@@ -6,15 +6,24 @@ from typing import Optional
import torch
from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
USE_RAY = parallel_config = get_current_vllm_config(
).parallel_config.distributed_executor_backend == "ray"
logger = init_logger(__name__)
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
if USE_RAY:
from vllm.executor import ray_utils
......@@ -33,6 +42,8 @@ class TpuCommunicator(DeviceCommunicatorBase):
global_rank = self.global_rank
global_world_size = self.global_world_size
if USE_RAY:
logger.info("TpuCommunicator initialized with RAY")
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
......@@ -46,6 +57,17 @@ class TpuCommunicator(DeviceCommunicatorBase):
local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
else:
logger.info("TpuCommunicator initialized with MP")
# Sanity: Verify we run on a single host
num_hosts = torch_xla.tpu.num_tpu_workers()
assert num_hosts == 1
# Get the current number of TPUs (we have locally)
local_world_size = torch_xla.tpu.num_available_chips()
# Get current rank
local_rank = global_rank % local_world_size
# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment