"vscode:/vscode.git/clone" did not exist on "69f30ae059e7623918637336457f1be42c8652c0"
Unverified Commit 038914b7 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Move `task` outside of `PoolingParams.verify` (#33796)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent d2f4a71c
......@@ -269,7 +269,11 @@ class AsyncLLM(EngineClient):
cancel_task_threadsafe(handler)
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async()
if not hasattr(self, "_supported_tasks"):
# Cache the result
self._supported_tasks = await self.engine_core.get_supported_tasks_async()
return self._supported_tasks
async def add_request(
self,
......@@ -355,6 +359,7 @@ class AsyncLLM(EngineClient):
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
supported_tasks=await self.get_supported_tasks(),
)
prompt_text = get_prompt_text(prompt)
......
......@@ -31,6 +31,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
......@@ -196,13 +197,41 @@ class InputProcessor:
def _validate_params(
self,
params: SamplingParams | PoolingParams,
# TODO: Validate generation tasks as well once `supported_tasks`
# is passed to all `process_inputs` calls
supported_tasks: tuple[SupportedTask, ...] | None,
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""
if isinstance(params, PoolingParams):
if supported_tasks is None:
raise RuntimeError("`supported_tasks` must be passed for pooling")
supported_pooling_tasks = [
task for task in supported_tasks if task in POOLING_TASKS
]
if params.task is None:
if not supported_pooling_tasks:
raise ValueError("Pooling tasks are not supported")
if "token_embed" in supported_pooling_tasks:
params.task = "token_embed"
elif "token_classify" in supported_pooling_tasks:
params.task = "token_classify"
elif "plugin" in supported_pooling_tasks:
params.task = "plugin"
if params.task not in supported_pooling_tasks:
raise ValueError(
f"Unsupported task: {params.task!r} "
f"Supported tasks: {supported_pooling_tasks}"
)
params.verify(self.model_config)
return
self._validate_logprobs(params)
......@@ -498,10 +527,11 @@ class InputProcessor:
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
supported_tasks: tuple[SupportedTask, ...] | None = None,
resumable: bool = False,
) -> EngineCoreRequest:
self._validate_lora(lora_request)
self._validate_params(params)
self._validate_params(params, supported_tasks)
parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
......
......@@ -201,7 +201,11 @@ class LLMEngine:
return outputs
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
if not hasattr(self, "_supported_tasks"):
# Cache the result
self._supported_tasks = self.engine_core.get_supported_tasks()
return self._supported_tasks
def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
......@@ -245,6 +249,7 @@ class LLMEngine:
tokenization_kwargs,
trace_headers,
priority,
supported_tasks=self.get_supported_tasks(),
)
prompt_text = get_prompt_text(prompt)
......
......@@ -5037,7 +5037,7 @@ class GPUModelRunner(
model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task)
dummy_pooling_params.verify(task=task, model_config=self.model_config)
dummy_pooling_params.verify(self.model_config)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment