Unverified Commit 761e63e5 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Always pass `supported_tasks` to validation (#35186)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d12d2014
...@@ -194,7 +194,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -194,7 +194,7 @@ class OpenAISpeechToText(OpenAIServing):
def _warmup_input_processor(self) -> None: def _warmup_input_processor(self) -> None:
"""Warm up input processor with dummy audio to avoid first-request latency. """Warm up input processor with dummy audio to avoid first-request latency.
The first call to input_processor.process_inputs() with multimodal audio The first call to renderer.render_cmpl() with multimodal audio
triggers multimodal processing initialization which can take ~2.5s. triggers multimodal processing initialization which can take ~2.5s.
This method processes a dummy audio request to warm up the pipeline. This method processes a dummy audio request to warm up the pipeline.
""" """
......
...@@ -356,13 +356,13 @@ class AsyncLLM(EngineClient): ...@@ -356,13 +356,13 @@ class AsyncLLM(EngineClient):
request_id, request_id,
prompt, prompt,
params, params,
supported_tasks=await self.get_supported_tasks(),
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
supported_tasks=await self.get_supported_tasks(),
) )
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt) prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
...@@ -433,6 +433,7 @@ class AsyncLLM(EngineClient): ...@@ -433,6 +433,7 @@ class AsyncLLM(EngineClient):
self._validate_streaming_input_sampling_params(sampling_params) self._validate_streaming_input_sampling_params(sampling_params)
inputs = dict( inputs = dict(
supported_tasks=await self.get_supported_tasks(),
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
......
...@@ -26,7 +26,7 @@ from vllm.multimodal.utils import argsort_mm_positions ...@@ -26,7 +26,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer, renderer_from_config from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import GENERATION_TASKS, POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
...@@ -111,10 +111,8 @@ class InputProcessor: ...@@ -111,10 +111,8 @@ class InputProcessor:
def _validate_params( def _validate_params(
self, self,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
# TODO: Validate generation tasks as well once `supported_tasks` supported_tasks: tuple[SupportedTask, ...],
# is passed to all `process_inputs` calls ) -> None:
supported_tasks: tuple[SupportedTask, ...] | None,
):
"""Raise `ValueError` if SamplingParams or PoolingParams is not valid.""" """Raise `ValueError` if SamplingParams or PoolingParams is not valid."""
if params.truncate_prompt_tokens is not None: if params.truncate_prompt_tokens is not None:
params_type = type(params).__name__ params_type = type(params).__name__
...@@ -127,6 +125,12 @@ class InputProcessor: ...@@ -127,6 +125,12 @@ class InputProcessor:
) )
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
supported_generation_tasks = [
task for task in supported_tasks if task in GENERATION_TASKS
]
if not supported_generation_tasks:
raise ValueError("This model does not support generation")
params.verify( params.verify(
self.model_config, self.model_config,
self.speculative_config, self.speculative_config,
...@@ -134,17 +138,13 @@ class InputProcessor: ...@@ -134,17 +138,13 @@ class InputProcessor:
self.tokenizer, self.tokenizer,
) )
elif isinstance(params, PoolingParams): elif isinstance(params, PoolingParams):
if supported_tasks is None:
raise RuntimeError("`supported_tasks` must be passed for pooling")
supported_pooling_tasks = [ supported_pooling_tasks = [
task for task in supported_tasks if task in POOLING_TASKS task for task in supported_tasks if task in POOLING_TASKS
] ]
if params.task is None:
if not supported_pooling_tasks: if not supported_pooling_tasks:
raise ValueError("Pooling tasks are not supported") raise ValueError("This model does not support pooling")
if params.task is None:
if "token_embed" in supported_pooling_tasks: if "token_embed" in supported_pooling_tasks:
params.task = "token_embed" params.task = "token_embed"
elif "token_classify" in supported_pooling_tasks: elif "token_classify" in supported_pooling_tasks:
...@@ -227,17 +227,17 @@ class InputProcessor: ...@@ -227,17 +227,17 @@ class InputProcessor:
request_id: str, request_id: str,
prompt: PromptType | ProcessorInputs, prompt: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
supported_tasks: tuple[SupportedTask, ...],
arrival_time: float | None = None, arrival_time: float | None = None,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: int | None = None, data_parallel_rank: int | None = None,
supported_tasks: tuple[SupportedTask, ...] | None = None,
resumable: bool = False, resumable: bool = False,
) -> EngineCoreRequest: ) -> EngineCoreRequest:
self._validate_lora(lora_request)
self._validate_params(params, supported_tasks) self._validate_params(params, supported_tasks)
self._validate_lora(lora_request)
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
......
...@@ -248,12 +248,12 @@ class LLMEngine: ...@@ -248,12 +248,12 @@ class LLMEngine:
request_id, request_id,
prompt, prompt,
params, params,
arrival_time,
lora_request,
tokenization_kwargs,
trace_headers,
priority,
supported_tasks=self.get_supported_tasks(), supported_tasks=self.get_supported_tasks(),
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
) )
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt) prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment