Unverified Commit bc34937d authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Hardware][TPU] Refactor TPU backend (#5831)

parent dd248f76
from typing import List, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
import torch import torch
...@@ -26,29 +26,45 @@ class TPUExecutor(ExecutorBase): ...@@ -26,29 +26,45 @@ class TPUExecutor(ExecutorBase):
self.model_config.dtype = torch.bfloat16 self.model_config.dtype = torch.bfloat16
# Instantiate the worker and load the model to the device. # Instantiate the worker and load the model to the device.
self._init_worker() self.driver_worker = self._create_worker()
self.driver_worker.init_device()
def _init_worker(self): self.driver_worker.load_model()
from vllm.worker.tpu_worker import TPUWorker
assert self.parallel_config.world_size == 1, ( def _get_worker_kwargs(
"TPUExecutor currently only supports a single TPU chip.") self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None,
) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
self.driver_worker = TPUWorker( return dict(
self.model_config, model_config=self.model_config,
self.parallel_config, parallel_config=self.parallel_config,
self.scheduler_config, scheduler_config=self.scheduler_config,
self.device_config, device_config=self.device_config,
self.cache_config, cache_config=self.cache_config,
self.load_config, load_config=self.load_config,
self.vision_language_config, local_rank=local_rank,
local_rank=0, rank=rank,
rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
vision_language_config=self.vision_language_config,
is_driver_worker=rank == 0,
) )
self.driver_worker.init_device()
self.driver_worker.load_model() def _create_worker(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None,
):
from vllm.worker.tpu_worker import TPUWorker
worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return worker
def initialize_cache( def initialize_cache(
self, self,
......
...@@ -33,6 +33,7 @@ class TPUModelRunner: ...@@ -33,6 +33,7 @@ class TPUModelRunner:
cache_config: CacheConfig, cache_config: CacheConfig,
load_config: LoadConfig, load_config: LoadConfig,
vision_language_config: Optional[VisionLanguageConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None,
is_driver_worker: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
...@@ -41,6 +42,7 @@ class TPUModelRunner: ...@@ -41,6 +42,7 @@ class TPUModelRunner:
self.cache_config = cache_config self.cache_config = cache_config
self.load_config = load_config self.load_config = load_config
self.vision_language_config = vision_language_config self.vision_language_config = vision_language_config
self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
self.max_num_blocks_per_seq = (self.model_config.max_model_len // self.max_num_blocks_per_seq = (self.model_config.max_model_len //
...@@ -373,6 +375,8 @@ class TPUModelRunner: ...@@ -373,6 +375,8 @@ class TPUModelRunner:
inputs = self.prepare_inputs(seq_group_metadata_list) inputs = self.prepare_inputs(seq_group_metadata_list)
next_token_ids = self.model(inputs[0], inputs[1], kv_caches, next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
*inputs[2:]) *inputs[2:])
if not self.is_driver_worker:
return []
next_token_ids = next_token_ids.cpu().tolist() next_token_ids = next_token_ids.cpu().tolist()
i = 0 i = 0
......
...@@ -34,6 +34,7 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -34,6 +34,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
is_driver_worker: bool,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
...@@ -45,6 +46,7 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -45,6 +46,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
assert self.device_config.device_type == "tpu" assert self.device_config.device_type == "tpu"
if self.cache_config.cache_dtype == "auto": if self.cache_config.cache_dtype == "auto":
...@@ -53,10 +55,14 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -53,10 +55,14 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype] self.cache_config.cache_dtype]
self.model_runner = TPUModelRunner(model_config, parallel_config, self.model_runner = TPUModelRunner(model_config,
scheduler_config, device_config, parallel_config,
cache_config, load_config, scheduler_config,
vision_language_config) device_config,
cache_config,
load_config,
vision_language_config,
is_driver_worker=is_driver_worker)
def init_device(self) -> None: def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU" os.environ["PJRT_DEVICE"] = "TPU"
...@@ -175,16 +181,13 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -175,16 +181,13 @@ class TPUWorker(LoraNotSupportedWorkerBase):
def execute_model( def execute_model(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
if execute_model_req is None: if not self.is_driver_worker:
return [] self._execute_model_non_driver()
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
num_seq_groups = len(seq_group_metadata_list)
if num_seq_groups == 0:
return [] return []
assert execute_model_req is not None
# Currently, TPUWorker does not support swapping. # Currently, TPUWorker does not support swapping.
# TODO(woosuk): Support block copying. # TODO(woosuk): Support block copying.
assert len(execute_model_req.blocks_to_swap_in) == 0, ( assert len(execute_model_req.blocks_to_swap_in) == 0, (
...@@ -193,6 +196,16 @@ class TPUWorker(LoraNotSupportedWorkerBase): ...@@ -193,6 +196,16 @@ class TPUWorker(LoraNotSupportedWorkerBase):
"Swapping is not supported for the TPU backend.") "Swapping is not supported for the TPU backend.")
assert len(execute_model_req.blocks_to_copy) == 0 assert len(execute_model_req.blocks_to_copy) == 0
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
assert len(seq_group_metadata_list) > 0
output = self.model_runner.execute_model(seq_group_metadata_list, output = self.model_runner.execute_model(seq_group_metadata_list,
self.tpu_cache) self.tpu_cache)
return [output] return [output]
def start_worker_execution_loop(self) -> None:
while self._execute_model_non_driver():
pass
def _execute_model_non_driver(self) -> bool:
self.model_runner.execute_model(None, self.tpu_cache)
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment