Unverified Commit 80c7b089 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[TPU] Async output processing for TPU (#8011)

parent 428dd144
...@@ -347,10 +347,10 @@ class ModelConfig: ...@@ -347,10 +347,10 @@ class ModelConfig:
self.use_async_output_proc = False self.use_async_output_proc = False
return return
if device_config.device_type != "cuda": if device_config.device_type not in ("cuda", "tpu"):
logger.warning( logger.warning(
"Async output processing is only supported for CUDA." "Async output processing is only supported for CUDA or TPU. "
" Disabling it for other platforms.") "Disabling it for other platforms.")
self.use_async_output_proc = False self.use_async_output_proc = False
return return
......
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, Union)
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
...@@ -51,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase): ...@@ -51,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
best_of: List[int] best_of: List[int]
seq_groups: List[List[int]] seq_groups: List[List[int]]
virtual_engine: int = 0 virtual_engine: int = 0
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
...@@ -562,6 +564,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -562,6 +564,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model_input.attn_metadata, model_input.input_lens[i:i + 1], model_input.attn_metadata, model_input.input_lens[i:i + 1],
model_input.t[i:i + 1], model_input.p[i:i + 1], model_input.t[i:i + 1], model_input.p[i:i + 1],
model_input.num_samples, kv_caches) model_input.num_samples, kv_caches)
if i == 0 and model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU. # Retrieve the outputs to CPU.
next_token_ids += output_token_ids.cpu().tolist() next_token_ids += output_token_ids.cpu().tolist()
start_idx = end_idx start_idx = end_idx
...@@ -572,6 +576,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -572,6 +576,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model_input.attn_metadata, model_input.input_lens, model_input.attn_metadata, model_input.input_lens,
model_input.t, model_input.p, model_input.num_samples, model_input.t, model_input.p, model_input.num_samples,
kv_caches) kv_caches)
if model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU. # Retrieve the outputs to CPU.
next_token_ids = output_token_ids.cpu().tolist() next_token_ids = output_token_ids.cpu().tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment