Unverified Commit d9d21eb8 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend][3/n] Improve pooling entrypoints | scoring. (#28631)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent f09daea2
...@@ -11,9 +11,7 @@ from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar ...@@ -11,9 +11,7 @@ from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
import numpy as np import numpy as np
from fastapi import Request from fastapi import Request
from openai.types.responses import ( from openai.types.responses import ToolChoiceFunction
ToolChoiceFunction,
)
from pydantic import ConfigDict, TypeAdapter, ValidationError from pydantic import ConfigDict, TypeAdapter, ValidationError
from starlette.datastructures import Headers from starlette.datastructures import Headers
...@@ -21,9 +19,7 @@ import vllm.envs as envs ...@@ -21,9 +19,7 @@ import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
BatchChatCompletionRequest, BatchChatCompletionRequest,
...@@ -42,9 +38,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -42,9 +38,7 @@ from vllm.entrypoints.openai.engine.protocol import (
GenerationError, GenerationError,
) )
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.responses.protocol import ( from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
ResponsesRequest,
)
from vllm.entrypoints.openai.speech_to_text.protocol import ( from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponse,
...@@ -56,14 +50,6 @@ from vllm.entrypoints.pooling.pooling.protocol import ( ...@@ -56,14 +50,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingCompletionRequest, PoolingCompletionRequest,
PoolingResponse, PoolingResponse,
) )
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreDataRequest,
ScoreQueriesDocumentsRequest,
ScoreRequest,
ScoreResponse,
ScoreTextRequest,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.serve.tokenize.protocol import ( from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest, DetokenizeRequest,
...@@ -72,8 +58,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -72,8 +58,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeResponse, TokenizeResponse,
) )
from vllm.entrypoints.utils import create_error_response from vllm.entrypoints.utils import create_error_response
from vllm.exceptions import VLLMValidationError from vllm.inputs import EngineInput, PromptType
from vllm.inputs import EngineInput, PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -119,8 +104,6 @@ CompletionLikeRequest: TypeAlias = ( ...@@ -119,8 +104,6 @@ CompletionLikeRequest: TypeAlias = (
CompletionRequest CompletionRequest
| TokenizeCompletionRequest | TokenizeCompletionRequest
| DetokenizeRequest | DetokenizeRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest | PoolingCompletionRequest
) )
...@@ -148,7 +131,6 @@ AnyResponse: TypeAlias = ( ...@@ -148,7 +131,6 @@ AnyResponse: TypeAlias = (
| TranscriptionResponse | TranscriptionResponse
| TokenizeResponse | TokenizeResponse
| PoolingResponse | PoolingResponse
| ScoreResponse
| GenerateResponse | GenerateResponse
) )
...@@ -692,88 +674,6 @@ class OpenAIServing: ...@@ -692,88 +674,6 @@ class OpenAIServing:
message_types.add(content_dict["type"].split("_")[0]) message_types.add(content_dict["type"].split("_")[0])
return message_types return message_types
def _validate_input(
self,
request: object,
input_ids: list[int],
input_text: str,
) -> TokensPrompt:
token_num = len(input_ids)
max_model_len = self.model_config.max_model_len
# Note: ScoreRequest doesn't have max_tokens
if isinstance(
request,
(
ScoreDataRequest,
ScoreTextRequest,
ScoreQueriesDocumentsRequest,
RerankRequest,
),
):
# Note: input length can be up to the entire model context length
# since these requests don't generate tokens.
if token_num > max_model_len:
operations: dict[type[AnyRequest], str] = {
ScoreDataRequest: "score",
ScoreTextRequest: "score",
ScoreQueriesDocumentsRequest: "score",
}
operation = operations.get(type(request), "embedding generation")
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input prompt.",
parameter="input_tokens",
value=token_num,
)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(
request,
(TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
):
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest):
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens = request.max_completion_tokens or request.max_tokens
else:
max_tokens = getattr(request, "max_tokens", None)
# Note: input length can be up to model context length - 1 for
# completion-like requests.
if token_num >= max_model_len:
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, your request has "
f"{token_num} input tokens. Please reduce the length of "
"the input messages.",
parameter="input_tokens",
value=token_num,
)
if max_tokens is not None and token_num + max_tokens > max_model_len:
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, you requested "
f"{max_tokens} output tokens and your prompt contains "
f"{token_num} input tokens, for a total of "
f"{token_num + max_tokens} tokens "
f"({token_num} + {max_tokens} = "
f"{token_num + max_tokens} > {max_model_len}). "
f"Please reduce the length of the input prompt or the "
f"number of requested output tokens.",
parameter="max_tokens",
value=max_tokens,
)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _validate_chat_template( def _validate_chat_template(
self, self,
request_chat_template: str | None, request_chat_template: str | None,
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import contextlib
import json
import sys import sys
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
...@@ -13,12 +15,14 @@ from urllib.parse import urlparse ...@@ -13,12 +15,14 @@ from urllib.parse import urlparse
import aiohttp import aiohttp
import pybase64 as base64 import pybase64 as base64
import pydantic
import torch import torch
from fastapi import UploadFile from fastapi import UploadFile
from prometheus_client import start_http_server from prometheus_client import start_http_server
from pydantic import Field, TypeAdapter, field_validator, model_validator from pydantic import Field, TypeAdapter, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from starlette.datastructures import State from starlette.datastructures import State
from starlette.responses import JSONResponse
from tqdm import tqdm from tqdm import tqdm
from urllib3.util import parse_url from urllib3.util import parse_url
...@@ -49,7 +53,7 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -49,7 +53,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingRequest, EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
) )
from vllm.entrypoints.pooling.score.protocol import ( from vllm.entrypoints.pooling.scoring.protocol import (
RerankRequest, RerankRequest,
RerankResponse, RerankResponse,
ScoreRequest, ScoreRequest,
...@@ -180,6 +184,18 @@ class BatchRequestInput(OpenAIBaseModel): ...@@ -180,6 +184,18 @@ class BatchRequestInput(OpenAIBaseModel):
return TypeAdapter(BatchRequestInputBody).validate_python(value) return TypeAdapter(BatchRequestInputBody).validate_python(value)
AllResponse: TypeAlias = (
ChatCompletionResponse
| EmbeddingResponse
| ScoreResponse
| RerankResponse
| TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
)
class BatchResponseData(OpenAIBaseModel): class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response. # HTTP status code of the response.
status_code: int = 200 status_code: int = 200
...@@ -188,17 +204,7 @@ class BatchResponseData(OpenAIBaseModel): ...@@ -188,17 +204,7 @@ class BatchResponseData(OpenAIBaseModel):
request_id: str request_id: str
# The body of the response. # The body of the response.
body: ( body: AllResponse | None = None
ChatCompletionResponse
| EmbeddingResponse
| ScoreResponse
| RerankResponse
| TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
| None
) = None
class BatchRequestOutput(OpenAIBaseModel): class BatchRequestOutput(OpenAIBaseModel):
...@@ -536,19 +542,13 @@ async def run_request( ...@@ -536,19 +542,13 @@ async def run_request(
except Exception as e: except Exception as e:
response = create_error_response(e) response = create_error_response(e)
if isinstance( if isinstance(response, JSONResponse):
response, with contextlib.suppress(pydantic.ValidationError):
( response = TypeAdapter(AllResponse | ErrorResponse).validate_python(
ChatCompletionResponse, json.loads(response.body)
EmbeddingResponse, )
ScoreResponse,
RerankResponse, if isinstance(response, AllResponse):
TranscriptionResponse,
TranscriptionResponseVerbose,
TranslationResponse,
TranslationResponseVerbose,
),
):
batch_output = BatchRequestOutput( batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}", id=f"vllm-{random_uuid()}",
custom_id=request.custom_id, custom_id=request.custom_id,
...@@ -745,14 +745,14 @@ async def build_endpoint_registry( ...@@ -745,14 +745,14 @@ async def build_endpoint_registry(
"score": { "score": {
"url_matcher": lambda url: url.endswith("/score"), "url_matcher": lambda url: url.endswith("/score"),
"handler_getter": lambda: ( "handler_getter": lambda: (
serving_scores.create_score if serving_scores is not None else None serving_scores if serving_scores is not None else None
), ),
"wrapper_fn": None, "wrapper_fn": None,
}, },
"rerank": { "rerank": {
"url_matcher": lambda url: url.endswith("/rerank"), "url_matcher": lambda url: url.endswith("/rerank"),
"handler_getter": lambda: ( "handler_getter": lambda: (
serving_scores.do_rerank if serving_scores is not None else None serving_scores if serving_scores is not None else None
), ),
"wrapper_fn": None, "wrapper_fn": None,
}, },
......
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
from fastapi import FastAPI from fastapi import FastAPI
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -23,23 +24,6 @@ else: ...@@ -23,23 +24,6 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
def enable_scoring_api(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
) -> bool:
if any(t in supported_tasks for t in ("embed", "token_embed")):
return True
if model_config is not None and "classify" in supported_tasks:
num_labels = getattr(model_config.hf_config, "num_labels", 0)
if num_labels != 1:
logger.debug_once("Score API is only enabled for num_labels == 1.")
return False
return True
return False
def register_pooling_api_routers( def register_pooling_api_routers(
app: FastAPI, app: FastAPI,
supported_tasks: tuple["SupportedTask", ...], supported_tasks: tuple["SupportedTask", ...],
...@@ -68,7 +52,7 @@ def register_pooling_api_routers( ...@@ -68,7 +52,7 @@ def register_pooling_api_routers(
app.include_router(embed_router) app.include_router(embed_router)
if enable_scoring_api(supported_tasks, model_config): if enable_scoring_api(supported_tasks, model_config):
from vllm.entrypoints.pooling.score.api_router import router as score_router from vllm.entrypoints.pooling.scoring.api_router import router as score_router
app.include_router(score_router) app.include_router(score_router)
...@@ -84,7 +68,7 @@ def init_pooling_state( ...@@ -84,7 +68,7 @@ def init_pooling_state(
from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores from vllm.entrypoints.pooling.scoring.serving import ServingScores
from vllm.tasks import POOLING_TASKS from vllm.tasks import POOLING_TASKS
model_config = engine_client.model_config model_config = engine_client.model_config
...@@ -136,8 +120,9 @@ def init_pooling_state( ...@@ -136,8 +120,9 @@ def init_pooling_state(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
score_template=resolved_chat_template, chat_template=resolved_chat_template,
log_error_stack=args.log_error_stack, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
) )
if enable_scoring_api(supported_tasks, model_config) if enable_scoring_api(supported_tasks, model_config)
else None else None
......
...@@ -13,13 +13,16 @@ from vllm.entrypoints.chat_utils import ( ...@@ -13,13 +13,16 @@ from vllm.entrypoints.chat_utils import (
ConversationMessage, ConversationMessage,
) )
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
from vllm.entrypoints.pooling.scoring.typing import ScoringData
from vllm.entrypoints.pooling.typing import ( from vllm.entrypoints.pooling.typing import (
OfflineInputsContext,
OfflineOutputsContext,
PoolingChatLikeRequest, PoolingChatLikeRequest,
PoolingCompletionLikeRequest, PoolingCompletionLikeRequest,
PoolingServeContext, PoolingServeContext,
) )
from vllm.inputs import EngineInput, SingletonPrompt from vllm.inputs import EngineInput, SingletonPrompt
from vllm.renderers import BaseRenderer, merge_kwargs from vllm.renderers import BaseRenderer, TokenizeParams, merge_kwargs
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
...@@ -96,29 +99,29 @@ class PoolingIOProcessor: ...@@ -96,29 +99,29 @@ class PoolingIOProcessor:
####################################### #######################################
# offline APIs # offline APIs
def pre_process_offline( def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
self, assert not isinstance(ctx.prompts, ScoringData)
prompts: PromptType | Sequence[PromptType], tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
tokenization_kwargs: dict[str, Any] | None = None, **(ctx.tokenization_kwargs or {})
) -> Sequence[EngineInput]: )
return self._preprocess_completion_offline( return self._preprocess_completion_offline(
prompts=prompts, tokenization_kwargs=tokenization_kwargs prompts=ctx.prompts, tok_params=tok_params
) )
async def pre_process_offline_async(self, *args, **kwargs): async def pre_process_offline_async(self, ctx: OfflineInputsContext):
return self.pre_process_offline(*args, **kwargs) return self.pre_process_offline(ctx)
def post_process_offline( def post_process_offline(
self, self,
outputs: list[PoolingRequestOutput], ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
return outputs return ctx.outputs
async def post_process_offline_async( async def post_process_offline_async(
self, self,
outputs: list[PoolingRequestOutput], ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
return self.post_process_offline(outputs) return self.post_process_offline(ctx)
####################################### #######################################
# helpers # helpers
...@@ -204,28 +207,21 @@ class PoolingIOProcessor: ...@@ -204,28 +207,21 @@ class PoolingIOProcessor:
def _preprocess_completion_offline( def _preprocess_completion_offline(
self, self,
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None, tok_params: TokenizeParams,
prompt_extras: dict[str, Any] | None = None,
) -> Sequence[EngineInput]: ) -> Sequence[EngineInput]:
renderer = self.renderer
model_config = self.model_config
prompts = prompt_to_seq(prompts) prompts = prompt_to_seq(prompts)
parsed_prompts = [ parsed_prompts = [
( (
prompt prompt
if isinstance(prompt, bytes) if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt) else parse_model_prompt(self.model_config, prompt)
) )
for prompt in prompts for prompt in prompts
] ]
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
return renderer.render_cmpl( return self.renderer.render_cmpl(
parsed_prompts, parsed_prompts, tok_params, prompt_extras=prompt_extras
tok_params,
) )
def _validate_chat_template( def _validate_chat_template(
......
...@@ -117,8 +117,16 @@ class PoolingServing: ...@@ -117,8 +117,16 @@ class PoolingServing:
else await self._get_trace_headers(ctx.raw_request.headers) else await self._get_trace_headers(ctx.raw_request.headers)
) )
pooling_params = self.io_processor.create_pooling_params(ctx.request) if ctx.pooling_params is None:
pooling_params.verify(self.model_config) pooling_params = self.io_processor.create_pooling_params(ctx.request)
else:
pooling_params = ctx.pooling_params
if isinstance(pooling_params, list):
for params in pooling_params:
params.verify(self.model_config)
else:
pooling_params.verify(self.model_config)
for i, engine_input in enumerate(ctx.engine_inputs): for i, engine_input in enumerate(ctx.engine_inputs):
prompt_request_id = ( prompt_request_id = (
...@@ -127,16 +135,22 @@ class PoolingServing: ...@@ -127,16 +135,22 @@ class PoolingServing:
else ctx.prompt_request_ids[i] else ctx.prompt_request_ids[i]
) )
params = (
pooling_params[i]
if isinstance(pooling_params, list)
else pooling_params
)
self._log_inputs( self._log_inputs(
prompt_request_id, prompt_request_id,
engine_input, engine_input,
params=pooling_params, params=params,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_input, engine_input,
pooling_params, params,
prompt_request_id, prompt_request_id,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessors
from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
...@@ -25,6 +27,11 @@ def init_pooling_io_processors( ...@@ -25,6 +27,11 @@ def init_pooling_io_processors(
processors.append(("embed", EmbedIOProcessor)) processors.append(("embed", EmbedIOProcessor))
if enable_scoring_api(supported_tasks, model_config):
score_type = model_config.score_type
if score_type is not None and score_type in ScoringIOProcessors:
processors.append((score_type, ScoringIOProcessors[score_type]))
return { return {
task: processor_cls( task: processor_cls(
model_config=model_config, model_config=model_config,
......
This diff is collapsed.
...@@ -3,21 +3,15 @@ ...@@ -3,21 +3,15 @@
from http import HTTPStatus from http import HTTPStatus
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
RerankResponse,
ScoreRequest,
ScoreResponse,
)
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger from vllm.logger import init_logger
from .protocol import RerankRequest, ScoreRequest
from .serving import ServingScores
router = APIRouter() router = APIRouter()
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -46,16 +40,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): ...@@ -46,16 +40,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
if handler is None: if handler is None:
raise NotImplementedError("The model does not support Score API") raise NotImplementedError("The model does not support Score API")
generator = await handler.create_score(request, raw_request) return await handler(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, ScoreResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post( @router.post(
...@@ -92,16 +77,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request): ...@@ -92,16 +77,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
if handler is None: if handler is None:
raise NotImplementedError("The model does not support Rerank (Score) API") raise NotImplementedError("The model does not support Rerank (Score) API")
generator = await handler.do_rerank(request, raw_request) return await handler(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, RerankResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post( @router.post(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Sequence
from typing import Any, TypeAlias, cast
import torch.nn.functional as F
from vllm import PoolingParams, PoolingRequestOutput, TokensPrompt
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.typing import (
OfflineInputsContext,
OfflineOutputsContext,
PoolingServeContext,
)
from vllm.inputs import EngineInput
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tasks import PoolingTask, ScoreType
from vllm.utils.mistral import is_mistral_tokenizer
from ...chat_utils import ChatTemplateResolutionError
from .protocol import RerankRequest, ScoreRequest, ScoringRequest
from .typing import ScoreData, ScoreInput, ScoringData
from .utils import (
compress_token_type_ids,
compute_maxsim_score,
parse_score_data,
score_data_to_prompts,
validate_score_input,
)
ScoringServeContext: TypeAlias = PoolingServeContext[ScoringRequest]
class ScoringIOProcessor(PoolingIOProcessor):
name: ScoreType
pooling_task: PoolingTask
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = self.renderer.get_tokenizer()
self.architecture = self.model_config.architecture
self.is_multimodal_model = self.model_config.is_multimodal_model
self.pad_token_id = self.tokenizer.pad_token_id
def create_pooling_params(self, request):
return request.to_pooling_params(self.pooling_task)
def valid_inputs(
self,
data_1: ScoreInput | list[ScoreInput],
data_2: ScoreInput | list[ScoreInput],
) -> ScoringData:
scoring_data = validate_score_input(
data_1,
data_2,
is_multimodal_model=self.is_multimodal_model,
architecture=self.architecture,
)
return scoring_data
class BiEncoderIOProcessor(ScoringIOProcessor):
name: ScoreType = "bi-encoder"
pooling_task: PoolingTask = "embed"
#######################################
# online APIs
def pre_process_online(self, ctx: ScoringServeContext):
request = ctx.request
if isinstance(request, ScoreRequest):
data_1 = request.data_1
data_2 = request.data_2
elif isinstance(request, RerankRequest):
data_1 = request.query
data_2 = request.documents
else:
raise ValueError(f"Invalid {self.name} request type")
scoring_data = self.valid_inputs(data_1, data_2)
tok_params = request.build_tok_params(self.model_config)
engine_inputs = self._pre_process(
scoring_data,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
ctx.engine_inputs = engine_inputs
ctx.intermediates = len(scoring_data.data_1)
def post_process_online(
self,
ctx: ScoringServeContext,
):
if ctx.final_res_batch is None:
raise ValueError("Final response batch not available")
if ctx.intermediates is None:
raise ValueError("data_1 len not available")
ctx.final_res_batch = self._post_process(
outputs=ctx.final_res_batch, offset=cast(int, ctx.intermediates)
)
#######################################
# offline APIs
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert isinstance(ctx.prompts, ScoringData)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
return self._pre_process(ctx.prompts, tok_params)
def post_process_offline(
self,
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
assert ctx.offset is not None
return self._post_process(outputs=ctx.outputs, offset=ctx.offset)
#######################################
# helpers
def _pre_process(
self,
scoring_data: ScoringData,
tok_params: TokenizeParams,
prompt_extras: dict[str, Any] | None = None,
) -> Sequence[EngineInput]:
data_1 = score_data_to_prompts(scoring_data.data_1, "query", self.model_config)
data_2 = score_data_to_prompts(
scoring_data.data_2, "document", self.model_config
)
return self._preprocess_completion_offline(
prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras
)
def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
emb_data_1 = outputs[:offset]
emb_data_2 = outputs[offset:]
if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
final_res_batch: list[PoolingRequestOutput] = []
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
pair_score = F.cosine_similarity(
emb_1.outputs.data.float(), emb_2.outputs.data.float(), dim=0
)
padding: list[int] = []
if self.pad_token_id is not None:
padding = [self.pad_token_id]
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
final_res_batch.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
return final_res_batch
class LateInteractionIOProcessor(BiEncoderIOProcessor):
name: ScoreType = "late-interaction"
pooling_task: PoolingTask = "token_embed"
def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
# Split into query and document embeddings
emb_data_1 = outputs[:offset]
emb_data_2 = outputs[offset:]
# Expand queries if 1:N scoring
if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
final_res_batch: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := self.pad_token_id) is not None:
padding = [pad_token_id]
# Compute MaxSim scores
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
final_res_batch.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=maxsim_score,
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
return final_res_batch
class CrossEncoderIOProcessor(ScoringIOProcessor):
name: ScoreType = "cross-encoder"
pooling_task: PoolingTask = "classify"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if is_mistral_tokenizer(self.tokenizer):
raise ValueError("MistralTokenizer not supported for cross-encoding")
from vllm.model_executor.model_loader import get_model_cls
from vllm.model_executor.models.interfaces import supports_score_template
model = get_model_cls(self.model_config)
self.supports_score_template = supports_score_template(model)
self.model = model if self.supports_score_template else None
self.use_sep_token = self.model_config.use_sep_token
#######################################
# online APIs
def pre_process_online(self, ctx: ScoringServeContext):
request = ctx.request
if isinstance(request, ScoreRequest):
data_1 = request.data_1
data_2 = request.data_2
elif isinstance(request, RerankRequest):
data_1 = request.query
data_2 = request.documents
else:
raise ValueError(f"Invalid {self.name} request type")
scoring_data = self.valid_inputs(data_1, data_2)
tok_params = request.build_tok_params(self.model_config)
pooling_params = self.create_pooling_params(request)
engine_inputs, pooling_params_list = self._pre_process(
scoring_data,
tok_params,
pooling_params,
chat_template=self.chat_template,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
ctx.engine_inputs = engine_inputs
ctx.pooling_params = pooling_params_list
#######################################
# offline APIs
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert isinstance(ctx.prompts, ScoringData)
assert not isinstance(ctx.pooling_params, list)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
engine_inputs, pooling_params_list = self._pre_process(
ctx.prompts, tok_params, ctx.pooling_params, ctx.chat_template
)
ctx.pooling_params = pooling_params_list
return engine_inputs
#######################################
# helpers
def _pre_process(
self,
scoring_data: ScoringData,
tok_params: TokenizeParams,
pooling_params: PoolingParams | None,
chat_template: str | None = None,
prompt_extras: dict[str, Any] | None = None,
) -> tuple[Sequence[EngineInput], list[PoolingParams]]:
# todo: support prompt_extras
arrival_time = time.time()
data_1 = scoring_data.data_1
data_2 = scoring_data.data_2
if len(data_1) == 1:
data_1 = data_1 * len(data_2)
if pooling_params is None:
pooling_params = PoolingParams(task="classify")
pooling_params_list = list[PoolingParams]()
engine_inputs = list[EngineInput]()
for q, d in zip(data_1, data_2):
_, engine_prompt = self.get_score_prompt(
data_1=q,
data_2=d,
encode_kwargs=tok_params.get_encode_kwargs(),
chat_template=chat_template,
)
if token_type_ids := engine_prompt.pop("token_type_ids", None):
params = pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids)
params.extra_kwargs = {"compressed_token_type_ids": compressed}
pooling_params_list.append(params)
else:
pooling_params_list.append(pooling_params)
tok_params.apply_post_tokenization(self.tokenizer, engine_prompt)
engine_inputs.append(
self.renderer.process_for_engine(engine_prompt, arrival_time)
)
return engine_inputs, pooling_params_list
def get_score_prompt(
self,
data_1: ScoreData,
data_2: ScoreData,
encode_kwargs: dict[str, Any],
chat_template: str | None = None,
):
model_config = self.model_config
tokenizer = self.tokenizer
prompt_1, prompt_2, mm_data, mm_uuids = parse_score_data(
data_1,
data_2,
model_config,
)
def default_tokenizer_encode():
if self.supports_score_template:
assert self.model is not None
full_prompt = self.model.get_score_template(prompt_1, prompt_2)
if full_prompt is None:
raise ValueError("Get empty score template from model")
prompt_inputs = tokenizer(full_prompt, **encode_kwargs)
else:
if self.use_sep_token:
# cross_encoder models defaults to using separating token.
prompt_inputs = tokenizer(
text=prompt_1, text_pair=prompt_2, **encode_kwargs
)
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
else:
# `llm as reranker` defaults to not using separating token.
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **encode_kwargs)
return full_prompt, prompt_inputs
# FIXME: For now, we only apply a template when one is explicitly provided.
# We cannot rely on the tokenizer's chat template because many models
# inherit junk templates from their base LLM, which breaks both the models
# and the tests that use them.
if chat_template is None:
full_prompt, prompt_inputs = default_tokenizer_encode()
else:
# FIXME:
# Try applying a score template from the CLI arg or tokenizer_config.json
# If that fails because there is no such template,
# fall back to the default implementation.
try:
full_prompt = safe_apply_chat_template(
model_config,
tokenizer,
[
{"role": "query", "content": prompt_1},
{"role": "document", "content": prompt_2},
],
chat_template=chat_template,
tools=None,
tokenize=False,
)
prompt_inputs = tokenizer(full_prompt, **encode_kwargs)
except ChatTemplateResolutionError:
full_prompt, prompt_inputs = default_tokenizer_encode()
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
engine_prompt["token_type_ids"] = token_type_ids
if self.model is not None:
self.model.post_process_tokens(engine_prompt)
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
return full_prompt, engine_prompt
ScoringIOProcessors: dict[ScoreType, type[ScoringIOProcessor]] = {
"bi-encoder": BiEncoderIOProcessor,
"late-interaction": LateInteractionIOProcessor,
"cross-encoder": CrossEncoderIOProcessor,
}
...@@ -12,15 +12,12 @@ from vllm.entrypoints.pooling.base.protocol import ( ...@@ -12,15 +12,12 @@ from vllm.entrypoints.pooling.base.protocol import (
ClassifyRequestMixin, ClassifyRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
from vllm.entrypoints.pooling.score.utils import (
ScoreContentPartParam,
ScoreInput,
ScoreInputs,
)
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.utils import random_uuid from vllm.utils import random_uuid
from .typing import ScoreContentPartParam, ScoreInput
class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
...@@ -43,13 +40,13 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): ...@@ -43,13 +40,13 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
class ScoreDataRequest(ScoreRequestMixin): class ScoreDataRequest(ScoreRequestMixin):
data_1: ScoreInputs data_1: ScoreInput | list[ScoreInput]
data_2: ScoreInputs data_2: ScoreInput | list[ScoreInput]
class ScoreQueriesDocumentsRequest(ScoreRequestMixin): class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
queries: ScoreInputs queries: ScoreInput | list[ScoreInput]
documents: ScoreInputs documents: ScoreInput | list[ScoreInput]
@property @property
def data_1(self): def data_1(self):
...@@ -61,8 +58,8 @@ class ScoreQueriesDocumentsRequest(ScoreRequestMixin): ...@@ -61,8 +58,8 @@ class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
class ScoreQueriesItemsRequest(ScoreRequestMixin): class ScoreQueriesItemsRequest(ScoreRequestMixin):
queries: ScoreInputs queries: ScoreInput | list[ScoreInput]
items: ScoreInputs items: ScoreInput | list[ScoreInput]
@property @property
def data_1(self): def data_1(self):
...@@ -74,8 +71,8 @@ class ScoreQueriesItemsRequest(ScoreRequestMixin): ...@@ -74,8 +71,8 @@ class ScoreQueriesItemsRequest(ScoreRequestMixin):
class ScoreTextRequest(ScoreRequestMixin): class ScoreTextRequest(ScoreRequestMixin):
text_1: ScoreInputs text_1: ScoreInput | list[ScoreInput]
text_2: ScoreInputs text_2: ScoreInput | list[ScoreInput]
@property @property
def data_1(self): def data_1(self):
...@@ -96,7 +93,7 @@ ScoreRequest: TypeAlias = ( ...@@ -96,7 +93,7 @@ ScoreRequest: TypeAlias = (
class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin): class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
query: ScoreInput query: ScoreInput
documents: ScoreInputs documents: ScoreInput | list[ScoreInput]
top_n: int = Field(default_factory=lambda: 0) top_n: int = Field(default_factory=lambda: 0)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
...@@ -118,6 +115,9 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin): ...@@ -118,6 +115,9 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
) )
ScoringRequest: TypeAlias = ScoreRequest | RerankRequest
class RerankDocument(BaseModel): class RerankDocument(BaseModel):
text: str | None = None text: str | None = None
multi_modal: list[ScoreContentPartParam] | None = None multi_modal: list[ScoreContentPartParam] | None = None
...@@ -154,3 +154,6 @@ class ScoreResponse(OpenAIBaseModel): ...@@ -154,3 +154,6 @@ class ScoreResponse(OpenAIBaseModel):
model: str model: str
data: list[ScoreResponseData] data: list[ScoreResponseData]
usage: UsageInfo usage: UsageInfo
ScoringResponse: TypeAlias = RerankResponse | ScoreResponse
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi.responses import JSONResponse
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.renderers import BaseRenderer
from .io_processor import ScoringIOProcessors, ScoringServeContext
from .protocol import (
RerankDocument,
RerankRequest,
RerankResponse,
RerankResult,
RerankUsage,
ScoreRequest,
ScoreResponse,
ScoreResponseData,
)
from .typing import ScoreInput
logger = init_logger(__name__)
class ServingScores(PoolingServing):
request_id_prefix = "score"
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
score_type = model_config.score_type
assert score_type in ScoringIOProcessors
processor_cls = ScoringIOProcessors[score_type]
return processor_cls(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
async def _build_response(
self,
ctx: ScoringServeContext,
) -> JSONResponse:
final_res_batch = ctx.final_res_batch
request_id = ctx.request_id
created_time = ctx.created_time
model_name = self.models.model_name()
if isinstance(ctx.request, ScoreRequest):
return self._request_output_to_score_response(
final_res_batch,
request_id,
created_time,
model_name,
)
elif isinstance(ctx.request, RerankRequest):
return self._request_output_to_rerank_response(
final_res_batch,
request_id,
model_name,
ctx.request.documents,
ctx.request.top_n if ctx.request.top_n > 0 else len(final_res_batch),
)
else:
raise NotImplementedError("")
def _request_output_to_score_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> JSONResponse:
items: list[ScoreResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
item = ScoreResponseData(
index=idx,
score=classify_res.outputs.score,
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
response = ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
return JSONResponse(content=response.model_dump())
def _request_output_to_rerank_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
model_name: str,
documents: ScoreInput | list[ScoreInput],
top_n: int,
) -> JSONResponse:
if not isinstance(documents, list):
documents = [documents]
results: list[RerankResult] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
document = documents[idx]
if isinstance(document, str):
rerank_document = RerankDocument(text=document)
else:
rerank_document = RerankDocument(
multi_modal=document.get("content", [])
)
result = RerankResult(
index=idx,
document=rerank_document,
relevance_score=classify_res.outputs.score,
)
results.append(result)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
# sort by relevance, then return the top n if set
results.sort(key=lambda x: x.relevance_score, reverse=True)
if top_n < len(documents):
results = results[:top_n]
response = RerankResponse(
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(
total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens
),
)
return JSONResponse(content=response.model_dump())
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TypeAlias
from typing_extensions import Required, TypedDict
from vllm.entrypoints.chat_utils import (
ChatCompletionContentPartImageEmbedsParam,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam,
ChatCompletionContentPartVideoParam,
)
ScoreContentPartParam: TypeAlias = (
ChatCompletionContentPartImageParam
| ChatCompletionContentPartImageEmbedsParam
| ChatCompletionContentPartTextParam
| ChatCompletionContentPartVideoParam
)
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content
The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
2. Including chat-specific fields would confuse users about their purpose in scoring
3. This is a more focused interface that only exposes what's needed for scoring
""" # noqa: E501
content: Required[list[ScoreContentPartParam]]
"""The multimodal contents"""
# Raw input data with content key in ScoreMultiModalParam.
ScoreInput = str | ScoreMultiModalParam
# Score data without content key.
ScoreData = str | list[ScoreContentPartParam]
@dataclass
class ScoringData:
data_1: list[ScoreData]
data_2: list[ScoreData]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, TypeAlias, cast from typing import cast
import torch import torch
from torch.nn import CosineSimilarity
from typing_extensions import Required, TypedDict
from vllm import PromptType, TextPrompt
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
BaseMultiModalItemTracker, BaseMultiModalItemTracker,
ChatCompletionContentPartImageEmbedsParam,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartParam, ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam, ChatCompletionContentPartTextParam,
ChatCompletionContentPartVideoParam,
ChatTemplateResolutionError,
ConversationMessage, ConversationMessage,
MultiModalItemTracker, MultiModalItemTracker,
_parse_chat_message_content_parts, _parse_chat_message_content_parts,
) )
from vllm.inputs import ( from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
MultiModalDataDict,
MultiModalUUIDDict, from .typing import (
PromptType, ScoreContentPartParam,
TextPrompt, ScoreData,
TokensPrompt, ScoreInput,
) ScoringData,
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.outputs import PoolingRequestOutput
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike
ScoreContentPartParam: TypeAlias = (
ChatCompletionContentPartImageParam
| ChatCompletionContentPartImageEmbedsParam
| ChatCompletionContentPartTextParam
| ChatCompletionContentPartVideoParam
) )
...@@ -57,72 +42,6 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens ...@@ -57,72 +42,6 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return token_scores.amax(dim=-1).sum() return token_scores.amax(dim=-1).sum()
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content
The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
2. Including chat-specific fields would confuse users about their purpose in scoring
3. This is a more focused interface that only exposes what's needed for scoring
""" # noqa: E501
content: Required[list[ScoreContentPartParam]]
"""The multimodal contents"""
# Raw input data with content key in ScoreMultiModalParam.
ScoreInput = str | ScoreMultiModalParam
ScoreInputs = ScoreInput | list[ScoreInput]
# Score data without content key.
ScoreData = str | list[ScoreContentPartParam]
def _cosine_similarity(
tokenizer: TokenizerLike,
embed_1: list[PoolingRequestOutput],
embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]:
scorer = CosineSimilarity(0)
scores: list[PoolingRequestOutput] = []
for emb_1, emb_2 in zip(embed_1, embed_2):
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
return scores
def _validate_score_input_lens(
data_1: list[ScoreData],
data_2: list[ScoreData],
):
len_1 = len(data_1)
len_2 = len(data_2)
if len_1 > 1 and len_1 != len_2:
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len_1 == 0:
raise ValueError("At least one text element must be given")
if len_2 == 0:
raise ValueError("At least one text_pair element must be given")
def _validate_mm_score_input( def _validate_mm_score_input(
data: list[ScoreInput], data: list[ScoreInput],
is_multimodal_model: bool, is_multimodal_model: bool,
...@@ -140,12 +59,27 @@ def _validate_mm_score_input( ...@@ -140,12 +59,27 @@ def _validate_mm_score_input(
return out return out
def _validate_score_input_lens(
data_1: list[ScoreData],
data_2: list[ScoreData],
):
len_1 = len(data_1)
len_2 = len(data_2)
if len_1 > 1 and len_1 != len_2:
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len_1 == 0:
raise ValueError("At least one text element must be given")
if len_2 == 0:
raise ValueError("At least one text_pair element must be given")
def validate_score_input( def validate_score_input(
data_1: ScoreInputs, data_1: ScoreInput | list[ScoreInput],
data_2: ScoreInputs, data_2: ScoreInput | list[ScoreInput],
is_multimodal_model: bool, is_multimodal_model: bool,
architecture: str, architecture: str,
) -> tuple[list[ScoreData], list[ScoreData]]: ) -> ScoringData:
if not isinstance(data_1, list): if not isinstance(data_1, list):
data_1 = [data_1] data_1 = [data_1]
...@@ -155,62 +89,7 @@ def validate_score_input( ...@@ -155,62 +89,7 @@ def validate_score_input(
score_input_1 = _validate_mm_score_input(data_1, is_multimodal_model, architecture) score_input_1 = _validate_mm_score_input(data_1, is_multimodal_model, architecture)
score_input_2 = _validate_mm_score_input(data_2, is_multimodal_model, architecture) score_input_2 = _validate_mm_score_input(data_2, is_multimodal_model, architecture)
_validate_score_input_lens(score_input_1, score_input_2) _validate_score_input_lens(score_input_1, score_input_2)
return score_input_1, score_input_2 return ScoringData(data_1=score_input_1, data_2=score_input_2)
def _ensure_str(content: list[ConversationMessage]) -> str:
"""Extract a single string prompt from parsed conversation content."""
assert len(content) == 1
prompt = content[0]["content"]
if prompt is not None and isinstance(prompt, str):
return cast(str, prompt)
raise ValueError(f"Only string content is supported, but got {content}.")
def parse_score_data(
data_1: ScoreData,
data_2: ScoreData,
model_config: ModelConfig,
) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
"""Parse a query-document pair into text prompts and shared multi-modal
data.
Uses a **single** :class:`MultiModalItemTracker` so that multi-modal
items from both inputs are merged into one ``mm_data`` dict. This is
the correct behaviour for cross-encoder scoring, where query and
document are concatenated into a single model prompt.
"""
mm_tracker = MultiModalItemTracker(model_config)
content_1 = _parse_score_content("query", data_1, mm_tracker)
content_2 = _parse_score_content("document", data_2, mm_tracker)
prompt_1 = _ensure_str(content_1)
prompt_2 = _ensure_str(content_2)
mm_items, mm_uuids = mm_tracker.resolve_items()
return prompt_1, prompt_2, mm_items, mm_uuids
def parse_score_data_single(
data: ScoreData,
role: str,
model_config: ModelConfig,
) -> tuple[str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
"""Parse **one** ScoreData into a text prompt and its own multi-modal
data.
Unlike :func:`parse_score_data`, each call creates an **independent**
:class:`MultiModalItemTracker` so multi-modal items are kept separate.
This is the correct behaviour for late-interaction scoring, where
query and document are encoded independently.
"""
mm_tracker = MultiModalItemTracker(model_config)
content = _parse_score_content(role, data, mm_tracker)
prompt = _ensure_str(content)
mm_items, mm_uuids = mm_tracker.resolve_items()
return prompt, mm_items, mm_uuids
def score_data_to_prompts( def score_data_to_prompts(
...@@ -243,6 +122,15 @@ def score_data_to_prompts( ...@@ -243,6 +122,15 @@ def score_data_to_prompts(
return prompts return prompts
def _ensure_str(content: list[ConversationMessage]) -> str:
"""Extract a single string prompt from parsed conversation content."""
assert len(content) == 1
prompt = content[0]["content"]
if prompt is not None and isinstance(prompt, str):
return cast(str, prompt)
raise ValueError(f"Only string content is supported, but got {content}.")
def _parse_score_content( def _parse_score_content(
role: str, role: str,
data: ScoreData, data: ScoreData,
...@@ -278,113 +166,50 @@ def _parse_score_content( ...@@ -278,113 +166,50 @@ def _parse_score_content(
return next(iter(mm_placeholder_storage.values()))[0] return next(iter(mm_placeholder_storage.values()))[0]
def _apply_model_score_template( def parse_score_data_single(
model_config: ModelConfig, prompt_1: str, prompt_2: str data: ScoreData,
) -> str: role: str,
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
from vllm.model_executor.model_loader import get_model_cls
model = get_model_cls(model_config)
if supports_score_template(model):
full_prompt = model.get_score_template(prompt_1, prompt_2)
if full_prompt is None:
raise ValueError("Get empty score template from model")
return full_prompt
raise ValueError(f"Unsupported model architecture: {model_config.architecture}")
def post_process_tokens(
model_config: ModelConfig, model_config: ModelConfig,
prompt: TokensPrompt, ) -> tuple[str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
) -> None: """Parse **one** ScoreData into a text prompt and its own multi-modal
""" data.
Perform architecture-specific manipulations on the input tokens.
Note: Unlike :func:`parse_score_data`, each call creates an **independent**
This is an in-place operation. :class:`MultiModalItemTracker` so multi-modal items are kept separate.
This is the correct behaviour for late-interaction scoring, where
query and document are encoded independently.
""" """
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf) mm_tracker = MultiModalItemTracker(model_config)
from vllm.model_executor.model_loader import get_model_cls content = _parse_score_content(role, data, mm_tracker)
model = get_model_cls(model_config) prompt = _ensure_str(content)
if supports_score_template(model): mm_items, mm_uuids = mm_tracker.resolve_items()
model.post_process_tokens(prompt) return prompt, mm_items, mm_uuids
def get_score_prompt( def parse_score_data(
model_config: ModelConfig,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
data_1: ScoreData, data_1: ScoreData,
data_2: ScoreData, data_2: ScoreData,
score_template: str | None = None, model_config: ModelConfig,
) -> tuple[str, TokensPrompt]: ) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
prompt_1, prompt_2, mm_data, mm_uuids = parse_score_data( """Parse a query-document pair into text prompts and shared multi-modal
data_1, data.
data_2,
model_config,
)
from vllm.model_executor.model_loader import get_model_cls
model = get_model_cls(model_config) Uses a **single** :class:`MultiModalItemTracker` so that multi-modal
items from both inputs are merged into one ``mm_data`` dict. This is
the correct behaviour for cross-encoder scoring, where query and
document are concatenated into a single model prompt.
"""
mm_tracker = MultiModalItemTracker(model_config)
def default_tokenizer_encode(): content_1 = _parse_score_content("query", data_1, mm_tracker)
if supports_score_template(model): content_2 = _parse_score_content("document", data_2, mm_tracker)
full_prompt = _apply_model_score_template(model_config, prompt_1, prompt_2)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) prompt_1 = _ensure_str(content_1)
else: prompt_2 = _ensure_str(content_2)
if model_config.use_sep_token: mm_items, mm_uuids = mm_tracker.resolve_items()
# cross_encoder models defaults to using separating token.
prompt_inputs = tokenizer( return prompt_1, prompt_2, mm_items, mm_uuids
text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
)
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
else:
# `llm as reranker` defaults to not using separating token.
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
return full_prompt, prompt_inputs
# FIXME: For now, we only apply a template when one is explicitly provided.
# We cannot rely on the tokenizer's chat template because many models
# inherit junk templates from their base LLM, which breaks both the models
# and the tests that use them.
if score_template is None:
full_prompt, prompt_inputs = default_tokenizer_encode()
else:
# FIXME: Try applying a score template from the CLI arg or tokenizer_config.json
# If that fails because there is no such template,
# fall back to the default implementation.
try:
full_prompt = safe_apply_chat_template(
model_config,
tokenizer,
[
{"role": "query", "content": prompt_1},
{"role": "document", "content": prompt_2},
],
chat_template=score_template,
tools=None,
tokenize=False,
)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
except ChatTemplateResolutionError:
full_prompt, prompt_inputs = default_tokenizer_encode()
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
engine_prompt["token_type_ids"] = token_type_ids
post_process_tokens(model_config, engine_prompt)
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
return full_prompt, engine_prompt
def compress_token_type_ids(token_type_ids: list[int]) -> int: def compress_token_type_ids(token_type_ids: list[int]) -> int:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Generic, TypeAlias, TypeVar from typing import Any, Generic, TypeAlias, TypeVar
from fastapi import Request from fastapi import Request
from pydantic import ConfigDict from pydantic import ConfigDict
from vllm import PoolingRequestOutput from vllm import PoolingParams, PoolingRequestOutput, PromptType
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
ClassificationCompletionRequest, ClassificationCompletionRequest,
...@@ -23,15 +23,13 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -23,15 +23,13 @@ from vllm.entrypoints.pooling.embed.protocol import (
) )
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest, IOProcessorRequest,
PoolingBytesResponse,
PoolingChatRequest, PoolingChatRequest,
PoolingCompletionRequest, PoolingCompletionRequest,
PoolingResponse, PoolingResponse,
) )
from vllm.entrypoints.pooling.score.protocol import ( from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse
RerankRequest, from vllm.entrypoints.pooling.scoring.typing import ScoringData
ScoreRequest,
ScoreResponse,
)
from vllm.inputs import EngineInput from vllm.inputs import EngineInput
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -49,8 +47,7 @@ AnyPoolingRequest: TypeAlias = ( ...@@ -49,8 +47,7 @@ AnyPoolingRequest: TypeAlias = (
PoolingCompletionLikeRequest PoolingCompletionLikeRequest
| PoolingChatLikeRequest | PoolingChatLikeRequest
| IOProcessorRequest | IOProcessorRequest
| RerankRequest | ScoringRequest
| ScoreRequest
| CohereEmbedRequest | CohereEmbedRequest
) )
...@@ -59,7 +56,8 @@ AnyPoolingResponse: TypeAlias = ( ...@@ -59,7 +56,8 @@ AnyPoolingResponse: TypeAlias = (
| EmbeddingResponse | EmbeddingResponse
| EmbeddingBytesResponse | EmbeddingBytesResponse
| PoolingResponse | PoolingResponse
| ScoreResponse | PoolingBytesResponse
| ScoringResponse
) )
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest) PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
...@@ -73,8 +71,8 @@ class PoolingServeContext(Generic[PoolingRequestT]): ...@@ -73,8 +71,8 @@ class PoolingServeContext(Generic[PoolingRequestT]):
request_id: str request_id: str
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
pooling_params: PoolingParams | list[PoolingParams] | None = None
engine_inputs: list[EngineInput] | None = None engine_inputs: Sequence[EngineInput] | None = None
prompt_request_ids: list[str] | None = None prompt_request_ids: list[str] | None = None
intermediates: Any | None = None intermediates: Any | None = None
...@@ -84,3 +82,22 @@ class PoolingServeContext(Generic[PoolingRequestT]): ...@@ -84,3 +82,22 @@ class PoolingServeContext(Generic[PoolingRequestT]):
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list) final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass
class OfflineInputsContext:
prompts: PromptType | Sequence[PromptType] | ScoringData
pooling_params: PoolingParams | list[PoolingParams] | None = None
tokenization_kwargs: dict[str, Any] | None = None
chat_template: str | None = None
## for bi-encoder & late-interaction
offset: int | None = None
@dataclass
class OfflineOutputsContext:
outputs: list[PoolingRequestOutput]
## for bi-encoder & late-interaction
offset: int | None = None
...@@ -11,8 +11,10 @@ import pybase64 ...@@ -11,8 +11,10 @@ import pybase64
import torch import torch
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EMBED_DTYPES, EMBED_DTYPES,
EmbedDType, EmbedDType,
...@@ -133,3 +135,20 @@ def get_json_response_cls() -> type[JSONResponse]: ...@@ -133,3 +135,20 @@ def get_json_response_cls() -> type[JSONResponse]:
"To make v1/embeddings API fast, please install orjson by `pip install orjson`" "To make v1/embeddings API fast, please install orjson by `pip install orjson`"
) )
return JSONResponse return JSONResponse
def enable_scoring_api(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
) -> bool:
if any(t in supported_tasks for t in ("embed", "token_embed")):
return True
if model_config is not None and "classify" in supported_tasks:
num_labels = getattr(model_config.hf_config, "num_labels", 0)
if num_labels != 1:
logger.debug_once("Scoring API is only enabled for num_labels == 1.")
return False
return True
return False
...@@ -14,8 +14,8 @@ from vllm.config import ModelConfig ...@@ -14,8 +14,8 @@ from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling import enable_scoring_api
from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.entrypoints.serve.instrumentator.basic import base from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health from vllm.entrypoints.serve.instrumentator.health import health
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
...@@ -76,15 +76,15 @@ def get_invocation_types( ...@@ -76,15 +76,15 @@ def get_invocation_types(
] ]
if enable_scoring_api(supported_tasks, model_config): if enable_scoring_api(supported_tasks, model_config):
from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank from vllm.entrypoints.pooling.scoring.api_router import do_rerank, rerank
from vllm.entrypoints.pooling.score.protocol import RerankRequest from vllm.entrypoints.pooling.scoring.protocol import RerankRequest
INVOCATION_TYPES += [ INVOCATION_TYPES += [
(RerankRequest, (rerank, do_rerank)), (RerankRequest, (rerank, do_rerank)),
] ]
from vllm.entrypoints.pooling.score.api_router import create_score, score from vllm.entrypoints.pooling.scoring.api_router import create_score, score
from vllm.entrypoints.pooling.score.protocol import ScoreRequest from vllm.entrypoints.pooling.scoring.protocol import ScoreRequest
INVOCATION_TYPES += [ INVOCATION_TYPES += [
(ScoreRequest, (score, create_score)), (ScoreRequest, (score, create_score)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment