Unverified Commit ea463978 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend][1/n] Improve pooling entrypoints | classify. (#35604)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 440f0e7d
...@@ -7,6 +7,7 @@ import warnings ...@@ -7,6 +7,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import Counter, defaultdict from collections import Counter, defaultdict
from collections.abc import Awaitable, Callable, Iterable from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass
from functools import cached_property, lru_cache, partial from functools import cached_property, lru_cache, partial
from itertools import accumulate from itertools import accumulate
from pathlib import Path from pathlib import Path
...@@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("video", placeholder) self._add_placeholder("video", placeholder)
@dataclass
class ChatTemplateConfig:
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
trust_request_chat_template: bool = False
def validate_chat_template(chat_template: Path | str | None): def validate_chat_template(chat_template: Path | str | None):
"""Raises if the provided chat template appears invalid.""" """Raises if the provided chat template appears invalid."""
if chat_template is None: if chat_template is None:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import itertools import itertools
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import cloudpickle import cloudpickle
...@@ -40,8 +41,11 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -40,8 +41,11 @@ from vllm.distributed.weight_transfer.base import (
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateConfig,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
load_chat_template,
) )
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.score.utils import ( from vllm.entrypoints.pooling.score.utils import (
ScoreData, ScoreData,
ScoreMultiModalParam, ScoreMultiModalParam,
...@@ -145,6 +149,7 @@ class LLM: ...@@ -145,6 +149,7 @@ class LLM:
a tag name, or a commit id. a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. branch name, a tag name, or a commit id.
chat_template: The chat template to apply.
seed: The seed to initialize the random number generator for sampling. seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher reserve for the model weights, activations, and KV cache. Higher
...@@ -232,6 +237,7 @@ class LLM: ...@@ -232,6 +237,7 @@ class LLM:
quantization: QuantizationMethods | None = None, quantization: QuantizationMethods | None = None,
revision: str | None = None, revision: str | None = None,
tokenizer_revision: str | None = None, tokenizer_revision: str | None = None,
chat_template: Path | str | None = None,
seed: int = 0, seed: int = 0,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: float = 4, swap_space: float = 4,
...@@ -384,9 +390,16 @@ class LLM: ...@@ -384,9 +390,16 @@ class LLM:
self.model_config = self.llm_engine.model_config self.model_config = self.llm_engine.model_config
self.renderer = self.llm_engine.renderer self.renderer = self.llm_engine.renderer
self.chat_template = load_chat_template(chat_template)
self.io_processor = self.llm_engine.io_processor self.io_processor = self.llm_engine.io_processor
self.input_processor = self.llm_engine.input_processor self.input_processor = self.llm_engine.input_processor
self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
self.init_pooling_io_processors = init_pooling_io_processors(
supported_tasks=supported_tasks,
model_config=self.model_config,
renderer=self.renderer,
chat_template_config=self.chat_template_config,
)
# Cache for __repr__ to avoid repeated collective_rpc calls # Cache for __repr__ to avoid repeated collective_rpc calls
self._cached_repr: str | None = None self._cached_repr: str | None = None
...@@ -1086,7 +1099,7 @@ class LLM: ...@@ -1086,7 +1099,7 @@ class LLM:
"pooling model." "pooling model."
) )
if use_io_processor := (isinstance(prompts, dict) and "data" in prompts): if isinstance(prompts, dict) and "data" in prompts:
if self.io_processor is None: if self.io_processor is None:
raise ValueError( raise ValueError(
"No IOProcessor plugin installed. Please refer " "No IOProcessor plugin installed. Please refer "
...@@ -1120,6 +1133,31 @@ class LLM: ...@@ -1120,6 +1133,31 @@ class LLM:
for p in params_seq: for p in params_seq:
if p.task is None: if p.task is None:
p.task = "plugin" p.task = "plugin"
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(outputs)
return [
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(
processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
)
]
else: else:
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
...@@ -1137,32 +1175,36 @@ class LLM: ...@@ -1137,32 +1175,36 @@ class LLM:
) )
raise ValueError(msg) raise ValueError(msg)
outputs = self._run_completion( if pooling_task in self.init_pooling_io_processors:
prompts=prompts_seq, io_processor = self.init_pooling_io_processors[pooling_task]
params=params_seq, processor_inputs = io_processor.pre_process_offline(
output_type=PoolingRequestOutput, prompts_seq, tokenization_kwargs
use_tqdm=use_tqdm, )
lora_request=lora_request, seq_lora_requests = self._lora_request_to_seq(
tokenization_kwargs=tokenization_kwargs, lora_request, len(prompts_seq)
) )
seq_priority = self._priority_to_seq(None, len(prompts))
if use_io_processor:
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(outputs)
return [ self._render_and_add_requests(
PoolingRequestOutput[Any]( prompts=processor_inputs,
request_id="", params=params_seq,
outputs=processed_outputs, lora_requests=seq_lora_requests,
num_cached_tokens=getattr( priorities=seq_priority,
processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
) )
]
outputs = self._run_engine(
use_tqdm=use_tqdm, output_type=PoolingRequestOutput
)
outputs = io_processor.post_process(outputs)
else:
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return outputs return outputs
def embed( def embed(
......
...@@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( ...@@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionResponse, TranscriptionResponse,
TranslationRequest, TranslationRequest,
) )
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse, EmbeddingBytesResponse,
EmbeddingChatRequest, EmbeddingChatRequest,
...@@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = ( ...@@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = (
| TokenizeCompletionRequest | TokenizeCompletionRequest
| DetokenizeRequest | DetokenizeRequest
| EmbeddingCompletionRequest | EmbeddingCompletionRequest
| ClassificationCompletionRequest
| RerankRequest | RerankRequest
| ScoreRequest | ScoreRequest
| PoolingCompletionRequest | PoolingCompletionRequest
...@@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = ( ...@@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = (
ChatCompletionRequest ChatCompletionRequest
| TokenizeChatRequest | TokenizeChatRequest
| EmbeddingChatRequest | EmbeddingChatRequest
| ClassificationChatRequest
| PoolingChatRequest | PoolingChatRequest
) )
...@@ -194,12 +187,10 @@ AnyResponse: TypeAlias = ( ...@@ -194,12 +187,10 @@ AnyResponse: TypeAlias = (
| TranscriptionResponse | TranscriptionResponse
| TokenizeResponse | TokenizeResponse
| PoolingResponse | PoolingResponse
| ClassificationResponse
| ScoreResponse | ScoreResponse
| GenerateResponse | GenerateResponse
) )
RequestT = TypeVar("RequestT", bound=AnyRequest) RequestT = TypeVar("RequestT", bound=AnyRequest)
...@@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]): ...@@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]):
class OpenAIServing: class OpenAIServing:
request_id_prefix: ClassVar[str] = """ request_id_prefix: ClassVar[str] = """
A short string prepended to every request’s ID (e.g. "embd", "classify") A short string prepended to every request’s ID (e.g. "embd")
so you can easily tell “this ID came from Embedding vs Classification.” so you can easily tell “this ID came from Embedding.”
""" """
def __init__( def __init__(
...@@ -456,7 +447,7 @@ class OpenAIServing: ...@@ -456,7 +447,7 @@ class OpenAIServing:
) -> ErrorResponse | None: ) -> ErrorResponse | None:
""" """
Default preprocessing hook. Subclasses may override Default preprocessing hook. Subclasses may override
to prepare `ctx` (classification, embedding, etc.). to prepare `ctx` (embedding, etc.).
""" """
return None return None
...@@ -817,7 +808,7 @@ class OpenAIServing: ...@@ -817,7 +808,7 @@ class OpenAIServing:
token_num = len(input_ids) token_num = len(input_ids)
max_model_len = self.model_config.max_model_len max_model_len = self.model_config.max_model_len
# Note: EmbeddingRequest, ClassificationRequest, # Note: EmbeddingRequest,
# and ScoreRequest doesn't have max_tokens # and ScoreRequest doesn't have max_tokens
if isinstance( if isinstance(
request, request,
...@@ -828,8 +819,6 @@ class OpenAIServing: ...@@ -828,8 +819,6 @@ class OpenAIServing:
ScoreTextRequest, ScoreTextRequest,
ScoreQueriesDocumentsRequest, ScoreQueriesDocumentsRequest,
RerankRequest, RerankRequest,
ClassificationCompletionRequest,
ClassificationChatRequest,
), ),
): ):
# Note: input length can be up to the entire model context length # Note: input length can be up to the entire model context length
...@@ -839,8 +828,6 @@ class OpenAIServing: ...@@ -839,8 +828,6 @@ class OpenAIServing:
ScoreDataRequest: "score", ScoreDataRequest: "score",
ScoreTextRequest: "score", ScoreTextRequest: "score",
ScoreQueriesDocumentsRequest: "score", ScoreQueriesDocumentsRequest: "score",
ClassificationCompletionRequest: "classification",
ClassificationChatRequest: "classification",
} }
operation = operations.get(type(request), "embedding generation") operation = operations.get(type(request), "embedding generation")
raise VLLMValidationError( raise VLLMValidationError(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Final
from vllm import PoolingRequestOutput, PromptType
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateConfig,
ChatTemplateContentFormatOption,
ConversationMessage,
)
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
from vllm.inputs import ProcessorInputs, SingletonPrompt
from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.utils.mistral import is_mistral_tokenizer
class PoolingIOProcessor:
def __init__(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
):
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self.model_config = model_config
self.renderer = renderer
self.chat_template = chat_template_config.chat_template
self.chat_template_content_format: Final = (
chat_template_config.chat_template_content_format
)
self.trust_request_chat_template = (
chat_template_config.trust_request_chat_template
)
def pre_process_online(self, *args, **kwargs):
raise NotImplementedError
async def pre_process_online_async(self, *args, **kwargs):
return self.pre_process_online(*args, **kwargs)
def pre_process_offline(self, *args, **kwargs):
raise NotImplementedError
async def pre_process_offline_async(self, *args, **kwargs):
return self.pre_process_offline(*args, **kwargs)
def post_process(
self, outputs: list[PoolingRequestOutput]
) -> list[PoolingRequestOutput]:
return outputs
async def post_process_async(
self, outputs: list[PoolingRequestOutput]
) -> list[PoolingRequestOutput]:
return self.post_process(outputs)
def create_pooling_params(self, request):
return request.to_pooling_params()
def _preprocess_completion_online(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokPrompt]:
renderer = self.renderer
model_config = self.model_config
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = request.build_tok_params(model_config)
return renderer.render_cmpl(
parsed_prompts,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
def _preprocess_chat_online(
self,
request: RendererChatRequest,
messages: list[ChatCompletionMessageParam],
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
renderer = self.renderer
default_template_kwargs = merge_kwargs(
default_template_kwargs,
dict(
tools=tool_dicts,
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
)
tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
(conversation,), (engine_prompt,) = renderer.render_chat(
[messages],
chat_params,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)
return conversation, [engine_prompt]
def _preprocess_completion_offline(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]:
renderer = self.renderer
model_config = self.model_config
prompts = prompt_to_seq(prompts)
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
return renderer.render_cmpl(
parsed_prompts,
tok_params,
)
def _validate_chat_template(
self,
request_chat_template: str | None,
chat_template_kwargs: dict[str, Any] | None,
trust_request_chat_template: bool,
):
if not trust_request_chat_template and (
request_chat_template is not None
or (
chat_template_kwargs
and chat_template_kwargs.get("chat_template") is not None
)
):
raise ValueError(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template."
)
return None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import AsyncGenerator, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import ClassVar, Generic, TypeVar
from fastapi import Request
from pydantic import ConfigDict
from starlette.datastructures import Headers
from starlette.responses import JSONResponse
from vllm import (
PoolingParams,
PoolingRequestOutput,
PromptType,
SamplingParams,
envs,
)
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatTemplateConfig,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse
from vllm.inputs import ProcessorInputs
from vllm.lora.request import LoRARequest
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import BeamSearchParams
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import merge_async_iterators
from ...utils import create_error_response
from .io_processor import PoolingIOProcessor
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
@dataclass(kw_only=True)
class PoolingServeContext(Generic[PoolingRequestT]):
request: PoolingRequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[ProcessorInputs] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
class PoolingServing:
request_id_prefix: ClassVar[str]
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
):
super().__init__()
self.engine_client = engine_client
self.models = models
self.model_config = models.model_config
self.max_model_len = self.model_config.max_model_len
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.log_error_stack = log_error_stack
self.chat_template_config = ChatTemplateConfig(
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
trust_request_chat_template=trust_request_chat_template,
)
self.io_processor = self.init_io_processor(
model_config=models.model_config,
renderer=models.renderer,
chat_template_config=self.chat_template_config,
)
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
raise NotImplementedError
async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request,
) -> JSONResponse:
try:
model_name = self.models.model_name()
request_id = (
f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
)
await self._check_model(request)
ctx = PoolingServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
self._validate_request(ctx)
self._maybe_get_adapters(ctx)
await self._preprocess(ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
response = await self._build_response(ctx)
return JSONResponse(content=response.model_dump())
except Exception as e:
error_response = create_error_response(e)
return JSONResponse(
content=error_response.model_dump(),
status_code=error_response.error.code,
)
async def _preprocess(
self,
ctx: PoolingServeContext,
):
ctx.engine_prompts = await self.io_processor.pre_process_online_async(
ctx.request
)
async def _prepare_generators(
self,
ctx: PoolingServeContext,
):
if ctx.engine_prompts is None:
raise ValueError("Engine prompts not available")
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
trace_headers = (
None
if ctx.raw_request is None
else await self._get_trace_headers(ctx.raw_request.headers)
)
pooling_params = self.io_processor.create_pooling_params(ctx.request)
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(generator)
ctx.result_generator = merge_async_iterators(*generators)
async def _collect_batch(
self,
ctx: PoolingServeContext,
):
if ctx.engine_prompts is None:
raise ValueError("Engine prompts not available")
if ctx.result_generator is None:
raise ValueError("Result generator not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
async for i, res in ctx.result_generator:
final_res_batch[i] = res
if None in final_res_batch:
raise ValueError("Failed to generate results for all prompts")
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
async def _build_response(
self,
ctx: PoolingServeContext,
) -> AnyPoolingResponse:
raise NotImplementedError
@staticmethod
def _base_request_id(
raw_request: Request | None, default: str | None = None
) -> str | None:
"""Pulls the request id to use from a header, if provided"""
if raw_request is not None and (
(req_id := raw_request.headers.get("X-Request-Id")) is not None
):
return req_id
return random_uuid() if default is None else default
def _is_model_supported(self, model_name: str | None) -> bool:
if not model_name:
return True
return self.models.is_base_model(model_name)
async def _check_model(
self,
request: AnyPoolingRequest,
) -> ErrorResponse | None:
if self._is_model_supported(request.model):
return None
if request.model in self.models.lora_requests:
return None
if (
envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
and request.model
and (load_result := await self.models.resolve_lora(request.model))
):
if isinstance(load_result, LoRARequest):
return None
if (
isinstance(load_result, ErrorResponse)
and load_result.error.code == HTTPStatus.BAD_REQUEST.value
):
raise ValueError(load_result.error.message)
return None
def _validate_request(self, ctx: PoolingServeContext) -> None:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
if (
truncate_prompt_tokens is not None
and truncate_prompt_tokens > self.max_model_len
):
raise ValueError(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size."
)
return None
async def _get_trace_headers(
self,
headers: Headers,
) -> Mapping[str, str] | None:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
if is_tracing_enabled:
return extract_trace_headers(headers)
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
def _maybe_get_adapters(
self,
ctx: PoolingServeContext,
supports_default_mm_loras: bool = False,
):
request = ctx.request
if request.model in self.models.lora_requests:
ctx.lora_request = self.models.lora_requests[request.model]
# Currently only support default modality specific loras
# if we have exactly one lora matched on the request.
if supports_default_mm_loras:
default_mm_lora = self._get_active_default_mm_loras(request)
if default_mm_lora is not None:
ctx.lora_request = default_mm_lora
if self._is_model_supported(request.model):
return None
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _get_active_default_mm_loras(
self, request: AnyPoolingRequest
) -> LoRARequest | None:
"""Determine if there are any active default multimodal loras."""
# TODO: Currently this is only enabled for chat completions
# to be better aligned with only being enabled for .generate
# when run offline. It would be nice to support additional
# tasks types in the future.
message_types = self._get_message_types(request)
default_mm_loras = set()
for lora in self.models.lora_requests.values():
# Best effort match for default multimodal lora adapters;
# There is probably a better way to do this, but currently
# this matches against the set of 'types' in any content lists
# up until '_', e.g., to match audio_url -> audio
if lora.lora_name in message_types:
default_mm_loras.add(lora)
# Currently only support default modality specific loras if
# we have exactly one lora matched on the request.
if len(default_mm_loras) == 1:
return default_mm_loras.pop()
return None
def _get_message_types(self, request: AnyPoolingRequest) -> set[str]:
"""Retrieve the set of types from message content dicts up
until `_`; we use this to match potential multimodal data
with default per modality loras.
"""
message_types: set[str] = set()
if not hasattr(request, "messages"):
return message_types
messages = request.messages
if messages is None or isinstance(messages, (str, bytes)):
return message_types
for message in messages:
if (
isinstance(message, dict)
and "content" in message
and isinstance(message["content"], list)
):
for content_dict in message["content"]:
if "type" in content_dict:
message_types.add(content_dict["type"].split("_")[0])
return message_types
def _log_inputs(
self,
request_id: str,
inputs: PromptType | ProcessorInputs,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
if self.request_logger is None:
return
components = extract_prompt_components(self.model_config, inputs)
self.request_logger.log_inputs(
request_id,
components.text,
components.token_ids,
components.embeds,
params=params,
lora_request=lora_request,
)
...@@ -3,16 +3,17 @@ ...@@ -3,16 +3,17 @@
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest, ClassificationRequest,
ClassificationResponse,
) )
from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.entrypoints.utils import (
create_error_response,
load_aware_call,
with_cancellation,
)
router = APIRouter() router = APIRouter()
...@@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None: ...@@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None:
@router.post("/classify", dependencies=[Depends(validate_json_request)]) @router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
@load_aware_call @load_aware_call
async def create_classify(request: ClassificationRequest, raw_request: Request): async def create_classify(
request: ClassificationRequest, raw_request: Request
) -> JSONResponse:
handler = classify(raw_request) handler = classify(raw_request)
if handler is None: if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization error_response = create_error_response(
return base_server.create_error_response(
message="The model does not support Classification API" message="The model does not support Classification API"
) )
try:
generator = await handler.create_classify(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code content=error_response.model_dump(),
status_code=error_response.error.code,
) )
elif isinstance(generator, ClassificationResponse): return await handler(request, raw_request)
return JSONResponse(content=generator.model_dump())
assert_never(generator)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
from vllm import PromptType
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
)
from vllm.inputs import ProcessorInputs
from vllm.renderers.inputs import TokPrompt
class ClassifyIOProcessor(PoolingIOProcessor):
def pre_process_online(
self, request: ClassificationCompletionRequest | ClassificationChatRequest
) -> list[TokPrompt] | None:
if isinstance(request, ClassificationChatRequest):
self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
_, engine_prompts = self._preprocess_chat_online(
request,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(request, ClassificationCompletionRequest):
engine_prompts = self._preprocess_completion_online(
request,
prompt_input=request.input,
prompt_embeds=None,
)
else:
raise ValueError("Invalid classification request type")
return engine_prompts
def pre_process_offline(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]:
return self._preprocess_completion_offline(
prompts=prompts, tokenization_kwargs=tokenization_kwargs
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Final, TypeAlias from typing import TypeAlias
import jinja2
import numpy as np import numpy as np
from fastapi import Request
from vllm import ClassificationOutput
from vllm.engine.protocol import EngineClient from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext from vllm.logger import init_logger
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.renderers import BaseRenderer
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, from .io_processor import ClassifyIOProcessor
ClassificationCompletionRequest, from .protocol import (
ClassificationData, ClassificationData,
ClassificationRequest, ClassificationRequest,
ClassificationResponse, ClassificationResponse,
) )
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput
logger = init_logger(__name__) logger = init_logger(__name__)
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest] ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing): class ServingClassification(PoolingServing):
request_id_prefix = "classify" request_id_prefix = "classify"
def __init__( def init_io_processor(
self, self,
engine_client: EngineClient, model_config: ModelConfig,
models: OpenAIServingModels, renderer: BaseRenderer,
*, chat_template_config: ChatTemplateConfig,
request_logger: RequestLogger | None, ) -> ClassifyIOProcessor:
chat_template: str | None = None, return ClassifyIOProcessor(
chat_template_content_format: ChatTemplateContentFormatOption = "auto", model_config=model_config,
trust_request_chat_template: bool = False, renderer=renderer,
log_error_stack: bool = False, chat_template_config=chat_template_config,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
) )
self.chat_template = chat_template async def _build_response(
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess(
self, self,
ctx: ClassificationServeContext, ctx: ClassificationServeContext,
) -> ErrorResponse | None: ) -> ClassificationResponse:
""" final_res_batch_checked = await self.io_processor.post_process_async(
Process classification inputs: tokenize text, resolve adapters, ctx.final_res_batch
and prepare model-specific inputs. )
"""
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, ClassificationChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
ctx.request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(ctx.request, ClassificationCompletionRequest):
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
return None
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_response(
self,
ctx: ClassificationServeContext,
) -> ClassificationResponse | ErrorResponse:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
id2label = getattr(self.model_config.hf_config, "id2label", {})
items: list[ClassificationData] = [] id2label = getattr(self.model_config.hf_config, "id2label", {})
num_prompt_tokens = 0 num_prompt_tokens = 0
items: list[ClassificationData] = []
final_res_batch_checked = ctx.final_res_batch
for idx, final_res in enumerate(final_res_batch_checked): for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs) classify_res = ClassificationOutput.from_base(final_res.outputs)
...@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing): ...@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing):
data=items, data=items,
usage=usage, usage=usage,
) )
async def create_classify(
self,
request: ClassificationRequest,
raw_request: Request,
) -> ClassificationResponse | ErrorResponse:
model_name = self.models.model_name()
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
ctx = ClassificationServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
)
return await self.handle(ctx) # type: ignore[return-value]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.renderers import BaseRenderer
from vllm.tasks import SupportedTask
def init_pooling_io_processors(
supported_tasks: tuple[SupportedTask, ...],
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> dict[str, PoolingIOProcessor]:
pooling_io_processors: dict[str, PoolingIOProcessor] = {}
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.io_processor import (
ClassifyIOProcessor,
)
pooling_io_processors["classify"] = ClassifyIOProcessor(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
return pooling_io_processors
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
PoolingCompletionLikeRequest: TypeAlias = (
EmbeddingCompletionRequest
| ClassificationCompletionRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest
)
PoolingChatLikeRequest: TypeAlias = (
EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
)
AnyPoolingRequest: TypeAlias = (
PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest
)
AnyPoolingResponse: TypeAlias = (
ClassificationResponse
| EmbeddingResponse
| EmbeddingBytesResponse
| PoolingResponse
| ScoreResponse
)
...@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse, Response ...@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse, Response
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.serve.instrumentator.basic import base from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health from vllm.entrypoints.serve.instrumentator.health import health
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
...@@ -20,7 +21,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask ...@@ -20,7 +21,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13) # (requires typing_extensions >= 4.13)
RequestType = Any RequestType = Any
GetHandlerFn = Callable[[Request], OpenAIServing | None] GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
......
...@@ -5,7 +5,10 @@ import asyncio ...@@ -5,7 +5,10 @@ import asyncio
import dataclasses import dataclasses
import functools import functools
import os import os
import sys
import traceback
from argparse import Namespace from argparse import Namespace
from http import HTTPStatus
from logging import Logger from logging import Logger
from string import Template from string import Template
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
...@@ -17,17 +20,23 @@ from starlette.background import BackgroundTask, BackgroundTasks ...@@ -17,17 +20,23 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs from vllm import envs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.exceptions import VLLMValidationError
from vllm.logger import current_formatter_type, init_logger from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.openai.engine.protocol import StreamOptions from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath from vllm.entrypoints.openai.models.protocol import LoRAModulePath
else: else:
StreamOptions = object ErrorResponse = object
ErrorInfo = object
LoRAModulePath = object LoRAModulePath = object
StreamOptions = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -291,3 +300,59 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None: ...@@ -291,3 +300,59 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None:
message = logo_template.substitute(colors) message = logo_template.substitute(colors)
lgr.info(message, version, model_name) lgr.info(message, version, model_name)
def create_error_response(
message: str | Exception,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None,
log_error_stack: bool = False,
) -> "ErrorResponse":
exc: Exception | None = None
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
if isinstance(message, Exception):
exc = message
if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
# Common validation errors from user input
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
elif isinstance(exc, NotImplementedError):
err_type = "NotImplementedError"
status_code = HTTPStatus.NOT_IMPLEMENTED
param = None
elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
else:
err_type = "InternalServerError"
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
param = None
message = str(exc)
if log_error_stack:
exc_type, _, _ = sys.exc_info()
if exc_type is not None:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse(
error=ErrorInfo(
message=sanitize_message(message),
type=err_type,
code=status_code.value,
param=param,
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment