Unverified Commit 4141608c authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware][intel GPU] add async output process for xpu (#8897)

parent dfe43a20
......@@ -361,9 +361,9 @@ class ModelConfig:
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if device_config.device_type not in ("cuda", "tpu"):
if device_config.device_type not in ("cuda", "tpu", "xpu"):
logger.warning(
"Async output processing is only supported for CUDA or TPU. "
"Async output processing is only supported for CUDA, TPU, XPU. "
"Disabling it for other platforms.")
self.use_async_output_proc = False
return
......
......@@ -2,8 +2,8 @@ import dataclasses
import time
import weakref
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, TypeVar)
import torch
import torch.nn as nn
......@@ -57,6 +57,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
......@@ -582,6 +583,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment