Unverified Commit f0a1c845 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Use new Renderer for Completions and Tokenize API (#32863)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8980001c
...@@ -5,65 +5,10 @@ import pytest ...@@ -5,65 +5,10 @@ import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_raw_prompts
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
def test_invalid_input_raise_type_error(invalid_input):
with pytest.raises(TypeError):
parse_raw_prompts(invalid_input)
def test_parse_raw_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([])
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([[]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_parse_raw_single_batch_string_consistent(string_input: str):
assert parse_raw_prompts(string_input) == parse_raw_prompts([string_input])
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
assert parse_raw_prompts(token_input) == parse_raw_prompts([token_input])
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] == parse_raw_prompts(
STRING_INPUTS[inputs_slice]
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mm_processor_kwargs,expected_mm_kwargs", "mm_processor_kwargs,expected_mm_kwargs",
......
...@@ -768,7 +768,7 @@ class ModelConfig: ...@@ -768,7 +768,7 @@ class ModelConfig:
) )
self.tokenizer = object_storage_tokenizer.dir self.tokenizer = object_storage_tokenizer.dir
def _get_encoder_config(self): def _get_encoder_config(self) -> dict[str, Any] | None:
model = self.model model = self.model
if is_remote_gguf(model): if is_remote_gguf(model):
model, _ = split_remote_gguf(model) model, _ = split_remote_gguf(model)
...@@ -1918,7 +1918,7 @@ def _get_and_verify_max_len( ...@@ -1918,7 +1918,7 @@ def _get_and_verify_max_len(
disable_sliding_window: bool, disable_sliding_window: bool,
sliding_window: int | None, sliding_window: int | None,
spec_target_max_model_len: int | None = None, spec_target_max_model_len: int | None = None,
encoder_config: Any | None = None, encoder_config: dict[str, Any] | None = None,
) -> int: ) -> int:
"""Get and verify the model's maximum length.""" """Get and verify the model's maximum length."""
(derived_max_model_len, max_len_key) = ( (derived_max_model_len, max_len_key) = (
......
...@@ -72,14 +72,9 @@ class EngineClient(ABC): ...@@ -72,14 +72,9 @@ class EngineClient(ABC):
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model. """Generate outputs for a request from a pooling model."""
NOTE: truncate_prompt_tokens is deprecated in v0.14.
TODO: Remove this argument in v0.15.
"""
... ...
@abstractmethod @abstractmethod
......
This diff is collapsed.
...@@ -13,12 +13,13 @@ from openai.types.chat.chat_completion_audio import ( ...@@ -13,12 +13,13 @@ from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio as OpenAIChatCompletionAudio, ChatCompletionAudio as OpenAIChatCompletionAudio,
) )
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
from pydantic import ( from pydantic import Field, model_validator
Field,
model_validator,
)
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat, AnyResponseFormat,
DeltaMessage, DeltaMessage,
...@@ -36,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -36,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import ( from vllm.sampling_params import (
BeamSearchParams, BeamSearchParams,
RequestOutputKind, RequestOutputKind,
...@@ -348,6 +350,43 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -348,6 +350,43 @@ class ChatCompletionRequest(OpenAIBaseModel):
# --8<-- [end:chat-completion-extra-params] # --8<-- [end:chat-completion-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
documents=self.documents,
reasoning_effort=self.reasoning_effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
if self.max_completion_tokens is not None:
max_output_tokens: int | None = self.max_completion_tokens
max_output_tokens_param = "max_completion_tokens"
else:
max_output_tokens = self.max_tokens
max_output_tokens_param = "max_tokens"
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=max_output_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param=max_output_tokens_param,
)
# Default sampling parameters for chat completion requests # Default sampling parameters for chat completion requests
_DEFAULT_SAMPLING_PARAMS: dict = { _DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
......
...@@ -67,7 +67,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -67,7 +67,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
) )
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
...@@ -185,8 +185,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -185,8 +185,6 @@ class OpenAIServingChat(OpenAIServing):
start_time = time.perf_counter() start_time = time.perf_counter()
try: try:
renderer = self.engine_client.renderer
# Create a minimal dummy request # Create a minimal dummy request
dummy_request = ChatCompletionRequest( dummy_request = ChatCompletionRequest(
messages=[{"role": "user", "content": "warmup"}], messages=[{"role": "user", "content": "warmup"}],
...@@ -201,18 +199,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -201,18 +199,10 @@ class OpenAIServingChat(OpenAIServing):
# 3. Tokenizer initialization for chat # 3. Tokenizer initialization for chat
await self._preprocess_chat( await self._preprocess_chat(
dummy_request, dummy_request,
renderer,
dummy_request.messages, dummy_request.messages,
chat_template=self.chat_template, default_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
add_generation_prompt=True, default_template_kwargs=self.default_chat_template_kwargs,
continue_final_message=False,
tool_dicts=None,
documents=None,
chat_template_kwargs=None,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=None,
add_special_tokens=False,
) )
elapsed = (time.perf_counter() - start_time) * 1000 elapsed = (time.perf_counter() - start_time) * 1000
...@@ -225,7 +215,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -225,7 +215,10 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request( async def render_chat_request(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[Any]] | ErrorResponse: ) -> (
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
| ErrorResponse
):
""" """
render chat request by validating and preprocessing inputs. render chat request by validating and preprocessing inputs.
...@@ -302,23 +295,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -302,23 +295,14 @@ class OpenAIServingChat(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
chat_template_kwargs = request.chat_template_kwargs or {}
chat_template_kwargs.update(reasoning_effort=request.reasoning_effort)
conversation, engine_prompts = await self._preprocess_chat( conversation, engine_prompts = await self._preprocess_chat(
request, request,
renderer,
request.messages, request.messages,
chat_template=request.chat_template or self.chat_template, default_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt, default_template_kwargs=self.default_chat_template_kwargs,
continue_final_message=request.continue_final_message,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
documents=request.documents,
chat_template_kwargs=chat_template_kwargs,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=tool_parser, tool_parser=tool_parser,
add_special_tokens=request.add_special_tokens,
) )
else: else:
# For GPT-OSS. # For GPT-OSS.
...@@ -428,11 +412,15 @@ class OpenAIServingChat(OpenAIServing): ...@@ -428,11 +412,15 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers, trace_headers=trace_headers,
) )
else: else:
engine_request, tokenization_kwargs = await self._process_inputs( tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id, sub_request_id,
engine_prompt, engine_prompt,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
......
...@@ -9,11 +9,9 @@ from dataclasses import replace ...@@ -9,11 +9,9 @@ from dataclasses import replace
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
import torch import torch
from pydantic import ( from pydantic import Field, model_validator
Field,
model_validator,
)
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat, AnyResponseFormat,
LegacyStructuralTagResponseFormat, LegacyStructuralTagResponseFormat,
...@@ -27,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -27,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import ( from vllm.sampling_params import (
BeamSearchParams, BeamSearchParams,
RequestOutputKind, RequestOutputKind,
...@@ -178,6 +177,17 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -178,6 +177,17 @@ class CompletionRequest(OpenAIBaseModel):
# --8<-- [end:completion-extra-params] # --8<-- [end:completion-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_tokens",
)
# Default sampling parameters for completion requests # Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = { _DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
......
...@@ -32,7 +32,6 @@ from vllm.entrypoints.openai.engine.serving import ( ...@@ -32,7 +32,6 @@ from vllm.entrypoints.openai.engine.serving import (
clamp_prompt_logprobs, clamp_prompt_logprobs,
) )
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
...@@ -111,11 +110,10 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -111,11 +110,10 @@ class OpenAIServingCompletion(OpenAIServing):
) )
try: try:
renderer = self._get_completion_renderer() engine_prompts = await self._preprocess_completion(
engine_prompts = await renderer.render_prompt_and_embeds( request,
prompt_or_prompts=request.prompt, prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds, prompt_embeds=request.prompt_embeds,
config=self._build_render_config(request),
) )
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
...@@ -203,10 +201,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -203,10 +201,6 @@ class OpenAIServingCompletion(OpenAIServing):
else await self._get_trace_headers(raw_request.headers) else await self._get_trace_headers(raw_request.headers)
) )
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search( generator = self.beam_search(
prompt=engine_prompt, prompt=engine_prompt,
...@@ -216,11 +210,15 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -216,11 +210,15 @@ class OpenAIServingCompletion(OpenAIServing):
trace_headers=trace_headers, trace_headers=trace_headers,
) )
else: else:
engine_request, tokenization_kwargs = await self._process_inputs( tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id_item, request_id_item,
engine_prompt, engine_prompt,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
...@@ -709,26 +707,3 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -709,26 +707,3 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens, tokens=out_tokens,
top_logprobs=out_top_logprobs, top_logprobs=out_top_logprobs,
) )
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: int | None = None,
) -> RenderConfig:
# Validate max_tokens before using it
if request.max_tokens is not None and request.max_tokens > self.max_model_len:
raise VLLMValidationError(
f"'max_tokens' ({request.max_tokens}) cannot be greater than "
f"the model's maximum context length ({self.max_model_len}).",
parameter="max_tokens",
value=request.max_tokens,
)
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo and not request.return_token_ids),
)
...@@ -16,9 +16,7 @@ from pydantic import ( ...@@ -16,9 +16,7 @@ from pydantic import (
from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import ( from vllm.sampling_params import SamplingParams
SamplingParams,
)
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
......
This diff is collapsed.
...@@ -43,7 +43,6 @@ from vllm.entrypoints.openai.responses.protocol import ( ...@@ -43,7 +43,6 @@ from vllm.entrypoints.openai.responses.protocol import (
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.renderers import RendererLike
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -261,7 +260,7 @@ class ParsableContext(ConversationContext): ...@@ -261,7 +260,7 @@ class ParsableContext(ConversationContext):
self, self,
*, *,
response_messages: list[ResponseInputOutputItem], response_messages: list[ResponseInputOutputItem],
renderer: RendererLike, tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None, reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
request: ResponsesRequest, request: ResponsesRequest,
available_tools: list[str] | None, available_tools: list[str] | None,
...@@ -280,7 +279,6 @@ class ParsableContext(ConversationContext): ...@@ -280,7 +279,6 @@ class ParsableContext(ConversationContext):
if reasoning_parser_cls is None: if reasoning_parser_cls is None:
raise ValueError("reasoning_parser_cls must be provided.") raise ValueError("reasoning_parser_cls must be provided.")
tokenizer = renderer.get_tokenizer()
self.parser = get_responses_parser_for_simple_context( self.parser = get_responses_parser_for_simple_context(
tokenizer=tokenizer, tokenizer=tokenizer,
reasoning_parser_cls=reasoning_parser_cls, reasoning_parser_cls=reasoning_parser_cls,
...@@ -290,8 +288,6 @@ class ParsableContext(ConversationContext): ...@@ -290,8 +288,6 @@ class ParsableContext(ConversationContext):
) )
self.tool_parser_cls = tool_parser_cls self.tool_parser_cls = tool_parser_cls
self.request = request self.request = request
self.renderer = renderer
self.tokenizer = tokenizer
self.available_tools = available_tools or [] self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {} self._tool_sessions: dict[str, ClientSession | Tool] = {}
......
...@@ -59,12 +59,15 @@ from pydantic import ( ...@@ -59,12 +59,15 @@ from pydantic import (
model_validator, model_validator,
) )
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.chat_utils import (
OpenAIBaseModel, ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
) )
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import ( from vllm.sampling_params import (
RequestOutputKind, RequestOutputKind,
SamplingParams, SamplingParams,
...@@ -230,6 +233,42 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -230,6 +233,42 @@ class ResponsesRequest(OpenAIBaseModel):
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
# --8<-- [end:responses-extra-params] # --8<-- [end:responses-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
from .utils import should_continue_final_message
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(self.input)
reasoning = self.reasoning
return ChatParams(
chat_template=default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs( # To remove unset values
{},
dict(
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
reasoning_effort=None if reasoning is None else reasoning.effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_output_tokens or 0,
truncate_prompt_tokens=-1 if self.truncation != "disabled" else None,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_output_tokens",
)
_DEFAULT_SAMPLING_PARAMS = { _DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0, "temperature": 1.0,
"top_p": 1.0, "top_p": 1.0,
......
...@@ -114,16 +114,15 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -114,16 +114,15 @@ from vllm.entrypoints.openai.responses.utils import (
construct_input_messages, construct_input_messages,
construct_tool_dicts, construct_tool_dicts,
extract_tool_types, extract_tool_types,
should_continue_final_message,
) )
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -291,13 +290,14 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -291,13 +290,14 @@ class OpenAIServingResponses(OpenAIServing):
self.tool_server = tool_server self.tool_server = tool_server
def _validate_generator_input( def _validate_generator_input(
self, engine_prompt: TokensPrompt self,
engine_prompt: TokensPrompt | EmbedsPrompt,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Add validations to the input to the generator here.""" """Add validations to the input to the generator here."""
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): prompt_len = get_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len:
error_message = ( error_message = (
"The engine prompt length" f"The engine prompt length {prompt_len} "
f" {len(engine_prompt['prompt_token_ids'])} "
f"exceeds the max_model_len {self.max_model_len}. " f"exceeds the max_model_len {self.max_model_len}. "
"Please reduce prompt." "Please reduce prompt."
) )
...@@ -307,6 +307,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -307,6 +307,7 @@ class OpenAIServingResponses(OpenAIServing):
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
param="input", param="input",
) )
return None return None
def _validate_create_responses_input( def _validate_create_responses_input(
...@@ -387,8 +388,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -387,8 +388,6 @@ class OpenAIServingResponses(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request) model_name = self.models.model_name(lora_request)
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
if self.use_harmony: if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony( messages, engine_prompts = self._make_request_with_harmony(
...@@ -396,7 +395,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -396,7 +395,7 @@ class OpenAIServingResponses(OpenAIServing):
) )
else: else:
messages, engine_prompts = await self._make_request( messages, engine_prompts = await self._make_request(
request, prev_response, renderer request, prev_response
) )
except ( except (
...@@ -431,6 +430,9 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -431,6 +430,9 @@ class OpenAIServingResponses(OpenAIServing):
assert len(builtin_tool_list) == 0 assert len(builtin_tool_list) == 0
available_tools = [] available_tools = []
try: try:
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
for engine_prompt in engine_prompts: for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt) maybe_error = self._validate_generator_input(engine_prompt)
if maybe_error is not None: if maybe_error is not None:
...@@ -446,6 +448,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -446,6 +448,7 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params default_max_tokens, self.default_sampling_params
) )
tok_params = request.build_tok_params(self.model_config)
trace_headers = ( trace_headers = (
None None
...@@ -465,7 +468,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -465,7 +468,7 @@ class OpenAIServingResponses(OpenAIServing):
# tokens during generation instead of at the end # tokens during generation instead of at the end
context = ParsableContext( context = ParsableContext(
response_messages=messages, response_messages=messages,
renderer=renderer, tokenizer=tokenizer,
reasoning_parser_cls=self.reasoning_parser, reasoning_parser_cls=self.reasoning_parser,
request=request, request=request,
tool_parser_cls=self.tool_parser, tool_parser_cls=self.tool_parser,
...@@ -495,6 +498,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -495,6 +498,7 @@ class OpenAIServingResponses(OpenAIServing):
request_id=request.request_id, request_id=request.request_id,
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
tok_params=tok_params,
context=context, context=context,
lora_request=lora_request, lora_request=lora_request,
priority=request.priority, priority=request.priority,
...@@ -596,7 +600,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -596,7 +600,6 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
request: ResponsesRequest, request: ResponsesRequest,
prev_response: ResponsesResponse | None, prev_response: ResponsesResponse | None,
renderer: RendererLike,
): ):
tool_dicts = construct_tool_dicts(request.tools, request.tool_choice) tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
# Construct the input messages. # Construct the input messages.
...@@ -606,30 +609,15 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -606,30 +609,15 @@ class OpenAIServingResponses(OpenAIServing):
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None, prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
prev_response_output=prev_response.output if prev_response else None, prev_response_output=prev_response.output if prev_response else None,
) )
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(request.input)
chat_template_kwargs = dict(
reasoning_effort=None
if request.reasoning is None
else request.reasoning.effort
)
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
request, request,
renderer,
messages, messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
tool_parser=self.tool_parser, tool_parser=self.tool_parser,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
# When continuing a partial message, we set continue_final_message=True
# and add_generation_prompt=False so the model continues the message
# rather than starting a new one.
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
chat_template_kwargs=chat_template_kwargs,
) )
return messages, engine_prompts return messages, engine_prompts
......
...@@ -8,8 +8,12 @@ from pydantic import Field, model_validator ...@@ -8,8 +8,12 @@ from pydantic import Field, model_validator
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config.pooler import get_use_activation from vllm.config.pooler import get_use_activation
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.renderers import ChatParams, merge_kwargs
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
...@@ -119,6 +123,23 @@ class ChatRequestMixin(OpenAIBaseModel): ...@@ -119,6 +123,23 @@ class ChatRequestMixin(OpenAIBaseModel):
) )
return data return data
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
),
),
)
class EncodingRequestMixin(OpenAIBaseModel): class EncodingRequestMixin(OpenAIBaseModel):
# --8<-- [start:encoding-params] # --8<-- [start:encoding-params]
......
...@@ -4,10 +4,9 @@ ...@@ -4,10 +4,9 @@
import time import time
from typing import Any, TypeAlias from typing import Any, TypeAlias
from pydantic import ( from pydantic import Field
Field,
)
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin, ChatRequestMixin,
...@@ -15,13 +14,24 @@ from vllm.entrypoints.pooling.base.protocol import ( ...@@ -15,13 +14,24 @@ from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin, CompletionRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
class ClassificationCompletionRequest( class ClassificationCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
): ):
pass def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
class ClassificationChatRequest( class ClassificationChatRequest(
...@@ -33,6 +43,18 @@ class ClassificationChatRequest( ...@@ -33,6 +43,18 @@ class ClassificationChatRequest(
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
ClassificationRequest: TypeAlias = ( ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest ClassificationCompletionRequest | ClassificationChatRequest
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus from typing import Final, TypeAlias
from typing import Final, cast
import jinja2 import jinja2
import numpy as np import numpy as np
...@@ -21,15 +20,14 @@ from vllm.entrypoints.pooling.classify.protocol import ( ...@@ -21,15 +20,14 @@ from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest, ClassificationRequest,
ClassificationResponse, ClassificationResponse,
) )
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput from vllm.outputs import ClassificationOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
logger = init_logger(__name__) logger = init_logger(__name__)
ClassificationServeContext = ServeContext[ClassificationRequest] ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing): class ServingClassification(OpenAIServing):
...@@ -77,34 +75,18 @@ class ServingClassification(OpenAIServing): ...@@ -77,34 +75,18 @@ class ServingClassification(OpenAIServing):
if error_check_ret: if error_check_ret:
return error_check_ret return error_check_ret
_, engine_prompts = await self._preprocess_chat( _, ctx.engine_prompts = await self._preprocess_chat(
ctx.request, ctx.request,
self.renderer,
ctx.request.messages, ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template, default_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt, default_template_kwargs=None,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
) )
ctx.engine_prompts = engine_prompts
elif isinstance(ctx.request, ClassificationCompletionRequest): elif isinstance(ctx.request, ClassificationCompletionRequest):
input_data = ctx.request.input ctx.engine_prompts = await self._preprocess_completion(
if input_data in (None, ""): ctx.request,
return self.create_error_response( prompt_input=ctx.request.input,
"Input or messages must be provided", prompt_embeds=None,
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(input_data, list) and not input_data:
ctx.engine_prompts = []
return None
renderer = self._get_completion_renderer()
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(ctx.request),
) )
else: else:
return self.create_error_response("Invalid classification request type") return self.create_error_response("Invalid classification request type")
...@@ -128,7 +110,7 @@ class ServingClassification(OpenAIServing): ...@@ -128,7 +110,7 @@ class ServingClassification(OpenAIServing):
items: list[ClassificationData] = [] items: list[ClassificationData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) final_res_batch_checked = ctx.final_res_batch
for idx, final_res in enumerate(final_res_batch_checked): for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs) classify_res = ClassificationOutput.from_base(final_res.outputs)
...@@ -161,13 +143,6 @@ class ServingClassification(OpenAIServing): ...@@ -161,13 +143,6 @@ class ServingClassification(OpenAIServing):
usage=usage, usage=usage,
) )
def _build_render_config(self, request: ClassificationRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
async def create_classify( async def create_classify(
self, self,
request: ClassificationRequest, request: ClassificationRequest,
......
...@@ -3,10 +3,9 @@ ...@@ -3,10 +3,9 @@
import time import time
from typing import Any, TypeAlias from typing import Any, TypeAlias
from pydantic import ( from pydantic import Field
Field,
)
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin, ChatRequestMixin,
...@@ -14,15 +13,47 @@ from vllm.entrypoints.pooling.base.protocol import ( ...@@ -14,15 +13,47 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin, EmbedRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
def _get_max_total_output_tokens(
model_config: ModelConfig,
) -> tuple[int | None, int]:
max_total_tokens = model_config.max_model_len
pooler_config = model_config.pooler_config
if pooler_config is None:
return max_total_tokens, 0
if pooler_config.enable_chunked_processing:
return None, 0
max_embed_len = pooler_config.max_embed_len or max_total_tokens
max_output_tokens = max_total_tokens - max_embed_len
return max_total_tokens, max_output_tokens
class EmbeddingCompletionRequest( class EmbeddingCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
): ):
# Ordered by official OpenAI API documentation def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
# https://platform.openai.com/docs/api-reference/embeddings encoder_config = model_config.encoder_config or {}
pass
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
class EmbeddingChatRequest( class EmbeddingChatRequest(
...@@ -33,6 +64,24 @@ class EmbeddingChatRequest( ...@@ -33,6 +64,24 @@ class EmbeddingChatRequest(
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json import json
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, cast from typing import Any, Final, TypeAlias
import torch import torch
from fastapi import Request from fastapi import Request
...@@ -22,8 +22,7 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -22,8 +22,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponse, EmbeddingResponse,
EmbeddingResponseData, EmbeddingResponseData,
) )
from vllm.entrypoints.renderer import RenderConfig from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -37,7 +36,7 @@ from vllm.utils.serial_utils import ( ...@@ -37,7 +36,7 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
EmbeddingServeContext = ServeContext[EmbeddingRequest] EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest]
class OpenAIServingEmbedding(OpenAIServing): class OpenAIServingEmbedding(OpenAIServing):
...@@ -95,19 +94,16 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -95,19 +94,16 @@ class OpenAIServingEmbedding(OpenAIServing):
_, ctx.engine_prompts = await self._preprocess_chat( _, ctx.engine_prompts = await self._preprocess_chat(
ctx.request, ctx.request,
self.renderer,
ctx.request.messages, ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template, default_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt, default_template_kwargs=None,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
) )
elif isinstance(ctx.request, EmbeddingCompletionRequest): elif isinstance(ctx.request, EmbeddingCompletionRequest):
renderer = self._get_completion_renderer() ctx.engine_prompts = await self._preprocess_completion(
ctx.engine_prompts = await renderer.render_prompt( ctx.request,
prompt_or_prompts=ctx.request.input, prompt_input=ctx.request.input,
config=self._build_render_config(ctx.request), prompt_embeds=None,
) )
else: else:
return self.create_error_response("Invalid classification request type") return self.create_error_response("Invalid classification request type")
...@@ -117,19 +113,6 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -117,19 +113,6 @@ class OpenAIServingEmbedding(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
return RenderConfig(
max_length=max_length,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
def _build_response( def _build_response(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
...@@ -246,14 +229,18 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -246,14 +229,18 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Create generator for this chunk and wrap it to return indices # Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode( original_generator = self.engine_client.encode(
chunk_engine_prompt, chunk_engine_prompt,
pooling_params, pooling_params,
chunk_request_id, chunk_request_id,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0), priority=ctx.request.priority,
) )
generators.append(original_generator) generators.append(original_generator)
...@@ -338,7 +325,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -338,7 +325,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator( async def _create_single_prompt_generator(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
engine_prompt: TokensPrompt, engine_prompt: TokensPrompt | EmbedsPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
...@@ -353,23 +340,25 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -353,23 +340,25 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Return the original generator without wrapping # Return the original generator without wrapping
return self.engine_client.encode( return self.engine_client.encode(
engine_prompt, engine_prompt,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0), priority=ctx.request.priority,
) )
async def _prepare_generators( async def _prepare_generators(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Override to support chunked processing.""" """Override to support chunked processing."""
ctx = cast(EmbeddingServeContext, ctx)
# Check if we should use chunked processing # Check if we should use chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request) use_chunked = self._should_use_chunked_processing(ctx.request)
...@@ -405,7 +394,8 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -405,7 +394,8 @@ class OpenAIServingEmbedding(OpenAIServing):
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_prompt in enumerate(ctx.engine_prompts):
# Check if this specific prompt needs chunked processing # Check if this specific prompt needs chunked processing
if "prompt_token_ids" in engine_prompt: if "prompt_token_ids" in engine_prompt:
prompt_token_ids = engine_prompt["prompt_token_ids"] prompt_token_ids = engine_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
if len(prompt_token_ids) > max_pos_embeddings: if len(prompt_token_ids) > max_pos_embeddings:
# Use chunked processing for this prompt # Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request( chunk_generators = await self._process_chunked_request(
...@@ -573,7 +563,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -573,7 +563,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"token IDs" "token IDs"
) )
original_token_ids = original_prompt["prompt_token_ids"] original_token_ids = original_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
pooling_request_output = PoolingRequestOutput( pooling_request_output = PoolingRequestOutput(
request_id=aggregator["request_id"], request_id=aggregator["request_id"],
......
...@@ -3,11 +3,10 @@ ...@@ -3,11 +3,10 @@
import time import time
from typing import Any, Generic, TypeAlias, TypeVar from typing import Any, Generic, TypeAlias, TypeVar
from pydantic import ( from pydantic import Field
Field,
)
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.config.pooler import get_use_activation from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.entrypoints.pooling.base.protocol import (
...@@ -18,6 +17,7 @@ from vllm.entrypoints.pooling.base.protocol import ( ...@@ -18,6 +17,7 @@ from vllm.entrypoints.pooling.base.protocol import (
EncodingRequestMixin, EncodingRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -30,6 +30,18 @@ class PoolingCompletionRequest( ...@@ -30,6 +30,18 @@ class PoolingCompletionRequest(
): ):
task: PoolingTask | None = None task: PoolingTask | None = None
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
...@@ -48,6 +60,18 @@ class PoolingChatRequest( ...@@ -48,6 +60,18 @@ class PoolingChatRequest(
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
......
...@@ -5,7 +5,7 @@ import asyncio ...@@ -5,7 +5,7 @@ import asyncio
import json import json
import time import time
from collections.abc import AsyncGenerator, Sequence from collections.abc import AsyncGenerator, Sequence
from typing import Final, cast from typing import Any, Final, cast
import jinja2 import jinja2
from fastapi import Request from fastapi import Request
...@@ -14,10 +14,7 @@ from typing_extensions import assert_never ...@@ -14,10 +14,7 @@ from typing_extensions import assert_never
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.entrypoints.pooling.pooling.protocol import (
...@@ -30,8 +27,6 @@ from vllm.entrypoints.pooling.pooling.protocol import ( ...@@ -30,8 +27,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingResponse, PoolingResponse,
PoolingResponseData, PoolingResponseData,
) )
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask from vllm.tasks import PoolingTask, SupportedTask
...@@ -99,11 +94,6 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -99,11 +94,6 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported" "dimensions is currently not supported"
) )
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens
)
if is_io_processor_request: if is_io_processor_request:
if self.io_processor is None: if self.io_processor is None:
raise ValueError( raise ValueError(
...@@ -134,19 +124,16 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -134,19 +124,16 @@ class OpenAIServingPooling(OpenAIServing):
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
request, request,
self.renderer,
request.messages, request.messages,
chat_template=request.chat_template or self.chat_template, default_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt, default_template_kwargs=None,
continue_final_message=request.continue_final_message,
add_special_tokens=request.add_special_tokens,
) )
elif isinstance(request, PoolingCompletionRequest): elif isinstance(request, PoolingCompletionRequest):
renderer = self._get_completion_renderer() engine_prompts = await self._preprocess_completion(
engine_prompts = await renderer.render_prompt( request,
prompt_or_prompts=request.input, prompt_input=request.input,
config=self._build_render_config(request), prompt_embeds=None,
) )
else: else:
raise ValueError(f"Unsupported request of type {type(request)}") raise ValueError(f"Unsupported request of type {type(request)}")
...@@ -207,11 +194,18 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -207,11 +194,18 @@ class OpenAIServingPooling(OpenAIServing):
else await self._get_trace_headers(raw_request.headers) else await self._get_trace_headers(raw_request.headers)
) )
if is_io_processor_request:
tokenization_kwargs: dict[str, Any] = {}
else:
tok_params = request.build_tok_params(self.model_config) # type: ignore
tokenization_kwargs = tok_params.get_encode_kwargs()
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_prompt,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
) )
...@@ -338,10 +332,3 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -338,10 +332,3 @@ class OpenAIServingPooling(OpenAIServing):
return encode_bytes(bytes_only=encoding_format == "bytes_only") return encode_bytes(bytes_only=encoding_format == "bytes_only")
else: else:
assert_never(encoding_format) assert_never(encoding_format)
def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment