Unverified Commit f0a1c845 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Use new Renderer for Completions and Tokenize API (#32863)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8980001c
......@@ -5,65 +5,10 @@ import pytest
from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_raw_prompts
from vllm.inputs.preprocess import InputPreprocessor
pytestmark = pytest.mark.cpu_test
STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
def test_invalid_input_raise_type_error(invalid_input):
with pytest.raises(TypeError):
parse_raw_prompts(invalid_input)
def test_parse_raw_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([])
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([[]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_parse_raw_single_batch_string_consistent(string_input: str):
assert parse_raw_prompts(string_input) == parse_raw_prompts([string_input])
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
assert parse_raw_prompts(token_input) == parse_raw_prompts([token_input])
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] == parse_raw_prompts(
STRING_INPUTS[inputs_slice]
)
@pytest.mark.parametrize(
"mm_processor_kwargs,expected_mm_kwargs",
......
......@@ -768,7 +768,7 @@ class ModelConfig:
)
self.tokenizer = object_storage_tokenizer.dir
def _get_encoder_config(self):
def _get_encoder_config(self) -> dict[str, Any] | None:
model = self.model
if is_remote_gguf(model):
model, _ = split_remote_gguf(model)
......@@ -1918,7 +1918,7 @@ def _get_and_verify_max_len(
disable_sliding_window: bool,
sliding_window: int | None,
spec_target_max_model_len: int | None = None,
encoder_config: Any | None = None,
encoder_config: dict[str, Any] | None = None,
) -> int:
"""Get and verify the model's maximum length."""
(derived_max_model_len, max_len_key) = (
......
......@@ -72,14 +72,9 @@ class EngineClient(ABC):
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.
NOTE: truncate_prompt_tokens is deprecated in v0.14.
TODO: Remove this argument in v0.15.
"""
"""Generate outputs for a request from a pooling model."""
...
@abstractmethod
......
This diff is collapsed.
......@@ -13,12 +13,13 @@ from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
from pydantic import (
Field,
model_validator,
)
from pydantic import Field, model_validator
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat,
DeltaMessage,
......@@ -36,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
......@@ -348,6 +350,43 @@ class ChatCompletionRequest(OpenAIBaseModel):
# --8<-- [end:chat-completion-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
documents=self.documents,
reasoning_effort=self.reasoning_effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
if self.max_completion_tokens is not None:
max_output_tokens: int | None = self.max_completion_tokens
max_output_tokens_param = "max_completion_tokens"
else:
max_output_tokens = self.max_tokens
max_output_tokens_param = "max_tokens"
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=max_output_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param=max_output_tokens_param,
)
# Default sampling parameters for chat completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
......
......@@ -67,7 +67,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
)
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob
......@@ -185,8 +185,6 @@ class OpenAIServingChat(OpenAIServing):
start_time = time.perf_counter()
try:
renderer = self.engine_client.renderer
# Create a minimal dummy request
dummy_request = ChatCompletionRequest(
messages=[{"role": "user", "content": "warmup"}],
......@@ -201,18 +199,10 @@ class OpenAIServingChat(OpenAIServing):
# 3. Tokenizer initialization for chat
await self._preprocess_chat(
dummy_request,
renderer,
dummy_request.messages,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=True,
continue_final_message=False,
tool_dicts=None,
documents=None,
chat_template_kwargs=None,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=None,
add_special_tokens=False,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
)
elapsed = (time.perf_counter() - start_time) * 1000
......@@ -225,7 +215,10 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request(
self,
request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[Any]] | ErrorResponse:
) -> (
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
| ErrorResponse
):
"""
render chat request by validating and preprocessing inputs.
......@@ -302,23 +295,14 @@ class OpenAIServingChat(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
chat_template_kwargs = request.chat_template_kwargs or {}
chat_template_kwargs.update(reasoning_effort=request.reasoning_effort)
conversation, engine_prompts = await self._preprocess_chat(
request,
renderer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
documents=request.documents,
chat_template_kwargs=chat_template_kwargs,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=tool_parser,
add_special_tokens=request.add_special_tokens,
)
else:
# For GPT-OSS.
......@@ -428,11 +412,15 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
......
......@@ -9,11 +9,9 @@ from dataclasses import replace
from typing import Annotated, Any, Literal
import torch
from pydantic import (
Field,
model_validator,
)
from pydantic import Field, model_validator
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat,
LegacyStructuralTagResponseFormat,
......@@ -27,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
......@@ -178,6 +177,17 @@ class CompletionRequest(OpenAIBaseModel):
# --8<-- [end:completion-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_tokens",
)
# Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
......
......@@ -32,7 +32,6 @@ from vllm.entrypoints.openai.engine.serving import (
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
......@@ -111,11 +110,10 @@ class OpenAIServingCompletion(OpenAIServing):
)
try:
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds,
config=self._build_render_config(request),
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
......@@ -203,10 +201,6 @@ class OpenAIServingCompletion(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
......@@ -216,11 +210,15 @@ class OpenAIServingCompletion(OpenAIServing):
trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id_item,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
......@@ -709,26 +707,3 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: int | None = None,
) -> RenderConfig:
# Validate max_tokens before using it
if request.max_tokens is not None and request.max_tokens > self.max_model_len:
raise VLLMValidationError(
f"'max_tokens' ({request.max_tokens}) cannot be greater than "
f"the model's maximum context length ({self.max_model_len}).",
parameter="max_tokens",
value=request.max_tokens,
)
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo and not request.return_token_ids),
)
......@@ -16,9 +16,7 @@ from pydantic import (
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.logger import init_logger
from vllm.sampling_params import (
SamplingParams,
)
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
......
This diff is collapsed.
......@@ -43,7 +43,6 @@ from vllm.entrypoints.openai.responses.protocol import (
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.renderers import RendererLike
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.utils import random_uuid
......@@ -261,7 +260,7 @@ class ParsableContext(ConversationContext):
self,
*,
response_messages: list[ResponseInputOutputItem],
renderer: RendererLike,
tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
request: ResponsesRequest,
available_tools: list[str] | None,
......@@ -280,7 +279,6 @@ class ParsableContext(ConversationContext):
if reasoning_parser_cls is None:
raise ValueError("reasoning_parser_cls must be provided.")
tokenizer = renderer.get_tokenizer()
self.parser = get_responses_parser_for_simple_context(
tokenizer=tokenizer,
reasoning_parser_cls=reasoning_parser_cls,
......@@ -290,8 +288,6 @@ class ParsableContext(ConversationContext):
)
self.tool_parser_cls = tool_parser_cls
self.request = request
self.renderer = renderer
self.tokenizer = tokenizer
self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {}
......
......@@ -59,12 +59,15 @@ from pydantic import (
model_validator,
)
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import (
OpenAIBaseModel,
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
RequestOutputKind,
SamplingParams,
......@@ -230,6 +233,42 @@ class ResponsesRequest(OpenAIBaseModel):
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
# --8<-- [end:responses-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
from .utils import should_continue_final_message
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(self.input)
reasoning = self.reasoning
return ChatParams(
chat_template=default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs( # To remove unset values
{},
dict(
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
reasoning_effort=None if reasoning is None else reasoning.effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_output_tokens or 0,
truncate_prompt_tokens=-1 if self.truncation != "disabled" else None,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_output_tokens",
)
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,
......
......@@ -114,16 +114,15 @@ from vllm.entrypoints.openai.responses.utils import (
construct_input_messages,
construct_tool_dicts,
extract_tool_types,
should_continue_final_message,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
......@@ -291,13 +290,14 @@ class OpenAIServingResponses(OpenAIServing):
self.tool_server = tool_server
def _validate_generator_input(
self, engine_prompt: TokensPrompt
self,
engine_prompt: TokensPrompt | EmbedsPrompt,
) -> ErrorResponse | None:
"""Add validations to the input to the generator here."""
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]):
prompt_len = get_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len:
error_message = (
"The engine prompt length"
f" {len(engine_prompt['prompt_token_ids'])} "
f"The engine prompt length {prompt_len} "
f"exceeds the max_model_len {self.max_model_len}. "
"Please reduce prompt."
)
......@@ -307,6 +307,7 @@ class OpenAIServingResponses(OpenAIServing):
status_code=HTTPStatus.BAD_REQUEST,
param="input",
)
return None
def _validate_create_responses_input(
......@@ -387,8 +388,6 @@ class OpenAIServingResponses(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request)
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony(
......@@ -396,7 +395,7 @@ class OpenAIServingResponses(OpenAIServing):
)
else:
messages, engine_prompts = await self._make_request(
request, prev_response, renderer
request, prev_response
)
except (
......@@ -431,6 +430,9 @@ class OpenAIServingResponses(OpenAIServing):
assert len(builtin_tool_list) == 0
available_tools = []
try:
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt)
if maybe_error is not None:
......@@ -446,6 +448,7 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
tok_params = request.build_tok_params(self.model_config)
trace_headers = (
None
......@@ -465,7 +468,7 @@ class OpenAIServingResponses(OpenAIServing):
# tokens during generation instead of at the end
context = ParsableContext(
response_messages=messages,
renderer=renderer,
tokenizer=tokenizer,
reasoning_parser_cls=self.reasoning_parser,
request=request,
tool_parser_cls=self.tool_parser,
......@@ -495,6 +498,7 @@ class OpenAIServingResponses(OpenAIServing):
request_id=request.request_id,
engine_prompt=engine_prompt,
sampling_params=sampling_params,
tok_params=tok_params,
context=context,
lora_request=lora_request,
priority=request.priority,
......@@ -596,7 +600,6 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
prev_response: ResponsesResponse | None,
renderer: RendererLike,
):
tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
# Construct the input messages.
......@@ -606,30 +609,15 @@ class OpenAIServingResponses(OpenAIServing):
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
prev_response_output=prev_response.output if prev_response else None,
)
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(request.input)
chat_template_kwargs = dict(
reasoning_effort=None
if request.reasoning is None
else request.reasoning.effort
)
_, engine_prompts = await self._preprocess_chat(
request,
renderer,
messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=self.tool_parser,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
# When continuing a partial message, we set continue_final_message=True
# and add_generation_prompt=False so the model continues the message
# rather than starting a new one.
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
chat_template_kwargs=chat_template_kwargs,
)
return messages, engine_prompts
......
......@@ -8,8 +8,12 @@ from pydantic import Field, model_validator
from vllm import PoolingParams
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.renderers import ChatParams, merge_kwargs
from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
......@@ -119,6 +123,23 @@ class ChatRequestMixin(OpenAIBaseModel):
)
return data
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
),
),
)
class EncodingRequestMixin(OpenAIBaseModel):
# --8<-- [start:encoding-params]
......
......@@ -4,10 +4,9 @@
import time
from typing import Any, TypeAlias
from pydantic import (
Field,
)
from pydantic import Field
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
......@@ -15,13 +14,24 @@ from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
class ClassificationCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
):
pass
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
class ClassificationChatRequest(
......@@ -33,6 +43,18 @@ class ClassificationChatRequest(
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import Final, cast
from typing import Final, TypeAlias
import jinja2
import numpy as np
......@@ -21,15 +20,14 @@ from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
from vllm.outputs import ClassificationOutput
from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
ClassificationServeContext = ServeContext[ClassificationRequest]
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
......@@ -77,34 +75,18 @@ class ServingClassification(OpenAIServing):
if error_check_ret:
return error_check_ret
_, engine_prompts = await self._preprocess_chat(
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
ctx.engine_prompts = engine_prompts
elif isinstance(ctx.request, ClassificationCompletionRequest):
input_data = ctx.request.input
if input_data in (None, ""):
return self.create_error_response(
"Input or messages must be provided",
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(input_data, list) and not input_data:
ctx.engine_prompts = []
return None
renderer = self._get_completion_renderer()
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(ctx.request),
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
......@@ -128,7 +110,7 @@ class ServingClassification(OpenAIServing):
items: list[ClassificationData] = []
num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
final_res_batch_checked = ctx.final_res_batch
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
......@@ -161,13 +143,6 @@ class ServingClassification(OpenAIServing):
usage=usage,
)
def _build_render_config(self, request: ClassificationRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
async def create_classify(
self,
request: ClassificationRequest,
......
......@@ -3,10 +3,9 @@
import time
from typing import Any, TypeAlias
from pydantic import (
Field,
)
from pydantic import Field
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
......@@ -14,15 +13,47 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
def _get_max_total_output_tokens(
model_config: ModelConfig,
) -> tuple[int | None, int]:
max_total_tokens = model_config.max_model_len
pooler_config = model_config.pooler_config
if pooler_config is None:
return max_total_tokens, 0
if pooler_config.enable_chunked_processing:
return None, 0
max_embed_len = pooler_config.max_embed_len or max_total_tokens
max_output_tokens = max_total_tokens - max_embed_len
return max_total_tokens, max_output_tokens
class EmbeddingCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
pass
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
class EmbeddingChatRequest(
......@@ -33,6 +64,24 @@ class EmbeddingChatRequest(
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, cast
from typing import Any, Final, TypeAlias
import torch
from fastapi import Request
......@@ -22,8 +22,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponse,
EmbeddingResponseData,
)
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
......@@ -37,7 +36,7 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__)
EmbeddingServeContext = ServeContext[EmbeddingRequest]
EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest]
class OpenAIServingEmbedding(OpenAIServing):
......@@ -95,19 +94,16 @@ class OpenAIServingEmbedding(OpenAIServing):
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(ctx.request, EmbeddingCompletionRequest):
renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
......@@ -117,19 +113,6 @@ class OpenAIServingEmbedding(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
return RenderConfig(
max_length=max_length,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
def _build_response(
self,
ctx: EmbeddingServeContext,
......@@ -246,14 +229,18 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request,
)
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode(
chunk_engine_prompt,
pooling_params,
chunk_request_id,
lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
priority=ctx.request.priority,
)
generators.append(original_generator)
......@@ -338,7 +325,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: TokensPrompt,
engine_prompt: TokensPrompt | EmbedsPrompt,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
......@@ -353,23 +340,25 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request,
)
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Return the original generator without wrapping
return self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
priority=ctx.request.priority,
)
async def _prepare_generators(
self,
ctx: ServeContext,
ctx: EmbeddingServeContext,
) -> ErrorResponse | None:
"""Override to support chunked processing."""
ctx = cast(EmbeddingServeContext, ctx)
# Check if we should use chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request)
......@@ -405,7 +394,8 @@ class OpenAIServingEmbedding(OpenAIServing):
for i, engine_prompt in enumerate(ctx.engine_prompts):
# Check if this specific prompt needs chunked processing
if "prompt_token_ids" in engine_prompt:
prompt_token_ids = engine_prompt["prompt_token_ids"]
prompt_token_ids = engine_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
if len(prompt_token_ids) > max_pos_embeddings:
# Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request(
......@@ -573,7 +563,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"token IDs"
)
original_token_ids = original_prompt["prompt_token_ids"]
original_token_ids = original_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
pooling_request_output = PoolingRequestOutput(
request_id=aggregator["request_id"],
......
......@@ -3,11 +3,10 @@
import time
from typing import Any, Generic, TypeAlias, TypeVar
from pydantic import (
Field,
)
from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
......@@ -18,6 +17,7 @@ from vllm.entrypoints.pooling.base.protocol import (
EncodingRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
......@@ -30,6 +30,18 @@ class PoolingCompletionRequest(
):
task: PoolingTask | None = None
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
......@@ -48,6 +60,18 @@ class PoolingChatRequest(
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
......
......@@ -5,7 +5,7 @@ import asyncio
import json
import time
from collections.abc import AsyncGenerator, Sequence
from typing import Final, cast
from typing import Any, Final, cast
import jinja2
from fastapi import Request
......@@ -14,10 +14,7 @@ from typing_extensions import assert_never
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.pooling.protocol import (
......@@ -30,8 +27,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingResponse,
PoolingResponseData,
)
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask
......@@ -99,11 +94,6 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported"
)
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens
)
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
......@@ -134,19 +124,16 @@ class OpenAIServingPooling(OpenAIServing):
_, engine_prompts = await self._preprocess_chat(
request,
self.renderer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
add_special_tokens=request.add_special_tokens,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(request, PoolingCompletionRequest):
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.input,
config=self._build_render_config(request),
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.input,
prompt_embeds=None,
)
else:
raise ValueError(f"Unsupported request of type {type(request)}")
......@@ -207,11 +194,18 @@ class OpenAIServingPooling(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
if is_io_processor_request:
tokenization_kwargs: dict[str, Any] = {}
else:
tok_params = request.build_tok_params(self.model_config) # type: ignore
tokenization_kwargs = tok_params.get_encode_kwargs()
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
)
......@@ -338,10 +332,3 @@ class OpenAIServingPooling(OpenAIServing):
return encode_bytes(bytes_only=encoding_format == "bytes_only")
else:
assert_never(encoding_format)
def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment