Unverified Commit d9d21eb8 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend][3/n] Improve pooling entrypoints | scoring. (#28631)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent f09daea2
......@@ -11,9 +11,7 @@ from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
import numpy as np
from fastapi import Request
from openai.types.responses import (
ToolChoiceFunction,
)
from openai.types.responses import ToolChoiceFunction
from pydantic import ConfigDict, TypeAdapter, ValidationError
from starlette.datastructures import Headers
......@@ -21,9 +19,7 @@ import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import (
BatchChatCompletionRequest,
......@@ -42,9 +38,7 @@ from vllm.entrypoints.openai.engine.protocol import (
GenerationError,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponse,
......@@ -56,14 +50,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreDataRequest,
ScoreQueriesDocumentsRequest,
ScoreRequest,
ScoreResponse,
ScoreTextRequest,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest,
......@@ -72,8 +58,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeResponse,
)
from vllm.entrypoints.utils import create_error_response
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EngineInput, PromptType, TokensPrompt
from vllm.inputs import EngineInput, PromptType
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
......@@ -119,8 +104,6 @@ CompletionLikeRequest: TypeAlias = (
CompletionRequest
| TokenizeCompletionRequest
| DetokenizeRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest
)
......@@ -148,7 +131,6 @@ AnyResponse: TypeAlias = (
| TranscriptionResponse
| TokenizeResponse
| PoolingResponse
| ScoreResponse
| GenerateResponse
)
......@@ -692,88 +674,6 @@ class OpenAIServing:
message_types.add(content_dict["type"].split("_")[0])
return message_types
def _validate_input(
self,
request: object,
input_ids: list[int],
input_text: str,
) -> TokensPrompt:
token_num = len(input_ids)
max_model_len = self.model_config.max_model_len
# Note: ScoreRequest doesn't have max_tokens
if isinstance(
request,
(
ScoreDataRequest,
ScoreTextRequest,
ScoreQueriesDocumentsRequest,
RerankRequest,
),
):
# Note: input length can be up to the entire model context length
# since these requests don't generate tokens.
if token_num > max_model_len:
operations: dict[type[AnyRequest], str] = {
ScoreDataRequest: "score",
ScoreTextRequest: "score",
ScoreQueriesDocumentsRequest: "score",
}
operation = operations.get(type(request), "embedding generation")
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input prompt.",
parameter="input_tokens",
value=token_num,
)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(
request,
(TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
):
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest):
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens = request.max_completion_tokens or request.max_tokens
else:
max_tokens = getattr(request, "max_tokens", None)
# Note: input length can be up to model context length - 1 for
# completion-like requests.
if token_num >= max_model_len:
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, your request has "
f"{token_num} input tokens. Please reduce the length of "
"the input messages.",
parameter="input_tokens",
value=token_num,
)
if max_tokens is not None and token_num + max_tokens > max_model_len:
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, you requested "
f"{max_tokens} output tokens and your prompt contains "
f"{token_num} input tokens, for a total of "
f"{token_num + max_tokens} tokens "
f"({token_num} + {max_tokens} = "
f"{token_num + max_tokens} > {max_model_len}). "
f"Please reduce the length of the input prompt or the "
f"number of requested output tokens.",
parameter="max_tokens",
value=max_tokens,
)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _validate_chat_template(
self,
request_chat_template: str | None,
......
......@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import json
import sys
import tempfile
from argparse import Namespace
......@@ -13,12 +15,14 @@ from urllib.parse import urlparse
import aiohttp
import pybase64 as base64
import pydantic
import torch
from fastapi import UploadFile
from prometheus_client import start_http_server
from pydantic import Field, TypeAdapter, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo
from starlette.datastructures import State
from starlette.responses import JSONResponse
from tqdm import tqdm
from urllib3.util import parse_url
......@@ -49,7 +53,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
from vllm.entrypoints.pooling.scoring.protocol import (
RerankRequest,
RerankResponse,
ScoreRequest,
......@@ -180,6 +184,18 @@ class BatchRequestInput(OpenAIBaseModel):
return TypeAdapter(BatchRequestInputBody).validate_python(value)
AllResponse: TypeAlias = (
ChatCompletionResponse
| EmbeddingResponse
| ScoreResponse
| RerankResponse
| TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
)
class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response.
status_code: int = 200
......@@ -188,17 +204,7 @@ class BatchResponseData(OpenAIBaseModel):
request_id: str
# The body of the response.
body: (
ChatCompletionResponse
| EmbeddingResponse
| ScoreResponse
| RerankResponse
| TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
| None
) = None
body: AllResponse | None = None
class BatchRequestOutput(OpenAIBaseModel):
......@@ -536,19 +542,13 @@ async def run_request(
except Exception as e:
response = create_error_response(e)
if isinstance(
response,
(
ChatCompletionResponse,
EmbeddingResponse,
ScoreResponse,
RerankResponse,
TranscriptionResponse,
TranscriptionResponseVerbose,
TranslationResponse,
TranslationResponseVerbose,
),
):
if isinstance(response, JSONResponse):
with contextlib.suppress(pydantic.ValidationError):
response = TypeAdapter(AllResponse | ErrorResponse).validate_python(
json.loads(response.body)
)
if isinstance(response, AllResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
......@@ -745,14 +745,14 @@ async def build_endpoint_registry(
"score": {
"url_matcher": lambda url: url.endswith("/score"),
"handler_getter": lambda: (
serving_scores.create_score if serving_scores is not None else None
serving_scores if serving_scores is not None else None
),
"wrapper_fn": None,
},
"rerank": {
"url_matcher": lambda url: url.endswith("/rerank"),
"handler_getter": lambda: (
serving_scores.do_rerank if serving_scores is not None else None
serving_scores if serving_scores is not None else None
),
"wrapper_fn": None,
},
......
......@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
from fastapi import FastAPI
from vllm.config import ModelConfig
from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.logger import init_logger
if TYPE_CHECKING:
......@@ -23,23 +24,6 @@ else:
logger = init_logger(__name__)
def enable_scoring_api(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
) -> bool:
if any(t in supported_tasks for t in ("embed", "token_embed")):
return True
if model_config is not None and "classify" in supported_tasks:
num_labels = getattr(model_config.hf_config, "num_labels", 0)
if num_labels != 1:
logger.debug_once("Score API is only enabled for num_labels == 1.")
return False
return True
return False
def register_pooling_api_routers(
app: FastAPI,
supported_tasks: tuple["SupportedTask", ...],
......@@ -68,7 +52,7 @@ def register_pooling_api_routers(
app.include_router(embed_router)
if enable_scoring_api(supported_tasks, model_config):
from vllm.entrypoints.pooling.score.api_router import router as score_router
from vllm.entrypoints.pooling.scoring.api_router import router as score_router
app.include_router(score_router)
......@@ -84,7 +68,7 @@ def init_pooling_state(
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.entrypoints.pooling.scoring.serving import ServingScores
from vllm.tasks import POOLING_TASKS
model_config = engine_client.model_config
......@@ -136,8 +120,9 @@ def init_pooling_state(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
)
if enable_scoring_api(supported_tasks, model_config)
else None
......
......@@ -13,13 +13,16 @@ from vllm.entrypoints.chat_utils import (
ConversationMessage,
)
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
from vllm.entrypoints.pooling.scoring.typing import ScoringData
from vllm.entrypoints.pooling.typing import (
OfflineInputsContext,
OfflineOutputsContext,
PoolingChatLikeRequest,
PoolingCompletionLikeRequest,
PoolingServeContext,
)
from vllm.inputs import EngineInput, SingletonPrompt
from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers import BaseRenderer, TokenizeParams, merge_kwargs
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tool_parsers import ToolParser
from vllm.utils.mistral import is_mistral_tokenizer
......@@ -96,29 +99,29 @@ class PoolingIOProcessor:
#######################################
# offline APIs
def pre_process_offline(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[EngineInput]:
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert not isinstance(ctx.prompts, ScoringData)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
return self._preprocess_completion_offline(
prompts=prompts, tokenization_kwargs=tokenization_kwargs
prompts=ctx.prompts, tok_params=tok_params
)
async def pre_process_offline_async(self, *args, **kwargs):
return self.pre_process_offline(*args, **kwargs)
async def pre_process_offline_async(self, ctx: OfflineInputsContext):
return self.pre_process_offline(ctx)
def post_process_offline(
self,
outputs: list[PoolingRequestOutput],
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
return outputs
return ctx.outputs
async def post_process_offline_async(
self,
outputs: list[PoolingRequestOutput],
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
return self.post_process_offline(outputs)
return self.post_process_offline(ctx)
#######################################
# helpers
......@@ -204,28 +207,21 @@ class PoolingIOProcessor:
def _preprocess_completion_offline(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
tok_params: TokenizeParams,
prompt_extras: dict[str, Any] | None = None,
) -> Sequence[EngineInput]:
renderer = self.renderer
model_config = self.model_config
prompts = prompt_to_seq(prompts)
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
else parse_model_prompt(self.model_config, prompt)
)
for prompt in prompts
]
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
return renderer.render_cmpl(
parsed_prompts,
tok_params,
return self.renderer.render_cmpl(
parsed_prompts, tok_params, prompt_extras=prompt_extras
)
def _validate_chat_template(
......
......@@ -117,8 +117,16 @@ class PoolingServing:
else await self._get_trace_headers(ctx.raw_request.headers)
)
pooling_params = self.io_processor.create_pooling_params(ctx.request)
pooling_params.verify(self.model_config)
if ctx.pooling_params is None:
pooling_params = self.io_processor.create_pooling_params(ctx.request)
else:
pooling_params = ctx.pooling_params
if isinstance(pooling_params, list):
for params in pooling_params:
params.verify(self.model_config)
else:
pooling_params.verify(self.model_config)
for i, engine_input in enumerate(ctx.engine_inputs):
prompt_request_id = (
......@@ -127,16 +135,22 @@ class PoolingServing:
else ctx.prompt_request_ids[i]
)
params = (
pooling_params[i]
if isinstance(pooling_params, list)
else pooling_params
)
self._log_inputs(
prompt_request_id,
engine_input,
params=pooling_params,
params=params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode(
engine_input,
pooling_params,
params,
prompt_request_id,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
......
......@@ -5,6 +5,8 @@
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessors
from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.renderers import BaseRenderer
from vllm.tasks import SupportedTask
......@@ -25,6 +27,11 @@ def init_pooling_io_processors(
processors.append(("embed", EmbedIOProcessor))
if enable_scoring_api(supported_tasks, model_config):
score_type = model_config.score_type
if score_type is not None and score_type in ScoringIOProcessors:
processors.append((score_type, ScoringIOProcessors[score_type]))
return {
task: processor_cls(
model_config=model_config,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.score.protocol import (
RerankDocument,
RerankRequest,
RerankResponse,
RerankResult,
RerankUsage,
ScoreRequest,
ScoreResponse,
ScoreResponseData,
)
from vllm.entrypoints.pooling.score.utils import (
ScoreData,
ScoreInputs,
_cosine_similarity,
compress_token_type_ids,
get_score_prompt,
parse_score_data_single,
validate_score_input,
)
from vllm.inputs import EngineInput, TokensPrompt, tokens_input
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import make_async, merge_async_iterators
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.v1.pool.late_interaction import (
build_late_interaction_doc_params,
build_late_interaction_query_params,
)
logger = init_logger(__name__)
class ServingScores(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
score_template: str | None = None,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
)
self.score_template = score_template
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self.score_type = self.model_config.score_type
self.architecture = self.model_config.architecture
self.is_multimodal_model = self.model_config.is_multimodal_model
if self.score_type == "cross-encoder":
self._score_func = self._cross_encoding_score
elif self.score_type == "late-interaction":
self._score_func = self._late_interaction_score
else: # "bi-encoder"
self._score_func = self._embedding_score
async def _embedding_score(
self,
data_1: list[ScoreData],
data_2: list[ScoreData],
request: RerankRequest | ScoreRequest,
request_id: str,
lora_request: LoRARequest | None | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
input_texts: list[str] = []
for text in data_1 + data_2:
if not isinstance(text, str):
raise NotImplementedError(
"Embedding scores currently do not support multimodal input."
)
input_texts.append(text)
model_config = self.model_config
tokenizer = self.renderer.get_tokenizer()
encode_async = make_async(
tokenizer.encode,
executor=self._tokenizer_executor,
)
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
tokenized_prompts = await asyncio.gather(
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
)
engine_inputs: list[EngineInput] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_inputs.append(
tokens_input(
text_token_prompt["prompt_token_ids"],
prompt=input_text,
)
)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params("embed")
for i, engine_input in enumerate(engine_inputs):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_input,
params=pooling_params,
lora_request=lora_request,
)
generators.append(
self.engine_client.encode(
engine_input,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
)
result_generator = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: list[PoolingRequestOutput] = []
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_inputs)
async for i, res in result_generator:
embeddings[i] = res
emb_data_1: list[PoolingRequestOutput] = []
emb_data_2: list[PoolingRequestOutput] = []
for i in range(0, len(data_1)):
assert (emb := embeddings[i]) is not None
emb_data_1.append(emb)
for i in range(len(data_1), len(embeddings)):
assert (emb := embeddings[i]) is not None
emb_data_2.append(emb)
if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
final_res_batch = _cosine_similarity(
tokenizer=tokenizer, embed_1=emb_data_1, embed_2=emb_data_2
)
return final_res_batch
def _preprocess_late_interaction_item(
self,
data: ScoreData,
role: str,
request: RerankRequest | ScoreRequest,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
) -> TokensPrompt:
"""Parse a single ScoreData into a text + optional multimodal
TokensPrompt for late-interaction encoding.
For plain strings, tokenises directly.
For multimodal content parts, extracts text and multi_modal_data.
"""
model_config = self.model_config
if isinstance(data, str):
text, mm_data, mm_uuids = data, None, None
else:
text, mm_data, mm_uuids = parse_score_data_single(data, role, model_config)
prompt_ids = tokenizer.encode(text, **tokenization_kwargs)
self._validate_input(request, prompt_ids, text)
tok_prompt = TokensPrompt(
prompt_token_ids=prompt_ids,
prompt=text,
)
if mm_data is not None:
tok_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
tok_prompt["multi_modal_uuids"] = mm_uuids
if request.mm_processor_kwargs is not None:
tok_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
return tok_prompt
async def _late_interaction_score(
self,
data_1: list[ScoreData],
data_2: list[ScoreData],
request: RerankRequest | ScoreRequest,
request_id: str,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
"""
Late interaction scoring (ColBERT MaxSim).
Encodes queries and documents into per-token embeddings, then computes
MaxSim: sum over query tokens of max similarity to any document token.
"""
model_config = self.model_config
tokenizer = self.renderer.get_tokenizer()
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
all_data = data_1 + data_2
roles = ["query"] * len(data_1) + ["document"] * len(data_2)
preprocess_async = make_async(
self._preprocess_late_interaction_item,
executor=self._tokenizer_executor,
)
tok_prompts = await asyncio.gather(
*(
preprocess_async(
data=d,
role=r,
request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)
for d, r in zip(all_data, roles)
)
)
query_prompts = tok_prompts[: len(data_1)]
doc_prompts = tok_prompts[len(data_1) :]
default_pooling_params = request.to_pooling_params("token_embed")
# stage 1: encode queries and cache token embeddings on workers.
query_keys = [f"{request_id}-query-{i}" for i in range(len(query_prompts))]
query_uses = [len(doc_prompts) if len(query_prompts) == 1 else 1] * len(
query_prompts
)
query_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
for i, tok_prompt in enumerate(query_prompts):
request_id_item = f"{request_id}-query-{i}"
pooling_params = default_pooling_params.clone()
pooling_params.late_interaction_params = (
build_late_interaction_query_params(
query_key=query_keys[i],
query_uses=query_uses[i],
)
)
self._log_inputs(
request_id_item,
tok_prompt,
params=pooling_params,
lora_request=lora_request,
)
query_generators.append(
self.engine_client.encode(
tok_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
)
query_outputs: list[PoolingRequestOutput | None] = [None] * len(query_prompts)
if query_generators:
async for i, res in merge_async_iterators(*query_generators):
query_outputs[i] = res
assert all(res is not None for res in query_outputs)
query_results = [res for res in query_outputs if res is not None]
# stage 2: encode docs and return scalar scores from workers.
doc_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
for i, tok_prompt in enumerate(doc_prompts):
request_id_item = f"{request_id}-doc-{i}"
query_idx = 0 if len(query_prompts) == 1 else i
pooling_params = default_pooling_params.clone()
pooling_params.late_interaction_params = build_late_interaction_doc_params(
query_key=query_keys[query_idx]
)
self._log_inputs(
request_id_item,
tok_prompt,
params=pooling_params,
lora_request=lora_request,
)
doc_generators.append(
self.engine_client.encode(
tok_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
)
doc_outputs: list[PoolingRequestOutput | None] = [None] * len(doc_prompts)
if doc_generators:
async for i, res in merge_async_iterators(*doc_generators):
doc_outputs[i] = res
assert all(res is not None for res in doc_outputs)
doc_results = [res for res in doc_outputs if res is not None]
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
if len(query_results) == 1:
query_results = query_results * len(doc_results)
for query_result, doc_result in zip(query_results, doc_results):
tokens = (
query_result.prompt_token_ids + padding + doc_result.prompt_token_ids
)
scores.append(
PoolingRequestOutput(
request_id=f"{query_result.request_id}_{doc_result.request_id}",
outputs=doc_result.outputs,
prompt_token_ids=tokens,
num_cached_tokens=(
query_result.num_cached_tokens + doc_result.num_cached_tokens
),
finished=True,
)
)
return scores
async def _cross_encoding_score(
self,
data_1: list[ScoreData],
data_2: list[ScoreData],
request: RerankRequest | ScoreRequest,
request_id: str,
lora_request: LoRARequest | None | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
tokenizer = self.renderer.get_tokenizer()
if is_mistral_tokenizer(tokenizer):
raise ValueError("MistralTokenizer not supported for cross-encoding")
model_config = self.model_config
if len(data_1) == 1:
data_1 = data_1 * len(data_2)
tok_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
preprocess_async = make_async(
self._preprocess_score,
executor=self._tokenizer_executor,
)
preprocessed_prompts = await asyncio.gather(
*(
preprocess_async(
request=request,
tokenizer=tokenizer,
tokenization_kwargs=tok_kwargs,
data_1=t1,
data_2=t2,
)
for t1, t2 in input_pairs
)
)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
default_pooling_params = request.to_pooling_params("classify")
for i, (full_prompt, tok_prompt) in enumerate(preprocessed_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
full_prompt,
params=default_pooling_params,
lora_request=lora_request,
)
if token_type_ids := tok_prompt.pop("token_type_ids", None):
pooling_params = default_pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids)
pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed}
else:
pooling_params = default_pooling_params
generator = self.engine_client.encode(
tok_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
result_generator = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: list[PoolingRequestOutput | None] = [None] * len(
preprocessed_prompts
)
async for i, res in result_generator:
final_res_batch[i] = res
return [out for out in final_res_batch if out is not None]
def _preprocess_score(
self,
request: RerankRequest | ScoreRequest,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
data_1: ScoreData,
data_2: ScoreData,
) -> tuple[str, TokensPrompt]:
model_config = self.model_config
full_prompt, engine_input = get_score_prompt(
model_config=model_config,
data_1=data_1,
data_2=data_2,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
score_template=self.score_template,
)
self._validate_input(request, engine_input["prompt_token_ids"], full_prompt)
if request.mm_processor_kwargs is not None:
engine_input["mm_processor_kwargs"] = request.mm_processor_kwargs
return full_prompt, engine_input
async def _run_scoring(
self,
data_1: ScoreInputs,
data_2: ScoreInputs,
request: ScoreRequest | RerankRequest,
request_id: str,
raw_request: Request | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
lora_request = self._maybe_get_adapters(request)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
score_data_1, score_data_2 = validate_score_input(
data_1,
data_2,
is_multimodal_model=self.is_multimodal_model,
architecture=self.architecture,
)
return await self._score_func(
data_1=score_data_1,
data_2=score_data_2,
request=request,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
)
async def create_score(
self,
request: ScoreRequest,
raw_request: Request | None = None,
) -> ScoreResponse | ErrorResponse:
"""
Score API similar to Sentence Transformers cross encoder
See https://sbert.net/docs/package_reference/cross_encoder
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.time())
try:
final_res_batch = await self._run_scoring(
request.data_1,
request.data_2,
request,
request_id,
raw_request,
)
if isinstance(final_res_batch, ErrorResponse):
return final_res_batch
return self.request_output_to_score_response(
final_res_batch,
request_id,
created_time,
self.models.model_name(),
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
async def do_rerank(
self, request: RerankRequest, raw_request: Request | None = None
) -> RerankResponse | ErrorResponse:
"""
Rerank API based on JinaAI's rerank API; implements the same
API interface. Designed for compatibility with off-the-shelf
tooling, since this is a common standard for reranking APIs
See example client implementations at
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
numerous clients use this standard.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"rerank-{self._base_request_id(raw_request)}"
documents = request.documents
try:
final_res_batch = await self._run_scoring(
request.query,
documents,
request,
request_id,
raw_request,
)
if isinstance(final_res_batch, ErrorResponse):
return final_res_batch
top_n = request.top_n if request.top_n > 0 else len(final_res_batch)
return self.request_output_to_rerank_response(
final_res_batch,
request_id,
self.models.model_name(),
documents,
top_n,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
def request_output_to_score_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> ScoreResponse:
items: list[ScoreResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
item = ScoreResponseData(
index=idx,
score=classify_res.outputs.score,
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
def request_output_to_rerank_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
model_name: str,
documents: ScoreInputs,
top_n: int,
) -> RerankResponse:
"""
Convert the output of do_rank to a RerankResponse
"""
if not isinstance(documents, list):
documents = [documents]
results: list[RerankResult] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
document = documents[idx]
if isinstance(document, str):
rerank_document = RerankDocument(text=document)
else:
rerank_document = RerankDocument(
multi_modal=document.get("content", [])
)
result = RerankResult(
index=idx,
document=rerank_document,
relevance_score=classify_res.outputs.score,
)
results.append(result)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
# sort by relevance, then return the top n if set
results.sort(key=lambda x: x.relevance_score, reverse=True)
if top_n < len(documents):
results = results[:top_n]
return RerankResponse(
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(
total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens
),
)
......@@ -3,21 +3,15 @@
from http import HTTPStatus
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
RerankResponse,
ScoreRequest,
ScoreResponse,
)
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger
from .protocol import RerankRequest, ScoreRequest
from .serving import ServingScores
router = APIRouter()
logger = init_logger(__name__)
......@@ -46,16 +40,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
if handler is None:
raise NotImplementedError("The model does not support Score API")
generator = await handler.create_score(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, ScoreResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
return await handler(request, raw_request)
@router.post(
......@@ -92,16 +77,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
if handler is None:
raise NotImplementedError("The model does not support Rerank (Score) API")
generator = await handler.do_rerank(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, RerankResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
return await handler(request, raw_request)
@router.post(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Sequence
from typing import Any, TypeAlias, cast
import torch.nn.functional as F
from vllm import PoolingParams, PoolingRequestOutput, TokensPrompt
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.typing import (
OfflineInputsContext,
OfflineOutputsContext,
PoolingServeContext,
)
from vllm.inputs import EngineInput
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tasks import PoolingTask, ScoreType
from vllm.utils.mistral import is_mistral_tokenizer
from ...chat_utils import ChatTemplateResolutionError
from .protocol import RerankRequest, ScoreRequest, ScoringRequest
from .typing import ScoreData, ScoreInput, ScoringData
from .utils import (
compress_token_type_ids,
compute_maxsim_score,
parse_score_data,
score_data_to_prompts,
validate_score_input,
)
ScoringServeContext: TypeAlias = PoolingServeContext[ScoringRequest]
class ScoringIOProcessor(PoolingIOProcessor):
name: ScoreType
pooling_task: PoolingTask
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = self.renderer.get_tokenizer()
self.architecture = self.model_config.architecture
self.is_multimodal_model = self.model_config.is_multimodal_model
self.pad_token_id = self.tokenizer.pad_token_id
def create_pooling_params(self, request):
return request.to_pooling_params(self.pooling_task)
def valid_inputs(
self,
data_1: ScoreInput | list[ScoreInput],
data_2: ScoreInput | list[ScoreInput],
) -> ScoringData:
scoring_data = validate_score_input(
data_1,
data_2,
is_multimodal_model=self.is_multimodal_model,
architecture=self.architecture,
)
return scoring_data
class BiEncoderIOProcessor(ScoringIOProcessor):
name: ScoreType = "bi-encoder"
pooling_task: PoolingTask = "embed"
#######################################
# online APIs
def pre_process_online(self, ctx: ScoringServeContext):
request = ctx.request
if isinstance(request, ScoreRequest):
data_1 = request.data_1
data_2 = request.data_2
elif isinstance(request, RerankRequest):
data_1 = request.query
data_2 = request.documents
else:
raise ValueError(f"Invalid {self.name} request type")
scoring_data = self.valid_inputs(data_1, data_2)
tok_params = request.build_tok_params(self.model_config)
engine_inputs = self._pre_process(
scoring_data,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
ctx.engine_inputs = engine_inputs
ctx.intermediates = len(scoring_data.data_1)
def post_process_online(
self,
ctx: ScoringServeContext,
):
if ctx.final_res_batch is None:
raise ValueError("Final response batch not available")
if ctx.intermediates is None:
raise ValueError("data_1 len not available")
ctx.final_res_batch = self._post_process(
outputs=ctx.final_res_batch, offset=cast(int, ctx.intermediates)
)
#######################################
# offline APIs
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert isinstance(ctx.prompts, ScoringData)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
return self._pre_process(ctx.prompts, tok_params)
def post_process_offline(
self,
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
assert ctx.offset is not None
return self._post_process(outputs=ctx.outputs, offset=ctx.offset)
#######################################
# helpers
def _pre_process(
self,
scoring_data: ScoringData,
tok_params: TokenizeParams,
prompt_extras: dict[str, Any] | None = None,
) -> Sequence[EngineInput]:
data_1 = score_data_to_prompts(scoring_data.data_1, "query", self.model_config)
data_2 = score_data_to_prompts(
scoring_data.data_2, "document", self.model_config
)
return self._preprocess_completion_offline(
prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras
)
def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
emb_data_1 = outputs[:offset]
emb_data_2 = outputs[offset:]
if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
final_res_batch: list[PoolingRequestOutput] = []
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
pair_score = F.cosine_similarity(
emb_1.outputs.data.float(), emb_2.outputs.data.float(), dim=0
)
padding: list[int] = []
if self.pad_token_id is not None:
padding = [self.pad_token_id]
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
final_res_batch.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
return final_res_batch
class LateInteractionIOProcessor(BiEncoderIOProcessor):
name: ScoreType = "late-interaction"
pooling_task: PoolingTask = "token_embed"
def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
# Split into query and document embeddings
emb_data_1 = outputs[:offset]
emb_data_2 = outputs[offset:]
# Expand queries if 1:N scoring
if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
final_res_batch: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := self.pad_token_id) is not None:
padding = [pad_token_id]
# Compute MaxSim scores
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
final_res_batch.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=maxsim_score,
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
return final_res_batch
class CrossEncoderIOProcessor(ScoringIOProcessor):
name: ScoreType = "cross-encoder"
pooling_task: PoolingTask = "classify"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if is_mistral_tokenizer(self.tokenizer):
raise ValueError("MistralTokenizer not supported for cross-encoding")
from vllm.model_executor.model_loader import get_model_cls
from vllm.model_executor.models.interfaces import supports_score_template
model = get_model_cls(self.model_config)
self.supports_score_template = supports_score_template(model)
self.model = model if self.supports_score_template else None
self.use_sep_token = self.model_config.use_sep_token
#######################################
# online APIs
def pre_process_online(self, ctx: ScoringServeContext):
request = ctx.request
if isinstance(request, ScoreRequest):
data_1 = request.data_1
data_2 = request.data_2
elif isinstance(request, RerankRequest):
data_1 = request.query
data_2 = request.documents
else:
raise ValueError(f"Invalid {self.name} request type")
scoring_data = self.valid_inputs(data_1, data_2)
tok_params = request.build_tok_params(self.model_config)
pooling_params = self.create_pooling_params(request)
engine_inputs, pooling_params_list = self._pre_process(
scoring_data,
tok_params,
pooling_params,
chat_template=self.chat_template,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
ctx.engine_inputs = engine_inputs
ctx.pooling_params = pooling_params_list
#######################################
# offline APIs
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert isinstance(ctx.prompts, ScoringData)
assert not isinstance(ctx.pooling_params, list)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
engine_inputs, pooling_params_list = self._pre_process(
ctx.prompts, tok_params, ctx.pooling_params, ctx.chat_template
)
ctx.pooling_params = pooling_params_list
return engine_inputs
#######################################
# helpers
def _pre_process(
self,
scoring_data: ScoringData,
tok_params: TokenizeParams,
pooling_params: PoolingParams | None,
chat_template: str | None = None,
prompt_extras: dict[str, Any] | None = None,
) -> tuple[Sequence[EngineInput], list[PoolingParams]]:
# todo: support prompt_extras
arrival_time = time.time()
data_1 = scoring_data.data_1
data_2 = scoring_data.data_2
if len(data_1) == 1:
data_1 = data_1 * len(data_2)
if pooling_params is None:
pooling_params = PoolingParams(task="classify")
pooling_params_list = list[PoolingParams]()
engine_inputs = list[EngineInput]()
for q, d in zip(data_1, data_2):
_, engine_prompt = self.get_score_prompt(
data_1=q,
data_2=d,
encode_kwargs=tok_params.get_encode_kwargs(),
chat_template=chat_template,
)
if token_type_ids := engine_prompt.pop("token_type_ids", None):
params = pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids)
params.extra_kwargs = {"compressed_token_type_ids": compressed}
pooling_params_list.append(params)
else:
pooling_params_list.append(pooling_params)
tok_params.apply_post_tokenization(self.tokenizer, engine_prompt)
engine_inputs.append(
self.renderer.process_for_engine(engine_prompt, arrival_time)
)
return engine_inputs, pooling_params_list
def get_score_prompt(
self,
data_1: ScoreData,
data_2: ScoreData,
encode_kwargs: dict[str, Any],
chat_template: str | None = None,
):
model_config = self.model_config
tokenizer = self.tokenizer
prompt_1, prompt_2, mm_data, mm_uuids = parse_score_data(
data_1,
data_2,
model_config,
)
def default_tokenizer_encode():
if self.supports_score_template:
assert self.model is not None
full_prompt = self.model.get_score_template(prompt_1, prompt_2)
if full_prompt is None:
raise ValueError("Get empty score template from model")
prompt_inputs = tokenizer(full_prompt, **encode_kwargs)
else:
if self.use_sep_token:
# cross_encoder models defaults to using separating token.
prompt_inputs = tokenizer(
text=prompt_1, text_pair=prompt_2, **encode_kwargs
)
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
else:
# `llm as reranker` defaults to not using separating token.
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **encode_kwargs)
return full_prompt, prompt_inputs
# FIXME: For now, we only apply a template when one is explicitly provided.
# We cannot rely on the tokenizer's chat template because many models
# inherit junk templates from their base LLM, which breaks both the models
# and the tests that use them.
if chat_template is None:
full_prompt, prompt_inputs = default_tokenizer_encode()
else:
# FIXME:
# Try applying a score template from the CLI arg or tokenizer_config.json
# If that fails because there is no such template,
# fall back to the default implementation.
try:
full_prompt = safe_apply_chat_template(
model_config,
tokenizer,
[
{"role": "query", "content": prompt_1},
{"role": "document", "content": prompt_2},
],
chat_template=chat_template,
tools=None,
tokenize=False,
)
prompt_inputs = tokenizer(full_prompt, **encode_kwargs)
except ChatTemplateResolutionError:
full_prompt, prompt_inputs = default_tokenizer_encode()
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
engine_prompt["token_type_ids"] = token_type_ids
if self.model is not None:
self.model.post_process_tokens(engine_prompt)
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
return full_prompt, engine_prompt
ScoringIOProcessors: dict[ScoreType, type[ScoringIOProcessor]] = {
"bi-encoder": BiEncoderIOProcessor,
"late-interaction": LateInteractionIOProcessor,
"cross-encoder": CrossEncoderIOProcessor,
}
......@@ -12,15 +12,12 @@ from vllm.entrypoints.pooling.base.protocol import (
ClassifyRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.entrypoints.pooling.score.utils import (
ScoreContentPartParam,
ScoreInput,
ScoreInputs,
)
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
from .typing import ScoreContentPartParam, ScoreInput
class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
......@@ -43,13 +40,13 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
class ScoreDataRequest(ScoreRequestMixin):
data_1: ScoreInputs
data_2: ScoreInputs
data_1: ScoreInput | list[ScoreInput]
data_2: ScoreInput | list[ScoreInput]
class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
queries: ScoreInputs
documents: ScoreInputs
queries: ScoreInput | list[ScoreInput]
documents: ScoreInput | list[ScoreInput]
@property
def data_1(self):
......@@ -61,8 +58,8 @@ class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
class ScoreQueriesItemsRequest(ScoreRequestMixin):
queries: ScoreInputs
items: ScoreInputs
queries: ScoreInput | list[ScoreInput]
items: ScoreInput | list[ScoreInput]
@property
def data_1(self):
......@@ -74,8 +71,8 @@ class ScoreQueriesItemsRequest(ScoreRequestMixin):
class ScoreTextRequest(ScoreRequestMixin):
text_1: ScoreInputs
text_2: ScoreInputs
text_1: ScoreInput | list[ScoreInput]
text_2: ScoreInput | list[ScoreInput]
@property
def data_1(self):
......@@ -96,7 +93,7 @@ ScoreRequest: TypeAlias = (
class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
query: ScoreInput
documents: ScoreInputs
documents: ScoreInput | list[ScoreInput]
top_n: int = Field(default_factory=lambda: 0)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
......@@ -118,6 +115,9 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
)
ScoringRequest: TypeAlias = ScoreRequest | RerankRequest
class RerankDocument(BaseModel):
text: str | None = None
multi_modal: list[ScoreContentPartParam] | None = None
......@@ -154,3 +154,6 @@ class ScoreResponse(OpenAIBaseModel):
model: str
data: list[ScoreResponseData]
usage: UsageInfo
ScoringResponse: TypeAlias = RerankResponse | ScoreResponse
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi.responses import JSONResponse
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.renderers import BaseRenderer
from .io_processor import ScoringIOProcessors, ScoringServeContext
from .protocol import (
RerankDocument,
RerankRequest,
RerankResponse,
RerankResult,
RerankUsage,
ScoreRequest,
ScoreResponse,
ScoreResponseData,
)
from .typing import ScoreInput
logger = init_logger(__name__)
class ServingScores(PoolingServing):
request_id_prefix = "score"
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
score_type = model_config.score_type
assert score_type in ScoringIOProcessors
processor_cls = ScoringIOProcessors[score_type]
return processor_cls(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
async def _build_response(
self,
ctx: ScoringServeContext,
) -> JSONResponse:
final_res_batch = ctx.final_res_batch
request_id = ctx.request_id
created_time = ctx.created_time
model_name = self.models.model_name()
if isinstance(ctx.request, ScoreRequest):
return self._request_output_to_score_response(
final_res_batch,
request_id,
created_time,
model_name,
)
elif isinstance(ctx.request, RerankRequest):
return self._request_output_to_rerank_response(
final_res_batch,
request_id,
model_name,
ctx.request.documents,
ctx.request.top_n if ctx.request.top_n > 0 else len(final_res_batch),
)
else:
raise NotImplementedError("")
def _request_output_to_score_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> JSONResponse:
items: list[ScoreResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
item = ScoreResponseData(
index=idx,
score=classify_res.outputs.score,
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
response = ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
return JSONResponse(content=response.model_dump())
def _request_output_to_rerank_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
model_name: str,
documents: ScoreInput | list[ScoreInput],
top_n: int,
) -> JSONResponse:
if not isinstance(documents, list):
documents = [documents]
results: list[RerankResult] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
document = documents[idx]
if isinstance(document, str):
rerank_document = RerankDocument(text=document)
else:
rerank_document = RerankDocument(
multi_modal=document.get("content", [])
)
result = RerankResult(
index=idx,
document=rerank_document,
relevance_score=classify_res.outputs.score,
)
results.append(result)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
# sort by relevance, then return the top n if set
results.sort(key=lambda x: x.relevance_score, reverse=True)
if top_n < len(documents):
results = results[:top_n]
response = RerankResponse(
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(
total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens
),
)
return JSONResponse(content=response.model_dump())
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TypeAlias
from typing_extensions import Required, TypedDict
from vllm.entrypoints.chat_utils import (
ChatCompletionContentPartImageEmbedsParam,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam,
ChatCompletionContentPartVideoParam,
)
ScoreContentPartParam: TypeAlias = (
ChatCompletionContentPartImageParam
| ChatCompletionContentPartImageEmbedsParam
| ChatCompletionContentPartTextParam
| ChatCompletionContentPartVideoParam
)
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content
The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
2. Including chat-specific fields would confuse users about their purpose in scoring
3. This is a more focused interface that only exposes what's needed for scoring
""" # noqa: E501
content: Required[list[ScoreContentPartParam]]
"""The multimodal contents"""
# Raw input data with content key in ScoreMultiModalParam.
ScoreInput = str | ScoreMultiModalParam
# Score data without content key.
ScoreData = str | list[ScoreContentPartParam]
@dataclass
class ScoringData:
data_1: list[ScoreData]
data_2: list[ScoreData]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Any, TypeAlias, cast
from typing import cast
import torch
from torch.nn import CosineSimilarity
from typing_extensions import Required, TypedDict
from vllm import PromptType, TextPrompt
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
BaseMultiModalItemTracker,
ChatCompletionContentPartImageEmbedsParam,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
ChatCompletionContentPartVideoParam,
ChatTemplateResolutionError,
ConversationMessage,
MultiModalItemTracker,
_parse_chat_message_content_parts,
)
from vllm.inputs import (
MultiModalDataDict,
MultiModalUUIDDict,
PromptType,
TextPrompt,
TokensPrompt,
)
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.outputs import PoolingRequestOutput
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike
ScoreContentPartParam: TypeAlias = (
ChatCompletionContentPartImageParam
| ChatCompletionContentPartImageEmbedsParam
| ChatCompletionContentPartTextParam
| ChatCompletionContentPartVideoParam
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
from .typing import (
ScoreContentPartParam,
ScoreData,
ScoreInput,
ScoringData,
)
......@@ -57,72 +42,6 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return token_scores.amax(dim=-1).sum()
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content
The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
2. Including chat-specific fields would confuse users about their purpose in scoring
3. This is a more focused interface that only exposes what's needed for scoring
""" # noqa: E501
content: Required[list[ScoreContentPartParam]]
"""The multimodal contents"""
# Raw input data with content key in ScoreMultiModalParam.
ScoreInput = str | ScoreMultiModalParam
ScoreInputs = ScoreInput | list[ScoreInput]
# Score data without content key.
ScoreData = str | list[ScoreContentPartParam]
def _cosine_similarity(
tokenizer: TokenizerLike,
embed_1: list[PoolingRequestOutput],
embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]:
scorer = CosineSimilarity(0)
scores: list[PoolingRequestOutput] = []
for emb_1, emb_2 in zip(embed_1, embed_2):
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
return scores
def _validate_score_input_lens(
data_1: list[ScoreData],
data_2: list[ScoreData],
):
len_1 = len(data_1)
len_2 = len(data_2)
if len_1 > 1 and len_1 != len_2:
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len_1 == 0:
raise ValueError("At least one text element must be given")
if len_2 == 0:
raise ValueError("At least one text_pair element must be given")
def _validate_mm_score_input(
data: list[ScoreInput],
is_multimodal_model: bool,
......@@ -140,12 +59,27 @@ def _validate_mm_score_input(
return out
def _validate_score_input_lens(
data_1: list[ScoreData],
data_2: list[ScoreData],
):
len_1 = len(data_1)
len_2 = len(data_2)
if len_1 > 1 and len_1 != len_2:
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len_1 == 0:
raise ValueError("At least one text element must be given")
if len_2 == 0:
raise ValueError("At least one text_pair element must be given")
def validate_score_input(
data_1: ScoreInputs,
data_2: ScoreInputs,
data_1: ScoreInput | list[ScoreInput],
data_2: ScoreInput | list[ScoreInput],
is_multimodal_model: bool,
architecture: str,
) -> tuple[list[ScoreData], list[ScoreData]]:
) -> ScoringData:
if not isinstance(data_1, list):
data_1 = [data_1]
......@@ -155,62 +89,7 @@ def validate_score_input(
score_input_1 = _validate_mm_score_input(data_1, is_multimodal_model, architecture)
score_input_2 = _validate_mm_score_input(data_2, is_multimodal_model, architecture)
_validate_score_input_lens(score_input_1, score_input_2)
return score_input_1, score_input_2
def _ensure_str(content: list[ConversationMessage]) -> str:
"""Extract a single string prompt from parsed conversation content."""
assert len(content) == 1
prompt = content[0]["content"]
if prompt is not None and isinstance(prompt, str):
return cast(str, prompt)
raise ValueError(f"Only string content is supported, but got {content}.")
def parse_score_data(
data_1: ScoreData,
data_2: ScoreData,
model_config: ModelConfig,
) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
"""Parse a query-document pair into text prompts and shared multi-modal
data.
Uses a **single** :class:`MultiModalItemTracker` so that multi-modal
items from both inputs are merged into one ``mm_data`` dict. This is
the correct behaviour for cross-encoder scoring, where query and
document are concatenated into a single model prompt.
"""
mm_tracker = MultiModalItemTracker(model_config)
content_1 = _parse_score_content("query", data_1, mm_tracker)
content_2 = _parse_score_content("document", data_2, mm_tracker)
prompt_1 = _ensure_str(content_1)
prompt_2 = _ensure_str(content_2)
mm_items, mm_uuids = mm_tracker.resolve_items()
return prompt_1, prompt_2, mm_items, mm_uuids
def parse_score_data_single(
data: ScoreData,
role: str,
model_config: ModelConfig,
) -> tuple[str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
"""Parse **one** ScoreData into a text prompt and its own multi-modal
data.
Unlike :func:`parse_score_data`, each call creates an **independent**
:class:`MultiModalItemTracker` so multi-modal items are kept separate.
This is the correct behaviour for late-interaction scoring, where
query and document are encoded independently.
"""
mm_tracker = MultiModalItemTracker(model_config)
content = _parse_score_content(role, data, mm_tracker)
prompt = _ensure_str(content)
mm_items, mm_uuids = mm_tracker.resolve_items()
return prompt, mm_items, mm_uuids
return ScoringData(data_1=score_input_1, data_2=score_input_2)
def score_data_to_prompts(
......@@ -243,6 +122,15 @@ def score_data_to_prompts(
return prompts
def _ensure_str(content: list[ConversationMessage]) -> str:
"""Extract a single string prompt from parsed conversation content."""
assert len(content) == 1
prompt = content[0]["content"]
if prompt is not None and isinstance(prompt, str):
return cast(str, prompt)
raise ValueError(f"Only string content is supported, but got {content}.")
def _parse_score_content(
role: str,
data: ScoreData,
......@@ -278,113 +166,50 @@ def _parse_score_content(
return next(iter(mm_placeholder_storage.values()))[0]
def _apply_model_score_template(
model_config: ModelConfig, prompt_1: str, prompt_2: str
) -> str:
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
from vllm.model_executor.model_loader import get_model_cls
model = get_model_cls(model_config)
if supports_score_template(model):
full_prompt = model.get_score_template(prompt_1, prompt_2)
if full_prompt is None:
raise ValueError("Get empty score template from model")
return full_prompt
raise ValueError(f"Unsupported model architecture: {model_config.architecture}")
def post_process_tokens(
def parse_score_data_single(
data: ScoreData,
role: str,
model_config: ModelConfig,
prompt: TokensPrompt,
) -> None:
"""
Perform architecture-specific manipulations on the input tokens.
) -> tuple[str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
"""Parse **one** ScoreData into a text prompt and its own multi-modal
data.
Note:
This is an in-place operation.
Unlike :func:`parse_score_data`, each call creates an **independent**
:class:`MultiModalItemTracker` so multi-modal items are kept separate.
This is the correct behaviour for late-interaction scoring, where
query and document are encoded independently.
"""
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
from vllm.model_executor.model_loader import get_model_cls
mm_tracker = MultiModalItemTracker(model_config)
content = _parse_score_content(role, data, mm_tracker)
model = get_model_cls(model_config)
if supports_score_template(model):
model.post_process_tokens(prompt)
prompt = _ensure_str(content)
mm_items, mm_uuids = mm_tracker.resolve_items()
return prompt, mm_items, mm_uuids
def get_score_prompt(
model_config: ModelConfig,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
def parse_score_data(
data_1: ScoreData,
data_2: ScoreData,
score_template: str | None = None,
) -> tuple[str, TokensPrompt]:
prompt_1, prompt_2, mm_data, mm_uuids = parse_score_data(
data_1,
data_2,
model_config,
)
from vllm.model_executor.model_loader import get_model_cls
model_config: ModelConfig,
) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
"""Parse a query-document pair into text prompts and shared multi-modal
data.
model = get_model_cls(model_config)
Uses a **single** :class:`MultiModalItemTracker` so that multi-modal
items from both inputs are merged into one ``mm_data`` dict. This is
the correct behaviour for cross-encoder scoring, where query and
document are concatenated into a single model prompt.
"""
mm_tracker = MultiModalItemTracker(model_config)
def default_tokenizer_encode():
if supports_score_template(model):
full_prompt = _apply_model_score_template(model_config, prompt_1, prompt_2)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
else:
if model_config.use_sep_token:
# cross_encoder models defaults to using separating token.
prompt_inputs = tokenizer(
text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
)
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
else:
# `llm as reranker` defaults to not using separating token.
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
return full_prompt, prompt_inputs
# FIXME: For now, we only apply a template when one is explicitly provided.
# We cannot rely on the tokenizer's chat template because many models
# inherit junk templates from their base LLM, which breaks both the models
# and the tests that use them.
if score_template is None:
full_prompt, prompt_inputs = default_tokenizer_encode()
else:
# FIXME: Try applying a score template from the CLI arg or tokenizer_config.json
# If that fails because there is no such template,
# fall back to the default implementation.
try:
full_prompt = safe_apply_chat_template(
model_config,
tokenizer,
[
{"role": "query", "content": prompt_1},
{"role": "document", "content": prompt_2},
],
chat_template=score_template,
tools=None,
tokenize=False,
)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
except ChatTemplateResolutionError:
full_prompt, prompt_inputs = default_tokenizer_encode()
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
engine_prompt["token_type_ids"] = token_type_ids
post_process_tokens(model_config, engine_prompt)
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
return full_prompt, engine_prompt
content_1 = _parse_score_content("query", data_1, mm_tracker)
content_2 = _parse_score_content("document", data_2, mm_tracker)
prompt_1 = _ensure_str(content_1)
prompt_2 = _ensure_str(content_2)
mm_items, mm_uuids = mm_tracker.resolve_items()
return prompt_1, prompt_2, mm_items, mm_uuids
def compress_token_type_ids(token_type_ids: list[int]) -> int:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass, field
from typing import Any, Generic, TypeAlias, TypeVar
from fastapi import Request
from pydantic import ConfigDict
from vllm import PoolingRequestOutput
from vllm import PoolingParams, PoolingRequestOutput, PromptType
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
......@@ -23,15 +23,13 @@ from vllm.entrypoints.pooling.embed.protocol import (
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingBytesResponse,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse
from vllm.entrypoints.pooling.scoring.typing import ScoringData
from vllm.inputs import EngineInput
from vllm.lora.request import LoRARequest
......@@ -49,8 +47,7 @@ AnyPoolingRequest: TypeAlias = (
PoolingCompletionLikeRequest
| PoolingChatLikeRequest
| IOProcessorRequest
| RerankRequest
| ScoreRequest
| ScoringRequest
| CohereEmbedRequest
)
......@@ -59,7 +56,8 @@ AnyPoolingResponse: TypeAlias = (
| EmbeddingResponse
| EmbeddingBytesResponse
| PoolingResponse
| ScoreResponse
| PoolingBytesResponse
| ScoringResponse
)
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
......@@ -73,8 +71,8 @@ class PoolingServeContext(Generic[PoolingRequestT]):
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_inputs: list[EngineInput] | None = None
pooling_params: PoolingParams | list[PoolingParams] | None = None
engine_inputs: Sequence[EngineInput] | None = None
prompt_request_ids: list[str] | None = None
intermediates: Any | None = None
......@@ -84,3 +82,22 @@ class PoolingServeContext(Generic[PoolingRequestT]):
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass
class OfflineInputsContext:
prompts: PromptType | Sequence[PromptType] | ScoringData
pooling_params: PoolingParams | list[PoolingParams] | None = None
tokenization_kwargs: dict[str, Any] | None = None
chat_template: str | None = None
## for bi-encoder & late-interaction
offset: int | None = None
@dataclass
class OfflineOutputsContext:
outputs: list[PoolingRequestOutput]
## for bi-encoder & late-interaction
offset: int | None = None
......@@ -11,8 +11,10 @@ import pybase64
import torch
from fastapi.responses import JSONResponse
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.utils.serial_utils import (
EMBED_DTYPES,
EmbedDType,
......@@ -133,3 +135,20 @@ def get_json_response_cls() -> type[JSONResponse]:
"To make v1/embeddings API fast, please install orjson by `pip install orjson`"
)
return JSONResponse
def enable_scoring_api(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
) -> bool:
if any(t in supported_tasks for t in ("embed", "token_embed")):
return True
if model_config is not None and "classify" in supported_tasks:
num_labels = getattr(model_config.hf_config, "num_labels", 0)
if num_labels != 1:
logger.debug_once("Scoring API is only enabled for num_labels == 1.")
return False
return True
return False
......@@ -14,8 +14,8 @@ from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling import enable_scoring_api
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health
from vllm.tasks import POOLING_TASKS, SupportedTask
......@@ -76,15 +76,15 @@ def get_invocation_types(
]
if enable_scoring_api(supported_tasks, model_config):
from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank
from vllm.entrypoints.pooling.score.protocol import RerankRequest
from vllm.entrypoints.pooling.scoring.api_router import do_rerank, rerank
from vllm.entrypoints.pooling.scoring.protocol import RerankRequest
INVOCATION_TYPES += [
(RerankRequest, (rerank, do_rerank)),
]
from vllm.entrypoints.pooling.score.api_router import create_score, score
from vllm.entrypoints.pooling.score.protocol import ScoreRequest
from vllm.entrypoints.pooling.scoring.api_router import create_score, score
from vllm.entrypoints.pooling.scoring.protocol import ScoreRequest
INVOCATION_TYPES += [
(ScoreRequest, (score, create_score)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment