Commit 5eb36575 authored by khluu's avatar khluu
Browse files

Revert "[Frontend] Remove frontend pooling multi task support. (#37861)"

This reverts commit d2e2e856.
parent 4d51588e
......@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.logger import init_logger
from vllm.plugins.io_processors import has_io_processor
from vllm.renderers import BaseRenderer
from vllm.tasks import POOLING_TASKS, SCORE_TYPE_MAP, SupportedTask
from vllm.tasks import POOLING_TASKS, SupportedTask
from .base.io_processor import PoolingIOProcessor
from .utils import enable_scoring_api
......@@ -43,24 +43,23 @@ def init_pooling_io_processors(
) -> dict[str, PoolingIOProcessor]:
model_config = vllm_config.model_config
processors: dict[str, type[PoolingIOProcessor]] = {}
pooling_task = model_config.get_pooling_task(supported_tasks)
if pooling_task == "classify":
if "classify" in supported_tasks:
from .classify.io_processor import ClassifyIOProcessor
processors["classify"] = ClassifyIOProcessor
if pooling_task == "token_classify":
if "token_classify" in supported_tasks:
from .classify.io_processor import TokenClassifyIOProcessor
processors["token_classify"] = TokenClassifyIOProcessor
if pooling_task == "embed":
if "embed" in supported_tasks:
from .embed.io_processor import EmbedIOProcessor
processors["embed"] = EmbedIOProcessor
if pooling_task == "token_embed":
if "token_embed" in supported_tasks:
from .embed.io_processor import TokenEmbedIOProcessor
processors["token_embed"] = TokenEmbedIOProcessor
......@@ -72,15 +71,15 @@ def init_pooling_io_processors(
from .pooling.io_processor import PluginWithIOProcessorPlugins
processors["plugin"] = PluginWithIOProcessorPlugins
elif pooling_task == "plugin":
elif "plugin" in supported_tasks:
from .pooling.io_processor import PluginWithoutIOProcessorPlugins
processors["plugin"] = PluginWithoutIOProcessorPlugins
if enable_scoring_api(supported_tasks, model_config):
score_type = model_config.score_type
from .scoring.io_processor import ScoringIOProcessors
score_type: str | None = SCORE_TYPE_MAP.get(pooling_task, None) # type: ignore[arg-type]
if score_type is not None and score_type in ScoringIOProcessors:
processors[score_type] = ScoringIOProcessors[score_type]
......@@ -141,10 +140,6 @@ def init_pooling_state(
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
model_config = engine_client.model_config
if model_config is None:
return
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.tasks import POOLING_TASKS
......@@ -153,14 +148,8 @@ def init_pooling_state(
from .pooling.serving import ServingPooling
from .scoring.serving import ServingScores
model_config = engine_client.model_config
resolved_chat_template = load_chat_template(args.chat_template)
pooling_task = model_config.get_pooling_task(supported_tasks)
chat_template_config = ChatTemplateConfig(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
)
state.serving_pooling = (
(
......@@ -169,7 +158,9 @@ def init_pooling_state(
state.openai_serving_models,
supported_tasks=supported_tasks,
request_logger=request_logger,
chat_template_config=chat_template_config,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
)
)
if any(t in supported_tasks for t in POOLING_TASKS)
......@@ -180,9 +171,11 @@ def init_pooling_state(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template_config=chat_template_config,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
)
if pooling_task == "embed"
if "embed" in supported_tasks
else None
)
state.serving_classification = (
......@@ -190,18 +183,21 @@ def init_pooling_state(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template_config=chat_template_config,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
)
if pooling_task == "classify"
if "classify" in supported_tasks
else None
)
state.serving_scores = (
ServingScores(
engine_client,
state.openai_serving_models,
supported_tasks=supported_tasks,
request_logger=request_logger,
chat_template_config=chat_template_config,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
enable_flash_late_interaction=getattr(
args, "enable_flash_late_interaction", True
),
......@@ -218,12 +214,7 @@ def get_pooling_invocation_types(
# NOTE: Items defined earlier take higher priority
invocation_types: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
if model_config is None:
return invocation_types
pooling_task = model_config.get_pooling_task(supported_tasks)
if pooling_task == "embed":
if "embed" in supported_tasks:
from .embed.api_router import create_embedding, embedding
from .embed.protocol import EmbeddingRequest
......@@ -231,7 +222,7 @@ def get_pooling_invocation_types(
(EmbeddingRequest, (embedding, create_embedding)),
]
if pooling_task == "classify":
if "classify" in supported_tasks:
from .classify.api_router import classify, create_classify
from .classify.protocol import ClassificationRequest
......
......@@ -78,15 +78,17 @@ class ServingPooling(PoolingServingBase):
# plugin task uses io_processor.parse_request to verify inputs
if pooling_task != "plugin" and pooling_task != self.pooling_task:
if pooling_task not in self.supported_tasks:
if pooling_task not in self.io_processors:
raise ValueError(
f"Unsupported task: {pooling_task!r} "
f"Supported tasks: {self.supported_tasks}"
)
else:
raise ValueError(
"Try switching the model's pooling_task "
f"via --pooler-config.task {request.task}."
logger.warning_once(
"Pooling multitask support is deprecated and will be removed "
"in v0.20. When the default pooling task is not what you want, you "
"need to manually specify it via --pooler-config.task %s. ",
pooling_task,
)
if pooling_task == "plugin" and "plugin" not in self.io_processors:
......
......@@ -8,7 +8,6 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.tasks import SCORE_TYPE_MAP, SupportedTask
from vllm.v1.pool.late_interaction import (
build_late_interaction_doc_params,
build_late_interaction_query_params,
......@@ -39,15 +38,10 @@ class ServingScores(PoolingServing):
self,
engine_client: EngineClient,
*args,
supported_tasks: tuple[SupportedTask, ...],
enable_flash_late_interaction: bool = True,
**kwargs,
):
pooling_task = engine_client.model_config.get_pooling_task(supported_tasks)
score_type = SCORE_TYPE_MAP.get(pooling_task, None) # type: ignore[arg-type]
assert score_type is not None
self.io_processor_name: str = score_type
self.io_processor_name: str = engine_client.model_config.score_type
self.enable_flash_late_interaction = (
self.io_processor_name == "late-interaction"
and enable_flash_late_interaction
......
......@@ -141,14 +141,10 @@ def enable_scoring_api(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
) -> bool:
if model_config is None:
return False
pooling_task = model_config.get_pooling_task(supported_tasks)
if pooling_task in ("embed", "token_embed"):
if any(t in supported_tasks for t in ("embed", "token_embed")):
return True
if pooling_task == "classify":
if model_config is not None and "classify" in supported_tasks:
num_labels = getattr(model_config.hf_config, "num_labels", 0)
if num_labels != 1:
logger.debug_once("Scoring API is only enabled for num_labels == 1.")
......
......@@ -87,6 +87,13 @@ class PoolingParams(
return deepcopy(self)
def verify(self, model_config: ModelConfig) -> None:
if self.task == "score":
logger.warning_once(
"`score` task is deprecated and will be removed in v0.20. "
"Please use `classify` instead."
)
self.task = "classify"
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
......
......@@ -16,11 +16,6 @@ PoolingTask = Literal[
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
ScoreType = Literal["bi-encoder", "cross-encoder", "late-interaction"]
SCORE_TYPE_MAP: dict[PoolingTask, ScoreType] = {
"embed": "bi-encoder",
"classify": "cross-encoder",
"token_embed": "late-interaction",
}
FrontendTask = Literal["render"]
FRONTEND_TASKS: tuple[FrontendTask, ...] = get_args(FrontendTask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment