Unverified Commit 038914b7 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Move `task` outside of `PoolingParams.verify` (#33796)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent d2f4a71c
...@@ -269,7 +269,11 @@ class AsyncLLM(EngineClient): ...@@ -269,7 +269,11 @@ class AsyncLLM(EngineClient):
cancel_task_threadsafe(handler) cancel_task_threadsafe(handler)
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async() if not hasattr(self, "_supported_tasks"):
# Cache the result
self._supported_tasks = await self.engine_core.get_supported_tasks_async()
return self._supported_tasks
async def add_request( async def add_request(
self, self,
...@@ -355,6 +359,7 @@ class AsyncLLM(EngineClient): ...@@ -355,6 +359,7 @@ class AsyncLLM(EngineClient):
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
supported_tasks=await self.get_supported_tasks(),
) )
prompt_text = get_prompt_text(prompt) prompt_text = get_prompt_text(prompt)
......
...@@ -31,6 +31,7 @@ from vllm.multimodal.utils import argsort_mm_positions ...@@ -31,6 +31,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
...@@ -196,13 +197,41 @@ class InputProcessor: ...@@ -196,13 +197,41 @@ class InputProcessor:
def _validate_params( def _validate_params(
self, self,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
# TODO: Validate generation tasks as well once `supported_tasks`
# is passed to all `process_inputs` calls
supported_tasks: tuple[SupportedTask, ...] | None,
): ):
""" """
Validate supported SamplingParam. Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server. Should raise ValueError if unsupported for API Server.
""" """
if isinstance(params, PoolingParams): if isinstance(params, PoolingParams):
if supported_tasks is None:
raise RuntimeError("`supported_tasks` must be passed for pooling")
supported_pooling_tasks = [
task for task in supported_tasks if task in POOLING_TASKS
]
if params.task is None:
if not supported_pooling_tasks:
raise ValueError("Pooling tasks are not supported")
if "token_embed" in supported_pooling_tasks:
params.task = "token_embed"
elif "token_classify" in supported_pooling_tasks:
params.task = "token_classify"
elif "plugin" in supported_pooling_tasks:
params.task = "plugin"
if params.task not in supported_pooling_tasks:
raise ValueError(
f"Unsupported task: {params.task!r} "
f"Supported tasks: {supported_pooling_tasks}"
)
params.verify(self.model_config)
return return
self._validate_logprobs(params) self._validate_logprobs(params)
...@@ -498,10 +527,11 @@ class InputProcessor: ...@@ -498,10 +527,11 @@ class InputProcessor:
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: int | None = None, data_parallel_rank: int | None = None,
supported_tasks: tuple[SupportedTask, ...] | None = None,
resumable: bool = False, resumable: bool = False,
) -> EngineCoreRequest: ) -> EngineCoreRequest:
self._validate_lora(lora_request) self._validate_lora(lora_request)
self._validate_params(params) self._validate_params(params, supported_tasks)
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
......
...@@ -201,7 +201,11 @@ class LLMEngine: ...@@ -201,7 +201,11 @@ class LLMEngine:
return outputs return outputs
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks() if not hasattr(self, "_supported_tasks"):
# Cache the result
self._supported_tasks = self.engine_core.get_supported_tasks()
return self._supported_tasks
def abort_request(self, request_ids: list[str], internal: bool = False) -> None: def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
"""Remove request_ids from EngineCore and Detokenizer.""" """Remove request_ids from EngineCore and Detokenizer."""
...@@ -245,6 +249,7 @@ class LLMEngine: ...@@ -245,6 +249,7 @@ class LLMEngine:
tokenization_kwargs, tokenization_kwargs,
trace_headers, trace_headers,
priority, priority,
supported_tasks=self.get_supported_tasks(),
) )
prompt_text = get_prompt_text(prompt) prompt_text = get_prompt_text(prompt)
......
...@@ -5037,7 +5037,7 @@ class GPUModelRunner( ...@@ -5037,7 +5037,7 @@ class GPUModelRunner(
model = cast(VllmModelForPooling, self.get_model()) model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task) dummy_pooling_params = PoolingParams(task=task)
dummy_pooling_params.verify(task=task, model_config=self.model_config) dummy_pooling_params.verify(self.model_config)
to_update = model.pooler.get_pooling_updates(task) to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params) to_update.apply(dummy_pooling_params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment