Unverified Commit 4634c872 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[TPU] Refactor TPU worker & model runner (#6506)

parent c8a7d51c
This diff is collapsed.
...@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
logger = init_logger(__name__) logger = init_logger(__name__)
class TPUWorker(LoraNotSupportedWorkerBase): class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__( def __init__(
self, self,
...@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype] self.cache_config.cache_dtype]
self.model_runner = TPUModelRunner(model_config, self.model_runner: TPUModelRunner = TPUModelRunner(
parallel_config, model_config,
scheduler_config, parallel_config,
device_config, scheduler_config,
cache_config, device_config,
load_config, cache_config,
multimodal_config, load_config,
is_driver_worker=is_driver_worker) multimodal_config,
is_driver_worker=is_driver_worker)
def init_device(self) -> None: def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU" os.environ["PJRT_DEVICE"] = "TPU"
...@@ -196,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -196,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase):
dtype_size = get_dtype_size(self.cache_dtype) dtype_size = get_dtype_size(self.cache_dtype)
return dtype_size * total return dtype_size * total
def execute_model( @property
def do_metadata_broadcast(self) -> bool:
# TODO(woosuk): Support TP.
return False
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return [self.tpu_cache]
def prepare_worker_input(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None, execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]: ) -> WorkerInput:
if not self.is_driver_worker: virtual_engine = execute_model_req.virtual_engine
self._execute_model_non_driver() num_seq_groups = len(execute_model_req.seq_group_metadata_list)
return [] blocks_to_swap_in = _make_src_to_dst(
assert execute_model_req is not None execute_model_req.blocks_to_swap_in, "cpu", self.device)
# Issue cache operations. blocks_to_swap_out = _make_src_to_dst(
self.cache_swap( execute_model_req.blocks_to_swap_out, self.device, "cpu")
execute_model_req.blocks_to_swap_in, blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
execute_model_req.blocks_to_swap_out, self.device, self.device)
execute_model_req.blocks_to_copy, return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
) )
# Run the model.
seq_group_metadata_list = execute_model_req.seq_group_metadata_list def execute_worker(self, worker_input: WorkerInput) -> None:
assert len(seq_group_metadata_list) > 0 virtual_engine = worker_input.virtual_engine
output = self.model_runner.execute_model(seq_group_metadata_list, assert virtual_engine == 0
self.tpu_cache)
return output
def cache_swap(
self,
blocks_to_swap_in: List[Tuple[int, int]],
blocks_to_swap_out: List[Tuple[int, int]],
blocks_to_copy: List[Tuple[int, int]],
) -> None:
attn_backend = self.model_runner.attn_backend attn_backend = self.model_runner.attn_backend
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
if blocks_to_swap_in: # Issue cache operations.
# Swap from CPU to TPU. if worker_input.blocks_to_swap_in is not None:
src_indices, dst_indices = _make_src_to_dst( src_indices, dst_indices = worker_input.blocks_to_swap_in
blocks_to_swap_in, "cpu", self.device) if src_indices.numel() > 0:
for i in range(num_layers): # Swap from CPU to TPU.
tpu_k_cache, tpu_v_cache = self.tpu_cache[i] for i in range(num_layers):
cpu_k_cache, cpu_v_cache = self.cpu_cache[i] tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device) cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
v = cpu_v_cache[:, src_indices].to(self.device) k = cpu_k_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
if blocks_to_swap_out:
# Swap from TPU to CPU. if worker_input.blocks_to_swap_out is not None:
src_indices, dst_indices = _make_src_to_dst( src_indices, dst_indices = worker_input.blocks_to_swap_out
blocks_to_swap_out, self.device, "cpu") if src_indices.numel() > 0:
for i in range(num_layers): # Swap from TPU to CPU.
tpu_k_cache, tpu_v_cache = self.tpu_cache[i] for i in range(num_layers):
cpu_k_cache, cpu_v_cache = self.cpu_cache[i] tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu() cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu() cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
if blocks_to_copy:
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device, if worker_input.blocks_to_copy is not None:
self.device) src_indices, dst_indices = worker_input.blocks_to_copy
attn_backend.copy_blocks(self.tpu_cache, src_to_dst) if src_indices.numel() > 0:
attn_backend.copy_blocks(self.tpu_cache,
def start_worker_execution_loop(self) -> None: (src_indices, dst_indices))
while self._execute_model_non_driver():
pass
def _execute_model_non_driver(self) -> bool:
self.model_runner.execute_model(None, self.tpu_cache)
return True
def _make_src_to_dst( def _make_src_to_dst(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment