Unverified Commit d2e2e856 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend] Remove frontend pooling multi task support. (#37861)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent 766cb65d
...@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateConfig ...@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.plugins.io_processors import has_io_processor from vllm.plugins.io_processors import has_io_processor
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SCORE_TYPE_MAP, SupportedTask
from .base.io_processor import PoolingIOProcessor from .base.io_processor import PoolingIOProcessor
from .utils import enable_scoring_api from .utils import enable_scoring_api
...@@ -43,23 +43,24 @@ def init_pooling_io_processors( ...@@ -43,23 +43,24 @@ def init_pooling_io_processors(
) -> dict[str, PoolingIOProcessor]: ) -> dict[str, PoolingIOProcessor]:
model_config = vllm_config.model_config model_config = vllm_config.model_config
processors: dict[str, type[PoolingIOProcessor]] = {} processors: dict[str, type[PoolingIOProcessor]] = {}
pooling_task = model_config.get_pooling_task(supported_tasks)
if "classify" in supported_tasks: if pooling_task == "classify":
from .classify.io_processor import ClassifyIOProcessor from .classify.io_processor import ClassifyIOProcessor
processors["classify"] = ClassifyIOProcessor processors["classify"] = ClassifyIOProcessor
if "token_classify" in supported_tasks: if pooling_task == "token_classify":
from .classify.io_processor import TokenClassifyIOProcessor from .classify.io_processor import TokenClassifyIOProcessor
processors["token_classify"] = TokenClassifyIOProcessor processors["token_classify"] = TokenClassifyIOProcessor
if "embed" in supported_tasks: if pooling_task == "embed":
from .embed.io_processor import EmbedIOProcessor from .embed.io_processor import EmbedIOProcessor
processors["embed"] = EmbedIOProcessor processors["embed"] = EmbedIOProcessor
if "token_embed" in supported_tasks: if pooling_task == "token_embed":
from .embed.io_processor import TokenEmbedIOProcessor from .embed.io_processor import TokenEmbedIOProcessor
processors["token_embed"] = TokenEmbedIOProcessor processors["token_embed"] = TokenEmbedIOProcessor
...@@ -71,15 +72,15 @@ def init_pooling_io_processors( ...@@ -71,15 +72,15 @@ def init_pooling_io_processors(
from .pooling.io_processor import PluginWithIOProcessorPlugins from .pooling.io_processor import PluginWithIOProcessorPlugins
processors["plugin"] = PluginWithIOProcessorPlugins processors["plugin"] = PluginWithIOProcessorPlugins
elif "plugin" in supported_tasks: elif pooling_task == "plugin":
from .pooling.io_processor import PluginWithoutIOProcessorPlugins from .pooling.io_processor import PluginWithoutIOProcessorPlugins
processors["plugin"] = PluginWithoutIOProcessorPlugins processors["plugin"] = PluginWithoutIOProcessorPlugins
if enable_scoring_api(supported_tasks, model_config): if enable_scoring_api(supported_tasks, model_config):
score_type = model_config.score_type
from .scoring.io_processor import ScoringIOProcessors from .scoring.io_processor import ScoringIOProcessors
score_type: str | None = SCORE_TYPE_MAP.get(pooling_task, None) # type: ignore[arg-type]
if score_type is not None and score_type in ScoringIOProcessors: if score_type is not None and score_type in ScoringIOProcessors:
processors[score_type] = ScoringIOProcessors[score_type] processors[score_type] = ScoringIOProcessors[score_type]
...@@ -140,6 +141,10 @@ def init_pooling_state( ...@@ -140,6 +141,10 @@ def init_pooling_state(
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...], supported_tasks: tuple["SupportedTask", ...],
): ):
model_config = engine_client.model_config
if model_config is None:
return
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import load_chat_template
from vllm.tasks import POOLING_TASKS from vllm.tasks import POOLING_TASKS
...@@ -148,8 +153,14 @@ def init_pooling_state( ...@@ -148,8 +153,14 @@ def init_pooling_state(
from .pooling.serving import ServingPooling from .pooling.serving import ServingPooling
from .scoring.serving import ServingScores from .scoring.serving import ServingScores
model_config = engine_client.model_config
resolved_chat_template = load_chat_template(args.chat_template) resolved_chat_template = load_chat_template(args.chat_template)
pooling_task = model_config.get_pooling_task(supported_tasks)
chat_template_config = ChatTemplateConfig(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
)
state.serving_pooling = ( state.serving_pooling = (
( (
...@@ -158,9 +169,7 @@ def init_pooling_state( ...@@ -158,9 +169,7 @@ def init_pooling_state(
state.openai_serving_models, state.openai_serving_models,
supported_tasks=supported_tasks, supported_tasks=supported_tasks,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template_config=chat_template_config,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
) )
) )
if any(t in supported_tasks for t in POOLING_TASKS) if any(t in supported_tasks for t in POOLING_TASKS)
...@@ -171,11 +180,9 @@ def init_pooling_state( ...@@ -171,11 +180,9 @@ def init_pooling_state(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template_config=chat_template_config,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
) )
if "embed" in supported_tasks if pooling_task == "embed"
else None else None
) )
state.serving_classification = ( state.serving_classification = (
...@@ -183,21 +190,18 @@ def init_pooling_state( ...@@ -183,21 +190,18 @@ def init_pooling_state(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template_config=chat_template_config,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
) )
if "classify" in supported_tasks if pooling_task == "classify"
else None else None
) )
state.serving_scores = ( state.serving_scores = (
ServingScores( ServingScores(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
supported_tasks=supported_tasks,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template_config=chat_template_config,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
enable_flash_late_interaction=getattr( enable_flash_late_interaction=getattr(
args, "enable_flash_late_interaction", True args, "enable_flash_late_interaction", True
), ),
...@@ -214,7 +218,12 @@ def get_pooling_invocation_types( ...@@ -214,7 +218,12 @@ def get_pooling_invocation_types(
# NOTE: Items defined earlier take higher priority # NOTE: Items defined earlier take higher priority
invocation_types: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [] invocation_types: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
if "embed" in supported_tasks: if model_config is None:
return invocation_types
pooling_task = model_config.get_pooling_task(supported_tasks)
if pooling_task == "embed":
from .embed.api_router import create_embedding, embedding from .embed.api_router import create_embedding, embedding
from .embed.protocol import EmbeddingRequest from .embed.protocol import EmbeddingRequest
...@@ -222,7 +231,7 @@ def get_pooling_invocation_types( ...@@ -222,7 +231,7 @@ def get_pooling_invocation_types(
(EmbeddingRequest, (embedding, create_embedding)), (EmbeddingRequest, (embedding, create_embedding)),
] ]
if "classify" in supported_tasks: if pooling_task == "classify":
from .classify.api_router import classify, create_classify from .classify.api_router import classify, create_classify
from .classify.protocol import ClassificationRequest from .classify.protocol import ClassificationRequest
......
...@@ -78,17 +78,15 @@ class ServingPooling(PoolingServingBase): ...@@ -78,17 +78,15 @@ class ServingPooling(PoolingServingBase):
# plugin task uses io_processor.parse_request to verify inputs # plugin task uses io_processor.parse_request to verify inputs
if pooling_task != "plugin" and pooling_task != self.pooling_task: if pooling_task != "plugin" and pooling_task != self.pooling_task:
if pooling_task not in self.io_processors: if pooling_task not in self.supported_tasks:
raise ValueError( raise ValueError(
f"Unsupported task: {pooling_task!r} " f"Unsupported task: {pooling_task!r} "
f"Supported tasks: {self.supported_tasks}" f"Supported tasks: {self.supported_tasks}"
) )
else: else:
logger.warning_once( raise ValueError(
"Pooling multitask support is deprecated and will be removed " "Try switching the model's pooling_task "
"in v0.20. When the default pooling task is not what you want, you " f"via --pooler-config.task {request.task}."
"need to manually specify it via --pooler-config.task %s. ",
pooling_task,
) )
if pooling_task == "plugin" and "plugin" not in self.io_processors: if pooling_task == "plugin" and "plugin" not in self.io_processors:
......
...@@ -8,6 +8,7 @@ from vllm.engine.protocol import EngineClient ...@@ -8,6 +8,7 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.tasks import SCORE_TYPE_MAP, SupportedTask
from vllm.v1.pool.late_interaction import ( from vllm.v1.pool.late_interaction import (
build_late_interaction_doc_params, build_late_interaction_doc_params,
build_late_interaction_query_params, build_late_interaction_query_params,
...@@ -38,10 +39,15 @@ class ServingScores(PoolingServing): ...@@ -38,10 +39,15 @@ class ServingScores(PoolingServing):
self, self,
engine_client: EngineClient, engine_client: EngineClient,
*args, *args,
supported_tasks: tuple[SupportedTask, ...],
enable_flash_late_interaction: bool = True, enable_flash_late_interaction: bool = True,
**kwargs, **kwargs,
): ):
self.io_processor_name: str = engine_client.model_config.score_type pooling_task = engine_client.model_config.get_pooling_task(supported_tasks)
score_type = SCORE_TYPE_MAP.get(pooling_task, None) # type: ignore[arg-type]
assert score_type is not None
self.io_processor_name: str = score_type
self.enable_flash_late_interaction = ( self.enable_flash_late_interaction = (
self.io_processor_name == "late-interaction" self.io_processor_name == "late-interaction"
and enable_flash_late_interaction and enable_flash_late_interaction
......
...@@ -141,10 +141,14 @@ def enable_scoring_api( ...@@ -141,10 +141,14 @@ def enable_scoring_api(
supported_tasks: tuple["SupportedTask", ...], supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None, model_config: ModelConfig | None = None,
) -> bool: ) -> bool:
if any(t in supported_tasks for t in ("embed", "token_embed")): if model_config is None:
return False
pooling_task = model_config.get_pooling_task(supported_tasks)
if pooling_task in ("embed", "token_embed"):
return True return True
if model_config is not None and "classify" in supported_tasks: if pooling_task == "classify":
num_labels = getattr(model_config.hf_config, "num_labels", 0) num_labels = getattr(model_config.hf_config, "num_labels", 0)
if num_labels != 1: if num_labels != 1:
logger.debug_once("Scoring API is only enabled for num_labels == 1.") logger.debug_once("Scoring API is only enabled for num_labels == 1.")
......
...@@ -87,13 +87,6 @@ class PoolingParams( ...@@ -87,13 +87,6 @@ class PoolingParams(
return deepcopy(self) return deepcopy(self)
def verify(self, model_config: ModelConfig) -> None: def verify(self, model_config: ModelConfig) -> None:
if self.task == "score":
logger.warning_once(
"`score` task is deprecated and will be removed in v0.20. "
"Please use `classify` instead."
)
self.task = "classify"
# plugin task uses io_processor.parse_request to verify inputs, # plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify # skipping PoolingParams verify
if self.task == "plugin": if self.task == "plugin":
......
...@@ -16,6 +16,11 @@ PoolingTask = Literal[ ...@@ -16,6 +16,11 @@ PoolingTask = Literal[
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask) POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
ScoreType = Literal["bi-encoder", "cross-encoder", "late-interaction"] ScoreType = Literal["bi-encoder", "cross-encoder", "late-interaction"]
SCORE_TYPE_MAP: dict[PoolingTask, ScoreType] = {
"embed": "bi-encoder",
"classify": "cross-encoder",
"token_embed": "late-interaction",
}
FrontendTask = Literal["render"] FrontendTask = Literal["render"]
FRONTEND_TASKS: tuple[FrontendTask, ...] = get_args(FrontendTask) FRONTEND_TASKS: tuple[FrontendTask, ...] = get_args(FrontendTask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment