Unverified Commit c831911b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Reduce mixin usage in serving pooling (#33101)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 157caf51
......@@ -64,13 +64,12 @@ from vllm.entrypoints.openai.translations.protocol import (
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
......@@ -170,6 +169,7 @@ AnyResponse: TypeAlias = (
CompletionResponse
| ChatCompletionResponse
| EmbeddingResponse
| EmbeddingBytesResponse
| TranscriptionResponse
| TokenizeResponse
| PoolingResponse
......@@ -183,51 +183,21 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@dataclass(kw_only=True)
class RequestProcessingMixin:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
@dataclass(kw_only=True)
class ResponseGenerationMixin:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list
)
model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
class ServeContext(Generic[RequestT]):
request: RequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption
model_config = ConfigDict(arbitrary_types_allowed=True)
class OpenAIServing:
......@@ -605,10 +575,7 @@ class OpenAIServing:
self,
ctx: ServeContext,
) -> AnyResponse | ErrorResponse:
generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
generation = self._pipeline(ctx)
async for response in generation:
async for response in self._pipeline(ctx):
return response
return self.create_error_response("No response yielded from pipeline")
......@@ -667,9 +634,7 @@ class OpenAIServing:
ctx: ServeContext,
) -> ErrorResponse | None:
"""Schedule the request and get the result generator."""
generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
trace_headers = (
......@@ -723,7 +688,7 @@ class OpenAIServing:
return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
if ctx.result_generator is None:
......@@ -1011,7 +976,7 @@ class OpenAIServing:
def _validate_input(
self,
request: AnyRequest,
request: object,
input_ids: list[int],
input_text: str,
) -> TokensPrompt:
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import cast
from typing import Final, cast
import jinja2
import numpy as np
......@@ -11,18 +11,8 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
ClassificationServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
......@@ -39,60 +29,68 @@ from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing):
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
trust_request_chat_template: bool
ClassificationServeContext = ServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess(
self,
ctx: ServeContext,
ctx: ClassificationServeContext,
) -> ErrorResponse | None:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
ctx = cast(ClassificationServeContext, ctx)
try:
request_obj = ctx.request
if isinstance(request_obj, ClassificationChatRequest):
chat_request = request_obj
messages = chat_request.messages
trust_request_chat_template = getattr(
self,
"trust_request_chat_template",
False,
)
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, ClassificationChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if ret:
return ret
if error_check_ret:
return error_check_ret
_, engine_prompts = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request),
ctx.request,
self.renderer,
messages,
chat_template=(
chat_request.chat_template
or getattr(self, "chat_template", None)
),
chat_template_content_format=cast(
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
)
ctx.engine_prompts = engine_prompts
elif isinstance(request_obj, ClassificationCompletionRequest):
completion_request = request_obj
input_data = completion_request.input
elif isinstance(ctx.request, ClassificationCompletionRequest):
input_data = ctx.request.input
if input_data in (None, ""):
return self.create_error_response(
"Input or messages must be provided",
......@@ -106,13 +104,10 @@ class ClassificationMixin(OpenAIServing):
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(completion_request),
config=self._build_render_config(ctx.request),
)
else:
return self.create_error_response(
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return self.create_error_response("Invalid classification request type")
return None
......@@ -122,13 +117,14 @@ class ClassificationMixin(OpenAIServing):
def _build_response(
self,
ctx: ServeContext,
ctx: ClassificationServeContext,
) -> ClassificationResponse | ErrorResponse:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
ctx = cast(ClassificationServeContext, ctx)
id2label = getattr(self.model_config.hf_config, "id2label", {})
items: list[ClassificationData] = []
num_prompt_tokens = 0
......@@ -139,9 +135,7 @@ class ClassificationMixin(OpenAIServing):
probs = classify_res.probs
predicted_index = int(np.argmax(probs))
label = getattr(self.model_config.hf_config, "id2label", {}).get(
predicted_index
)
label = id2label.get(predicted_index)
item = ClassificationData(
index=idx,
......@@ -174,32 +168,6 @@ class ClassificationMixin(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify(
self,
request: ClassificationRequest,
......@@ -215,11 +183,11 @@ class ServingClassification(ClassificationMixin):
request_id=request_id,
)
return await super().handle(ctx) # type: ignore
return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params(
self,
ctx: ServeContext[ClassificationRequest],
ctx: ClassificationServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
......
......@@ -6,21 +6,13 @@ from typing import Any, Final, cast
import torch
from fastapi import Request
from fastapi.responses import Response
from typing_extensions import assert_never, override
from typing_extensions import assert_never
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
......@@ -33,19 +25,11 @@ from vllm.entrypoints.pooling.embed.protocol import (
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import (
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes,
encode_pooling_output,
)
......@@ -53,9 +37,33 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__)
class EmbeddingMixin(OpenAIServing):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
EmbeddingServeContext = ServeContext[EmbeddingRequest]
class OpenAIServingEmbedding(OpenAIServing):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
pooler_config = self.model_config.pooler_config
......@@ -69,32 +77,41 @@ class EmbeddingMixin(OpenAIServing):
else None
)
@override
async def _preprocess(
self,
ctx: ServeContext,
ctx: EmbeddingServeContext,
) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or ctx.chat_template,
chat_template_content_format=ctx.chat_template_content_format,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
)
else:
elif isinstance(ctx.request, EmbeddingCompletionRequest):
renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),
)
else:
return self.create_error_response("Invalid classification request type")
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
......@@ -113,16 +130,15 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
@override
def _build_response(
self,
ctx: ServeContext,
) -> EmbeddingResponse | Response | ErrorResponse:
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
ctx: EmbeddingServeContext,
) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
final_res_batch_checked = ctx.final_res_batch
encoding_format: EncodingFormat = ctx.request.encoding_format
embed_dtype: EmbedDType = ctx.request.embed_dtype
endianness: Endianness = ctx.request.endianness
encoding_format = ctx.request.encoding_format
embed_dtype = ctx.request.embed_dtype
endianness = ctx.request.endianness
def encode_float_base64():
items: list[EmbeddingResponseData] = []
......@@ -203,8 +219,8 @@ class EmbeddingMixin(OpenAIServing):
self,
ctx: EmbeddingServeContext,
token_ids: list[int],
pooling_params,
trace_headers,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing."""
......@@ -246,7 +262,7 @@ class EmbeddingMixin(OpenAIServing):
def _validate_input(
self,
request,
request: object,
input_ids: list[int],
input_text: str,
) -> TokensPrompt:
......@@ -326,7 +342,7 @@ class EmbeddingMixin(OpenAIServing):
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}"
......@@ -347,7 +363,6 @@ class EmbeddingMixin(OpenAIServing):
priority=getattr(ctx.request, "priority", 0),
)
@override
async def _prepare_generators(
self,
ctx: ServeContext,
......@@ -363,9 +378,7 @@ class EmbeddingMixin(OpenAIServing):
return await super()._prepare_generators(ctx)
# Custom logic for chunked processing
generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
trace_headers = (
......@@ -419,10 +432,9 @@ class EmbeddingMixin(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@override
async def _collect_batch(
self,
ctx: ServeContext,
ctx: EmbeddingServeContext,
) -> ErrorResponse | None:
"""Collect and aggregate batch results
with support for chunked processing.
......@@ -431,7 +443,6 @@ class EmbeddingMixin(OpenAIServing):
minimize memory usage.
For regular requests, collects results normally.
"""
ctx = cast(EmbeddingServeContext, ctx)
try:
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
......@@ -527,12 +538,10 @@ class EmbeddingMixin(OpenAIServing):
except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = cast(
PoolingRequestOutput, result
)
short_prompts_results[prompt_idx] = result
# Finalize aggregated results
final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
final_res_batch: list[PoolingRequestOutput] = []
num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts):
......@@ -580,49 +589,19 @@ class EmbeddingMixin(OpenAIServing):
f"Failed to aggregate chunks for prompt {prompt_idx}"
)
elif prompt_idx in short_prompts_results:
final_res_batch.append(
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
)
final_res_batch.append(short_prompts_results[prompt_idx])
else:
return self.create_error_response(
f"Result not found for prompt {prompt_idx}"
)
ctx.final_res_batch = cast(
list[RequestOutput | PoolingRequestOutput], final_res_batch
)
ctx.final_res_batch = final_res_batch
return None
except Exception as e:
return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding(
self,
request: EmbeddingRequest,
......@@ -645,16 +624,13 @@ class OpenAIServingEmbedding(EmbeddingMixin):
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
return await super().handle(ctx) # type: ignore
return await self.handle(ctx) # type: ignore[return-value]
@override
def _create_pooling_params(
self,
ctx: ServeContext[EmbeddingRequest],
ctx: EmbeddingServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
......@@ -666,17 +642,3 @@ class OpenAIServingEmbedding(EmbeddingMixin):
return self.create_error_response(str(e))
return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment