Unverified Commit c831911b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Reduce mixin usage in serving pooling (#33101)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 157caf51
...@@ -64,13 +64,12 @@ from vllm.entrypoints.openai.translations.protocol import ( ...@@ -64,13 +64,12 @@ from vllm.entrypoints.openai.translations.protocol import (
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
ClassificationCompletionRequest, ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse, ClassificationResponse,
) )
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest, EmbeddingChatRequest,
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
) )
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.entrypoints.pooling.pooling.protocol import (
...@@ -170,6 +169,7 @@ AnyResponse: TypeAlias = ( ...@@ -170,6 +169,7 @@ AnyResponse: TypeAlias = (
CompletionResponse CompletionResponse
| ChatCompletionResponse | ChatCompletionResponse
| EmbeddingResponse | EmbeddingResponse
| EmbeddingBytesResponse
| TranscriptionResponse | TranscriptionResponse
| TokenizeResponse | TokenizeResponse
| PoolingResponse | PoolingResponse
...@@ -183,51 +183,21 @@ RequestT = TypeVar("RequestT", bound=AnyRequest) ...@@ -183,51 +183,21 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@dataclass(kw_only=True) @dataclass(kw_only=True)
class RequestProcessingMixin: class ServeContext(Generic[RequestT]):
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
@dataclass(kw_only=True)
class ResponseGenerationMixin:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list
)
model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
request: RequestT request: RequestT
raw_request: Request | None = None raw_request: Request | None = None
model_name: str model_name: str
request_id: str request_id: str
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
@dataclass(kw_only=True) model_config = ConfigDict(arbitrary_types_allowed=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption
class OpenAIServing: class OpenAIServing:
...@@ -605,10 +575,7 @@ class OpenAIServing: ...@@ -605,10 +575,7 @@ class OpenAIServing:
self, self,
ctx: ServeContext, ctx: ServeContext,
) -> AnyResponse | ErrorResponse: ) -> AnyResponse | ErrorResponse:
generation: AsyncGenerator[AnyResponse | ErrorResponse, None] async for response in self._pipeline(ctx):
generation = self._pipeline(ctx)
async for response in generation:
return response return response
return self.create_error_response("No response yielded from pipeline") return self.create_error_response("No response yielded from pipeline")
...@@ -667,9 +634,7 @@ class OpenAIServing: ...@@ -667,9 +634,7 @@ class OpenAIServing:
ctx: ServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Schedule the request and get the result generator.""" """Schedule the request and get the result generator."""
generators: list[ generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try: try:
trace_headers = ( trace_headers = (
...@@ -723,7 +688,7 @@ class OpenAIServing: ...@@ -723,7 +688,7 @@ class OpenAIServing:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_prompts)
final_res_batch: list[RequestOutput | PoolingRequestOutput | None] final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts final_res_batch = [None] * num_prompts
if ctx.result_generator is None: if ctx.result_generator is None:
...@@ -1011,7 +976,7 @@ class OpenAIServing: ...@@ -1011,7 +976,7 @@ class OpenAIServing:
def _validate_input( def _validate_input(
self, self,
request: AnyRequest, request: object,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TokensPrompt: ) -> TokensPrompt:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus from http import HTTPStatus
from typing import cast from typing import Final, cast
import jinja2 import jinja2
import numpy as np import numpy as np
...@@ -11,18 +11,8 @@ from fastapi import Request ...@@ -11,18 +11,8 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
ChatCompletionRequest, from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
ClassificationServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
...@@ -39,60 +29,68 @@ from vllm.pooling_params import PoolingParams ...@@ -39,60 +29,68 @@ from vllm.pooling_params import PoolingParams
logger = init_logger(__name__) logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing): ClassificationServeContext = ServeContext[ClassificationRequest]
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
trust_request_chat_template: bool class ServingClassification(OpenAIServing):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess( async def _preprocess(
self, self,
ctx: ServeContext, ctx: ClassificationServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
""" """
Process classification inputs: tokenize text, resolve adapters, Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs. and prepare model-specific inputs.
""" """
ctx = cast(ClassificationServeContext, ctx)
try: try:
request_obj = ctx.request ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(request_obj, ClassificationChatRequest): if isinstance(ctx.request, ClassificationChatRequest):
chat_request = request_obj error_check_ret = self._validate_chat_template(
messages = chat_request.messages request_chat_template=ctx.request.chat_template,
trust_request_chat_template = getattr( chat_template_kwargs=ctx.request.chat_template_kwargs,
self, trust_request_chat_template=self.trust_request_chat_template,
"trust_request_chat_template",
False,
)
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
) )
if ret: if error_check_ret:
return ret return error_check_ret
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request), ctx.request,
self.renderer, self.renderer,
messages, ctx.request.messages,
chat_template=( chat_template=ctx.request.chat_template or self.chat_template,
chat_request.chat_template chat_template_content_format=self.chat_template_content_format,
or getattr(self, "chat_template", None) add_generation_prompt=ctx.request.add_generation_prompt,
), continue_final_message=ctx.request.continue_final_message,
chat_template_content_format=cast( add_special_tokens=ctx.request.add_special_tokens,
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
) )
ctx.engine_prompts = engine_prompts ctx.engine_prompts = engine_prompts
elif isinstance(request_obj, ClassificationCompletionRequest): elif isinstance(ctx.request, ClassificationCompletionRequest):
completion_request = request_obj input_data = ctx.request.input
input_data = completion_request.input
if input_data in (None, ""): if input_data in (None, ""):
return self.create_error_response( return self.create_error_response(
"Input or messages must be provided", "Input or messages must be provided",
...@@ -106,13 +104,10 @@ class ClassificationMixin(OpenAIServing): ...@@ -106,13 +104,10 @@ class ClassificationMixin(OpenAIServing):
prompt_input = cast(str | list[str], input_data) prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input, prompt_or_prompts=prompt_input,
config=self._build_render_config(completion_request), config=self._build_render_config(ctx.request),
) )
else: else:
return self.create_error_response( return self.create_error_response("Invalid classification request type")
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return None return None
...@@ -122,13 +117,14 @@ class ClassificationMixin(OpenAIServing): ...@@ -122,13 +117,14 @@ class ClassificationMixin(OpenAIServing):
def _build_response( def _build_response(
self, self,
ctx: ServeContext, ctx: ClassificationServeContext,
) -> ClassificationResponse | ErrorResponse: ) -> ClassificationResponse | ErrorResponse:
""" """
Convert model outputs to a formatted classification response Convert model outputs to a formatted classification response
with probabilities and labels. with probabilities and labels.
""" """
ctx = cast(ClassificationServeContext, ctx) id2label = getattr(self.model_config.hf_config, "id2label", {})
items: list[ClassificationData] = [] items: list[ClassificationData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
...@@ -139,9 +135,7 @@ class ClassificationMixin(OpenAIServing): ...@@ -139,9 +135,7 @@ class ClassificationMixin(OpenAIServing):
probs = classify_res.probs probs = classify_res.probs
predicted_index = int(np.argmax(probs)) predicted_index = int(np.argmax(probs))
label = getattr(self.model_config.hf_config, "id2label", {}).get( label = id2label.get(predicted_index)
predicted_index
)
item = ClassificationData( item = ClassificationData(
index=idx, index=idx,
...@@ -174,32 +168,6 @@ class ClassificationMixin(OpenAIServing): ...@@ -174,32 +168,6 @@ class ClassificationMixin(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify( async def create_classify(
self, self,
request: ClassificationRequest, request: ClassificationRequest,
...@@ -215,11 +183,11 @@ class ServingClassification(ClassificationMixin): ...@@ -215,11 +183,11 @@ class ServingClassification(ClassificationMixin):
request_id=request_id, request_id=request_id,
) )
return await super().handle(ctx) # type: ignore return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ServeContext[ClassificationRequest], ctx: ClassificationServeContext,
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
......
...@@ -6,21 +6,13 @@ from typing import Any, Final, cast ...@@ -6,21 +6,13 @@ from typing import Any, Final, cast
import torch import torch
from fastapi import Request from fastapi import Request
from fastapi.responses import Response from typing_extensions import assert_never
from typing_extensions import assert_never, override
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
ErrorResponse, from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse, EmbeddingBytesResponse,
...@@ -33,19 +25,11 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -33,19 +25,11 @@ from vllm.entrypoints.pooling.embed.protocol import (
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ( from vllm.outputs import PoolingOutput, PoolingRequestOutput
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes, encode_pooling_bytes,
encode_pooling_output, encode_pooling_output,
) )
...@@ -53,9 +37,33 @@ from vllm.utils.serial_utils import ( ...@@ -53,9 +37,33 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
class EmbeddingMixin(OpenAIServing): EmbeddingServeContext = ServeContext[EmbeddingRequest]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class OpenAIServingEmbedding(OpenAIServing):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
pooler_config = self.model_config.pooler_config pooler_config = self.model_config.pooler_config
...@@ -69,32 +77,41 @@ class EmbeddingMixin(OpenAIServing): ...@@ -69,32 +77,41 @@ class EmbeddingMixin(OpenAIServing):
else None else None
) )
@override
async def _preprocess( async def _preprocess(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat( _, ctx.engine_prompts = await self._preprocess_chat(
ctx.request, ctx.request,
self.renderer, self.renderer,
ctx.request.messages, ctx.request.messages,
chat_template=ctx.request.chat_template or ctx.chat_template, chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=ctx.chat_template_content_format, chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt, add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message, continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens, add_special_tokens=ctx.request.add_special_tokens,
) )
else: elif isinstance(ctx.request, EmbeddingCompletionRequest):
renderer = self._get_completion_renderer() renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input, prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request), config=self._build_render_config(ctx.request),
) )
else:
return self.create_error_response("Invalid classification request type")
return None return None
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
...@@ -113,16 +130,15 @@ class EmbeddingMixin(OpenAIServing): ...@@ -113,16 +130,15 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
@override
def _build_response( def _build_response(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> EmbeddingResponse | Response | ErrorResponse: ) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) final_res_batch_checked = ctx.final_res_batch
encoding_format: EncodingFormat = ctx.request.encoding_format encoding_format = ctx.request.encoding_format
embed_dtype: EmbedDType = ctx.request.embed_dtype embed_dtype = ctx.request.embed_dtype
endianness: Endianness = ctx.request.endianness endianness = ctx.request.endianness
def encode_float_base64(): def encode_float_base64():
items: list[EmbeddingResponseData] = [] items: list[EmbeddingResponseData] = []
...@@ -203,8 +219,8 @@ class EmbeddingMixin(OpenAIServing): ...@@ -203,8 +219,8 @@ class EmbeddingMixin(OpenAIServing):
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
token_ids: list[int], token_ids: list[int],
pooling_params, pooling_params: PoolingParams,
trace_headers, trace_headers: Mapping[str, str] | None,
prompt_idx: int, prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]: ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing.""" """Process a single prompt using chunked processing."""
...@@ -246,7 +262,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -246,7 +262,7 @@ class EmbeddingMixin(OpenAIServing):
def _validate_input( def _validate_input(
self, self,
request, request: object,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TokensPrompt: ) -> TokensPrompt:
...@@ -326,7 +342,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -326,7 +342,7 @@ class EmbeddingMixin(OpenAIServing):
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Create a generator for a single prompt using standard processing.""" """Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}" request_id_item = f"{ctx.request_id}-{prompt_index}"
...@@ -347,7 +363,6 @@ class EmbeddingMixin(OpenAIServing): ...@@ -347,7 +363,6 @@ class EmbeddingMixin(OpenAIServing):
priority=getattr(ctx.request, "priority", 0), priority=getattr(ctx.request, "priority", 0),
) )
@override
async def _prepare_generators( async def _prepare_generators(
self, self,
ctx: ServeContext, ctx: ServeContext,
...@@ -363,9 +378,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -363,9 +378,7 @@ class EmbeddingMixin(OpenAIServing):
return await super()._prepare_generators(ctx) return await super()._prepare_generators(ctx)
# Custom logic for chunked processing # Custom logic for chunked processing
generators: list[ generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try: try:
trace_headers = ( trace_headers = (
...@@ -419,10 +432,9 @@ class EmbeddingMixin(OpenAIServing): ...@@ -419,10 +432,9 @@ class EmbeddingMixin(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@override
async def _collect_batch( async def _collect_batch(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Collect and aggregate batch results """Collect and aggregate batch results
with support for chunked processing. with support for chunked processing.
...@@ -431,7 +443,6 @@ class EmbeddingMixin(OpenAIServing): ...@@ -431,7 +443,6 @@ class EmbeddingMixin(OpenAIServing):
minimize memory usage. minimize memory usage.
For regular requests, collects results normally. For regular requests, collects results normally.
""" """
ctx = cast(EmbeddingServeContext, ctx)
try: try:
if ctx.engine_prompts is None: if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
...@@ -527,12 +538,10 @@ class EmbeddingMixin(OpenAIServing): ...@@ -527,12 +538,10 @@ class EmbeddingMixin(OpenAIServing):
except (ValueError, IndexError): except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = cast( short_prompts_results[prompt_idx] = result
PoolingRequestOutput, result
)
# Finalize aggregated results # Finalize aggregated results
final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = [] final_res_batch: list[PoolingRequestOutput] = []
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts): for prompt_idx in range(num_prompts):
...@@ -580,49 +589,19 @@ class EmbeddingMixin(OpenAIServing): ...@@ -580,49 +589,19 @@ class EmbeddingMixin(OpenAIServing):
f"Failed to aggregate chunks for prompt {prompt_idx}" f"Failed to aggregate chunks for prompt {prompt_idx}"
) )
elif prompt_idx in short_prompts_results: elif prompt_idx in short_prompts_results:
final_res_batch.append( final_res_batch.append(short_prompts_results[prompt_idx])
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
)
else: else:
return self.create_error_response( return self.create_error_response(
f"Result not found for prompt {prompt_idx}" f"Result not found for prompt {prompt_idx}"
) )
ctx.final_res_batch = cast( ctx.final_res_batch = final_res_batch
list[RequestOutput | PoolingRequestOutput], final_res_batch
)
return None return None
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding( async def create_embedding(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
...@@ -645,16 +624,13 @@ class OpenAIServingEmbedding(EmbeddingMixin): ...@@ -645,16 +624,13 @@ class OpenAIServingEmbedding(EmbeddingMixin):
raw_request=raw_request, raw_request=raw_request,
model_name=model_name, model_name=model_name,
request_id=request_id, request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
) )
return await super().handle(ctx) # type: ignore return await self.handle(ctx) # type: ignore[return-value]
@override
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ServeContext[EmbeddingRequest], ctx: EmbeddingServeContext,
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
...@@ -666,17 +642,3 @@ class OpenAIServingEmbedding(EmbeddingMixin): ...@@ -666,17 +642,3 @@ class OpenAIServingEmbedding(EmbeddingMixin):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return pooling_params return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment