Unverified Commit a08df832 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[TPU] Support multi-host inference (#7457)

parent 16422ea7
...@@ -8,7 +8,7 @@ vLLM supports Google Cloud TPUs using PyTorch XLA. ...@@ -8,7 +8,7 @@ vLLM supports Google Cloud TPUs using PyTorch XLA.
Requirements Requirements
------------ ------------
* Google Cloud TPU VM (single host) * Google Cloud TPU VM (single & multi host)
* TPU versions: v5e, v5p, v4 * TPU versions: v5e, v5p, v4
* Python: 3.10 * Python: 3.10
......
import ray
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -18,9 +19,15 @@ class TpuCommunicator: ...@@ -18,9 +19,15 @@ class TpuCommunicator:
return return
self.disabled = False self.disabled = False
local_rank = dist.get_rank(group) # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
world_size = dist.get_world_size(group) # must be used together. Therefore, the local rank and world size can
pjrt.initialize_multiprocess(local_rank, world_size) # be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
num_nodes = len(ray.nodes())
local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal() xr._init_world_size_ordinal()
def all_reduce(self, x: torch.Tensor) -> torch.Tensor: def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment