Unverified Commit 761e63e5 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Always pass `supported_tasks` to validation (#35186)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d12d2014
......@@ -194,7 +194,7 @@ class OpenAISpeechToText(OpenAIServing):
def _warmup_input_processor(self) -> None:
"""Warm up input processor with dummy audio to avoid first-request latency.
The first call to input_processor.process_inputs() with multimodal audio
The first call to renderer.render_cmpl() with multimodal audio
triggers multimodal processing initialization which can take ~2.5s.
This method processes a dummy audio request to warm up the pipeline.
"""
......
......@@ -356,13 +356,13 @@ class AsyncLLM(EngineClient):
request_id,
prompt,
params,
supported_tasks=await self.get_supported_tasks(),
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
supported_tasks=await self.get_supported_tasks(),
)
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
......@@ -433,6 +433,7 @@ class AsyncLLM(EngineClient):
self._validate_streaming_input_sampling_params(sampling_params)
inputs = dict(
supported_tasks=await self.get_supported_tasks(),
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
......
......@@ -26,7 +26,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.sampling_params import SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tasks import GENERATION_TASKS, POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.func_utils import supports_kw
......@@ -111,10 +111,8 @@ class InputProcessor:
def _validate_params(
self,
params: SamplingParams | PoolingParams,
# TODO: Validate generation tasks as well once `supported_tasks`
# is passed to all `process_inputs` calls
supported_tasks: tuple[SupportedTask, ...] | None,
):
supported_tasks: tuple[SupportedTask, ...],
) -> None:
"""Raise `ValueError` if SamplingParams or PoolingParams is not valid."""
if params.truncate_prompt_tokens is not None:
params_type = type(params).__name__
......@@ -127,6 +125,12 @@ class InputProcessor:
)
if isinstance(params, SamplingParams):
supported_generation_tasks = [
task for task in supported_tasks if task in GENERATION_TASKS
]
if not supported_generation_tasks:
raise ValueError("This model does not support generation")
params.verify(
self.model_config,
self.speculative_config,
......@@ -134,17 +138,13 @@ class InputProcessor:
self.tokenizer,
)
elif isinstance(params, PoolingParams):
if supported_tasks is None:
raise RuntimeError("`supported_tasks` must be passed for pooling")
supported_pooling_tasks = [
task for task in supported_tasks if task in POOLING_TASKS
]
if params.task is None:
if not supported_pooling_tasks:
raise ValueError("Pooling tasks are not supported")
raise ValueError("This model does not support pooling")
if params.task is None:
if "token_embed" in supported_pooling_tasks:
params.task = "token_embed"
elif "token_classify" in supported_pooling_tasks:
......@@ -227,17 +227,17 @@ class InputProcessor:
request_id: str,
prompt: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams,
supported_tasks: tuple[SupportedTask, ...],
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
supported_tasks: tuple[SupportedTask, ...] | None = None,
resumable: bool = False,
) -> EngineCoreRequest:
self._validate_lora(lora_request)
self._validate_params(params, supported_tasks)
self._validate_lora(lora_request)
parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
......
......@@ -248,12 +248,12 @@ class LLMEngine:
request_id,
prompt,
params,
arrival_time,
lora_request,
tokenization_kwargs,
trace_headers,
priority,
supported_tasks=self.get_supported_tasks(),
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
)
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment