Unverified Commit e1b00483 authored by Joe Runde's avatar Joe Runde Committed by GitHub
Browse files

[Hardware] Add processor inputs to platform validation (#16680)


Signed-off-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
parent ee378f3d
...@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union ...@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from vllm.inputs import PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -400,6 +400,7 @@ class Platform: ...@@ -400,6 +400,7 @@ class Platform:
cls, cls,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
processed_inputs: ProcessorInputs,
) -> None: ) -> None:
"""Raises if this request is unsupported on this platform""" """Raises if this request is unsupported on this platform"""
......
...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union ...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.inputs import PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
...@@ -150,6 +150,7 @@ class TpuPlatform(Platform): ...@@ -150,6 +150,7 @@ class TpuPlatform(Platform):
cls, cls,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
processed_inputs: ProcessorInputs,
) -> None: ) -> None:
"""Raises if this request is unsupported on this platform""" """Raises if this request is unsupported on this platform"""
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
......
...@@ -202,12 +202,6 @@ class Processor: ...@@ -202,12 +202,6 @@ class Processor:
# TODO(woosuk): Support pooling models. # TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models. # TODO(woosuk): Support encoder-decoder models.
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
)
self._validate_lora(lora_request) self._validate_lora(lora_request)
self._validate_params(params) self._validate_params(params)
if priority != 0: if priority != 0:
...@@ -231,6 +225,12 @@ class Processor: ...@@ -231,6 +225,12 @@ class Processor:
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash, return_mm_hashes=self.use_hash,
) )
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
self._validate_model_inputs(processed_inputs, lora_request) self._validate_model_inputs(processed_inputs, lora_request)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment