Commit 5eb36575 authored by khluu's avatar khluu
Browse files

Revert "[Frontend] Remove frontend pooling multi task support. (#37861)"

This reverts commit d2e2e856.
parent 4d51588e
...@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateConfig ...@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.plugins.io_processors import has_io_processor from vllm.plugins.io_processors import has_io_processor
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.tasks import POOLING_TASKS, SCORE_TYPE_MAP, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
from .base.io_processor import PoolingIOProcessor from .base.io_processor import PoolingIOProcessor
from .utils import enable_scoring_api from .utils import enable_scoring_api
...@@ -43,24 +43,23 @@ def init_pooling_io_processors( ...@@ -43,24 +43,23 @@ def init_pooling_io_processors(
) -> dict[str, PoolingIOProcessor]: ) -> dict[str, PoolingIOProcessor]:
model_config = vllm_config.model_config model_config = vllm_config.model_config
processors: dict[str, type[PoolingIOProcessor]] = {} processors: dict[str, type[PoolingIOProcessor]] = {}
pooling_task = model_config.get_pooling_task(supported_tasks)
if pooling_task == "classify": if "classify" in supported_tasks:
from .classify.io_processor import ClassifyIOProcessor from .classify.io_processor import ClassifyIOProcessor
processors["classify"] = ClassifyIOProcessor processors["classify"] = ClassifyIOProcessor
if pooling_task == "token_classify": if "token_classify" in supported_tasks:
from .classify.io_processor import TokenClassifyIOProcessor from .classify.io_processor import TokenClassifyIOProcessor
processors["token_classify"] = TokenClassifyIOProcessor processors["token_classify"] = TokenClassifyIOProcessor
if pooling_task == "embed": if "embed" in supported_tasks:
from .embed.io_processor import EmbedIOProcessor from .embed.io_processor import EmbedIOProcessor
processors["embed"] = EmbedIOProcessor processors["embed"] = EmbedIOProcessor
if pooling_task == "token_embed": if "token_embed" in supported_tasks:
from .embed.io_processor import TokenEmbedIOProcessor from .embed.io_processor import TokenEmbedIOProcessor
processors["token_embed"] = TokenEmbedIOProcessor processors["token_embed"] = TokenEmbedIOProcessor
...@@ -72,15 +71,15 @@ def init_pooling_io_processors( ...@@ -72,15 +71,15 @@ def init_pooling_io_processors(
from .pooling.io_processor import PluginWithIOProcessorPlugins from .pooling.io_processor import PluginWithIOProcessorPlugins
processors["plugin"] = PluginWithIOProcessorPlugins processors["plugin"] = PluginWithIOProcessorPlugins
elif pooling_task == "plugin": elif "plugin" in supported_tasks:
from .pooling.io_processor import PluginWithoutIOProcessorPlugins from .pooling.io_processor import PluginWithoutIOProcessorPlugins
processors["plugin"] = PluginWithoutIOProcessorPlugins processors["plugin"] = PluginWithoutIOProcessorPlugins
if enable_scoring_api(supported_tasks, model_config): if enable_scoring_api(supported_tasks, model_config):
score_type = model_config.score_type
from .scoring.io_processor import ScoringIOProcessors from .scoring.io_processor import ScoringIOProcessors
score_type: str | None = SCORE_TYPE_MAP.get(pooling_task, None) # type: ignore[arg-type]
if score_type is not None and score_type in ScoringIOProcessors: if score_type is not None and score_type in ScoringIOProcessors:
processors[score_type] = ScoringIOProcessors[score_type] processors[score_type] = ScoringIOProcessors[score_type]
...@@ -141,10 +140,6 @@ def init_pooling_state( ...@@ -141,10 +140,6 @@ def init_pooling_state(
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...], supported_tasks: tuple["SupportedTask", ...],
): ):
model_config = engine_client.model_config
if model_config is None:
return
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import load_chat_template
from vllm.tasks import POOLING_TASKS from vllm.tasks import POOLING_TASKS
...@@ -153,14 +148,8 @@ def init_pooling_state( ...@@ -153,14 +148,8 @@ def init_pooling_state(
from .pooling.serving import ServingPooling from .pooling.serving import ServingPooling
from .scoring.serving import ServingScores from .scoring.serving import ServingScores
model_config = engine_client.model_config
resolved_chat_template = load_chat_template(args.chat_template) resolved_chat_template = load_chat_template(args.chat_template)
pooling_task = model_config.get_pooling_task(supported_tasks)
chat_template_config = ChatTemplateConfig(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
)
state.serving_pooling = ( state.serving_pooling = (
( (
...@@ -169,7 +158,9 @@ def init_pooling_state( ...@@ -169,7 +158,9 @@ def init_pooling_state(
state.openai_serving_models, state.openai_serving_models,
supported_tasks=supported_tasks, supported_tasks=supported_tasks,
request_logger=request_logger, request_logger=request_logger,
chat_template_config=chat_template_config, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
) )
) )
if any(t in supported_tasks for t in POOLING_TASKS) if any(t in supported_tasks for t in POOLING_TASKS)
...@@ -180,9 +171,11 @@ def init_pooling_state( ...@@ -180,9 +171,11 @@ def init_pooling_state(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
chat_template_config=chat_template_config, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
) )
if pooling_task == "embed" if "embed" in supported_tasks
else None else None
) )
state.serving_classification = ( state.serving_classification = (
...@@ -190,18 +183,21 @@ def init_pooling_state( ...@@ -190,18 +183,21 @@ def init_pooling_state(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
chat_template_config=chat_template_config, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
) )
if pooling_task == "classify" if "classify" in supported_tasks
else None else None
) )
state.serving_scores = ( state.serving_scores = (
ServingScores( ServingScores(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
supported_tasks=supported_tasks,
request_logger=request_logger, request_logger=request_logger,
chat_template_config=chat_template_config, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
enable_flash_late_interaction=getattr( enable_flash_late_interaction=getattr(
args, "enable_flash_late_interaction", True args, "enable_flash_late_interaction", True
), ),
...@@ -218,12 +214,7 @@ def get_pooling_invocation_types( ...@@ -218,12 +214,7 @@ def get_pooling_invocation_types(
# NOTE: Items defined earlier take higher priority # NOTE: Items defined earlier take higher priority
invocation_types: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [] invocation_types: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
if model_config is None: if "embed" in supported_tasks:
return invocation_types
pooling_task = model_config.get_pooling_task(supported_tasks)
if pooling_task == "embed":
from .embed.api_router import create_embedding, embedding from .embed.api_router import create_embedding, embedding
from .embed.protocol import EmbeddingRequest from .embed.protocol import EmbeddingRequest
...@@ -231,7 +222,7 @@ def get_pooling_invocation_types( ...@@ -231,7 +222,7 @@ def get_pooling_invocation_types(
(EmbeddingRequest, (embedding, create_embedding)), (EmbeddingRequest, (embedding, create_embedding)),
] ]
if pooling_task == "classify": if "classify" in supported_tasks:
from .classify.api_router import classify, create_classify from .classify.api_router import classify, create_classify
from .classify.protocol import ClassificationRequest from .classify.protocol import ClassificationRequest
......
...@@ -78,15 +78,17 @@ class ServingPooling(PoolingServingBase): ...@@ -78,15 +78,17 @@ class ServingPooling(PoolingServingBase):
# plugin task uses io_processor.parse_request to verify inputs # plugin task uses io_processor.parse_request to verify inputs
if pooling_task != "plugin" and pooling_task != self.pooling_task: if pooling_task != "plugin" and pooling_task != self.pooling_task:
if pooling_task not in self.supported_tasks: if pooling_task not in self.io_processors:
raise ValueError( raise ValueError(
f"Unsupported task: {pooling_task!r} " f"Unsupported task: {pooling_task!r} "
f"Supported tasks: {self.supported_tasks}" f"Supported tasks: {self.supported_tasks}"
) )
else: else:
raise ValueError( logger.warning_once(
"Try switching the model's pooling_task " "Pooling multitask support is deprecated and will be removed "
f"via --pooler-config.task {request.task}." "in v0.20. When the default pooling task is not what you want, you "
"need to manually specify it via --pooler-config.task %s. ",
pooling_task,
) )
if pooling_task == "plugin" and "plugin" not in self.io_processors: if pooling_task == "plugin" and "plugin" not in self.io_processors:
......
...@@ -8,7 +8,6 @@ from vllm.engine.protocol import EngineClient ...@@ -8,7 +8,6 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.tasks import SCORE_TYPE_MAP, SupportedTask
from vllm.v1.pool.late_interaction import ( from vllm.v1.pool.late_interaction import (
build_late_interaction_doc_params, build_late_interaction_doc_params,
build_late_interaction_query_params, build_late_interaction_query_params,
...@@ -39,15 +38,10 @@ class ServingScores(PoolingServing): ...@@ -39,15 +38,10 @@ class ServingScores(PoolingServing):
self, self,
engine_client: EngineClient, engine_client: EngineClient,
*args, *args,
supported_tasks: tuple[SupportedTask, ...],
enable_flash_late_interaction: bool = True, enable_flash_late_interaction: bool = True,
**kwargs, **kwargs,
): ):
pooling_task = engine_client.model_config.get_pooling_task(supported_tasks) self.io_processor_name: str = engine_client.model_config.score_type
score_type = SCORE_TYPE_MAP.get(pooling_task, None) # type: ignore[arg-type]
assert score_type is not None
self.io_processor_name: str = score_type
self.enable_flash_late_interaction = ( self.enable_flash_late_interaction = (
self.io_processor_name == "late-interaction" self.io_processor_name == "late-interaction"
and enable_flash_late_interaction and enable_flash_late_interaction
......
...@@ -141,14 +141,10 @@ def enable_scoring_api( ...@@ -141,14 +141,10 @@ def enable_scoring_api(
supported_tasks: tuple["SupportedTask", ...], supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None, model_config: ModelConfig | None = None,
) -> bool: ) -> bool:
if model_config is None: if any(t in supported_tasks for t in ("embed", "token_embed")):
return False
pooling_task = model_config.get_pooling_task(supported_tasks)
if pooling_task in ("embed", "token_embed"):
return True return True
if pooling_task == "classify": if model_config is not None and "classify" in supported_tasks:
num_labels = getattr(model_config.hf_config, "num_labels", 0) num_labels = getattr(model_config.hf_config, "num_labels", 0)
if num_labels != 1: if num_labels != 1:
logger.debug_once("Scoring API is only enabled for num_labels == 1.") logger.debug_once("Scoring API is only enabled for num_labels == 1.")
......
...@@ -87,6 +87,13 @@ class PoolingParams( ...@@ -87,6 +87,13 @@ class PoolingParams(
return deepcopy(self) return deepcopy(self)
def verify(self, model_config: ModelConfig) -> None: def verify(self, model_config: ModelConfig) -> None:
if self.task == "score":
logger.warning_once(
"`score` task is deprecated and will be removed in v0.20. "
"Please use `classify` instead."
)
self.task = "classify"
# plugin task uses io_processor.parse_request to verify inputs, # plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify # skipping PoolingParams verify
if self.task == "plugin": if self.task == "plugin":
......
...@@ -16,11 +16,6 @@ PoolingTask = Literal[ ...@@ -16,11 +16,6 @@ PoolingTask = Literal[
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask) POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
ScoreType = Literal["bi-encoder", "cross-encoder", "late-interaction"] ScoreType = Literal["bi-encoder", "cross-encoder", "late-interaction"]
SCORE_TYPE_MAP: dict[PoolingTask, ScoreType] = {
"embed": "bi-encoder",
"classify": "cross-encoder",
"token_embed": "late-interaction",
}
FrontendTask = Literal["render"] FrontendTask = Literal["render"]
FRONTEND_TASKS: tuple[FrontendTask, ...] = get_args(FrontendTask) FRONTEND_TASKS: tuple[FrontendTask, ...] = get_args(FrontendTask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment