Unverified Commit f0a1c845 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Use new Renderer for Completions and Tokenize API (#32863)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8980001c
...@@ -5,65 +5,10 @@ import pytest ...@@ -5,65 +5,10 @@ import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_raw_prompts
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
def test_invalid_input_raise_type_error(invalid_input):
with pytest.raises(TypeError):
parse_raw_prompts(invalid_input)
def test_parse_raw_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([])
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([[]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_parse_raw_single_batch_string_consistent(string_input: str):
assert parse_raw_prompts(string_input) == parse_raw_prompts([string_input])
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
assert parse_raw_prompts(token_input) == parse_raw_prompts([token_input])
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] == parse_raw_prompts(
STRING_INPUTS[inputs_slice]
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mm_processor_kwargs,expected_mm_kwargs", "mm_processor_kwargs,expected_mm_kwargs",
......
...@@ -768,7 +768,7 @@ class ModelConfig: ...@@ -768,7 +768,7 @@ class ModelConfig:
) )
self.tokenizer = object_storage_tokenizer.dir self.tokenizer = object_storage_tokenizer.dir
def _get_encoder_config(self): def _get_encoder_config(self) -> dict[str, Any] | None:
model = self.model model = self.model
if is_remote_gguf(model): if is_remote_gguf(model):
model, _ = split_remote_gguf(model) model, _ = split_remote_gguf(model)
...@@ -1918,7 +1918,7 @@ def _get_and_verify_max_len( ...@@ -1918,7 +1918,7 @@ def _get_and_verify_max_len(
disable_sliding_window: bool, disable_sliding_window: bool,
sliding_window: int | None, sliding_window: int | None,
spec_target_max_model_len: int | None = None, spec_target_max_model_len: int | None = None,
encoder_config: Any | None = None, encoder_config: dict[str, Any] | None = None,
) -> int: ) -> int:
"""Get and verify the model's maximum length.""" """Get and verify the model's maximum length."""
(derived_max_model_len, max_len_key) = ( (derived_max_model_len, max_len_key) = (
......
...@@ -72,14 +72,9 @@ class EngineClient(ABC): ...@@ -72,14 +72,9 @@ class EngineClient(ABC):
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model. """Generate outputs for a request from a pooling model."""
NOTE: truncate_prompt_tokens is deprecated in v0.14.
TODO: Remove this argument in v0.15.
"""
... ...
@abstractmethod @abstractmethod
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
import warnings
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, TypeAlias, cast
import cloudpickle import cloudpickle
import torch.nn as nn import torch.nn as nn
...@@ -46,15 +47,17 @@ from vllm.entrypoints.pooling.score.utils import ( ...@@ -46,15 +47,17 @@ from vllm.entrypoints.pooling.score.utils import (
compress_token_type_ids, compress_token_type_ids,
get_score_prompt, get_score_prompt,
) )
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs import ( from vllm.inputs import (
DataPrompt, DataPrompt,
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
PromptType, PromptType,
SingletonPrompt, SingletonPrompt,
TextPrompt, TextPrompt,
TokensPrompt, TokensPrompt,
) )
from vllm.inputs.parse import get_prompt_components from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -67,6 +70,7 @@ from vllm.outputs import ( ...@@ -67,6 +70,7 @@ from vllm.outputs import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -74,7 +78,6 @@ from vllm.tokenizers.mistral import MistralTokenizer ...@@ -74,7 +78,6 @@ from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter from vllm.utils.counter import Counter
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
...@@ -85,6 +88,9 @@ logger = init_logger(__name__) ...@@ -85,6 +88,9 @@ logger = init_logger(__name__)
_R = TypeVar("_R", default=Any) _R = TypeVar("_R", default=Any)
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]
class LLM: class LLM:
"""An LLM for generating texts from given prompts and sampling parameters. """An LLM for generating texts from given prompts and sampling parameters.
...@@ -372,6 +378,7 @@ class LLM: ...@@ -372,6 +378,7 @@ class LLM:
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None, priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -398,15 +405,11 @@ class LLM: ...@@ -398,15 +405,11 @@ class LLM:
If provided, must be a list of integers matching the length If provided, must be a list of integers matching the length
of `prompts`, where each priority value corresponds to the prompt of `prompts`, where each priority value corresponds to the prompt
at the same index. at the same index.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns: Returns:
A list of `RequestOutput` objects containing the A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts. generated completions in the same order as the input prompts.
Note:
Using `prompts` and `prompt_token_ids` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
""" """
model_config = self.model_config model_config = self.model_config
runner_type = model_config.runner_type runner_type = model_config.runner_type
...@@ -418,17 +421,14 @@ class LLM: ...@@ -418,17 +421,14 @@ class LLM:
) )
if sampling_params is None: if sampling_params is None:
# Use default sampling params.
sampling_params = self.get_default_sampling_params() sampling_params = self.get_default_sampling_params()
# Add any modality specific loras to the corresponding prompts
lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=prompts, prompts=prompts,
params=sampling_params, params=sampling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request),
tokenization_kwargs=tokenization_kwargs,
priority=priority, priority=priority,
) )
...@@ -771,65 +771,169 @@ class LLM: ...@@ -771,65 +771,169 @@ class LLM:
return outputs return outputs
def preprocess_chat( def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs)
def _normalize_prompts(
self, self,
messages: list[ChatCompletionMessageParam] prompts: PromptType | Sequence[PromptType],
) -> list[EnginePrompt | EngineEncDecPrompt]:
if isinstance(prompts, str):
prompts = TextPrompt(prompt=prompts)
return prompts if isinstance(prompts, Sequence) else [prompts] # type: ignore[return-value]
def _preprocess_cmpl_singleton(
self,
prompt: SingletonPrompt,
tok_params: TokenizeParams,
*,
tokenize: bool,
) -> EnginePrompt:
renderer = self.llm_engine.renderer
if not isinstance(prompt, dict):
prompt = renderer.render_completion(prompt)
return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt
def _preprocess_cmpl_enc_dec(
self,
prompt: ExplicitEncoderDecoderPrompt,
tok_params: TokenizeParams,
) -> EngineEncDecPrompt:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return EngineEncDecPrompt(
encoder_prompt=self._preprocess_cmpl_singleton(
enc_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
),
decoder_prompt=(
None
if dec_prompt is None
else self._preprocess_cmpl_singleton(
dec_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
)
),
)
def _preprocess_completion(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[EnginePrompt | EngineEncDecPrompt]:
"""
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`.
Refer to [LLM.generate][] for a complete description of the arguments.
Returns:
A list of `TokensPrompts` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
engine_prompts = list[EnginePrompt | EngineEncDecPrompt]()
for prompt in self._normalize_prompts(prompts):
if is_explicit_encoder_decoder_prompt(prompt):
engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params))
else:
# Some MM models have non-default `add_special_tokens`
# TODO: Move multi-modal processor into tokenization
engine_prompts.append(
self._preprocess_cmpl_singleton(
prompt,
tok_params,
tokenize=not self.model_config.is_multimodal_model,
)
)
return engine_prompts
def _normalize_conversations(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
) -> list[list[ChatCompletionMessageParam]]:
return conversations if is_list_of(conversations, list) else [conversations] # type: ignore[list-item,return-value]
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=False,
).with_kwargs(tokenization_kwargs)
def _preprocess_chat(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]], | list[list[ChatCompletionMessageParam]],
chat_template: str | None = None, chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto", chat_template_content_format: ChatTemplateContentFormatOption = "auto",
chat_template_kwargs: dict[str, Any] | None = None,
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
continue_final_message: bool = False, continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[TextPrompt | TokensPrompt]: ) -> list[EnginePrompt]:
""" """
Generate prompt for a chat conversation. The pre-processed Convert a list of conversations into prompts so that they can then
prompt can then be used as input for the other LLM methods. be used as input for other LLM APIs.
Refer to [LLM.chat][] for a complete description of the arguments.
Refer to `chat` for a complete description of the arguments.
Returns: Returns:
A list of `TokensPrompts` objects containing the tokenized A list of `TokensPrompts` objects containing the tokenized prompt
prompt after chat template interpolation, and the after chat template interpolation, and the raw multi-modal inputs.
pre-processed multi-modal inputs.
""" """
list_of_messages: list[list[ChatCompletionMessageParam]]
# Handle multi and single conversations
if is_list_of(messages, list):
# messages is list[list[...]]
list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
else:
# messages is list[...]
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
renderer = self.llm_engine.renderer renderer = self.llm_engine.renderer
chat_template_kwargs = { chat_params = ChatParams(
"chat_template": chat_template, chat_template=chat_template,
"add_generation_prompt": add_generation_prompt, chat_template_content_format=chat_template_content_format,
"continue_final_message": continue_final_message, chat_template_kwargs=merge_kwargs(
"tools": tools, chat_template_kwargs,
**(chat_template_kwargs or {}), dict(
} add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
prompts = list[TextPrompt | TokensPrompt]() tools=tools,
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
for msgs in list_of_messages: ),
# NOTE: renderer.render_messages() currently doesn't ),
# handle mm_processor_kwargs, since there is no implementation in )
# the chat message parsing for it. tok_params = self._get_chat_tok_params(tokenization_kwargs)
_, prompt = renderer.render_messages(
msgs, engine_prompts = list[EnginePrompt]()
chat_template_content_format=chat_template_content_format, for conversation in self._normalize_conversations(conversations):
**chat_template_kwargs, _, in_prompt = renderer.render_messages(conversation, chat_params)
)
if mm_processor_kwargs is not None: if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs in_prompt["mm_processor_kwargs"] = mm_processor_kwargs
prompts.append(prompt) engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
return prompts return engine_prompts
def chat( def chat(
self, self,
...@@ -844,6 +948,7 @@ class LLM: ...@@ -844,6 +948,7 @@ class LLM:
continue_final_message: bool = False, continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None, chat_template_kwargs: dict[str, Any] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
""" """
...@@ -889,22 +994,22 @@ class LLM: ...@@ -889,22 +994,22 @@ class LLM:
`True` if `add_generation_prompt` is also `True`. `True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat chat_template_kwargs: Additional kwargs to pass to the chat
template. template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this tokenization_kwargs: Overrides for `tokenizer.encode`.
chat request. Only used for offline requests. mm_processor_kwargs: Overrides for `processor.__call__`.
Returns: Returns:
A list of `RequestOutput` objects containing the generated A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages. responses in the same order as the input messages.
""" """
prompts = self._preprocess_chat(
prompts = self.preprocess_chat( messages,
messages=messages,
chat_template=chat_template, chat_template=chat_template,
chat_template_content_format=chat_template_content_format, chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message, continue_final_message=continue_final_message,
tools=tools, tools=tools,
chat_template_kwargs=chat_template_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
) )
...@@ -913,6 +1018,7 @@ class LLM: ...@@ -913,6 +1018,7 @@ class LLM:
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
) )
def encode( def encode(
...@@ -945,37 +1051,29 @@ class LLM: ...@@ -945,37 +1051,29 @@ class LLM:
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
pooling_task: Override the pooling task to use. pooling_task: Override the pooling task to use.
tokenization_kwargs: overrides tokenization_kwargs set in tokenization_kwargs: Overrides for `tokenizer.encode`.
pooling_params
Returns: Returns:
A list of `PoolingRequestOutput` objects containing the A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts. pooled hidden states in the same order as the input prompts.
Note:
Using `prompts` and `prompt_token_ids` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
""" """
error_str = (
"pooling_task required for `LLM.encode`\n"
"Please use one of the more specific methods or set the "
"pooling_task when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For similarity scores, use `LLM.score(...)`.\n"
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="token_classify"`\n'
" - For token classification, "
'use `pooling_task="token_classify"`\n'
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
)
if pooling_task is None: if pooling_task is None:
raise ValueError(error_str) raise ValueError(
"pooling_task required for `LLM.encode`\n"
"Please use one of the more specific methods or set the "
"pooling_task when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For similarity scores, use `LLM.score(...)`.\n"
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="token_classify"`\n'
" - For token classification, "
'use `pooling_task="token_classify"`\n'
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
)
model_config = self.model_config model_config = self.model_config
runner_type = model_config.runner_type runner_type = model_config.runner_type
...@@ -986,6 +1084,20 @@ class LLM: ...@@ -986,6 +1084,20 @@ class LLM:
"pooling model." "pooling model."
) )
if truncate_prompt_tokens is not None:
warnings.warn(
"The `truncate_prompt_tokens` parameter in `LLM.encode()` "
"is deprecated and will be removed in v0.16. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
io_processor_prompt = False io_processor_prompt = False
if isinstance(prompts, dict) and "data" in prompts: if isinstance(prompts, dict) and "data" in prompts:
io_processor_prompt = True io_processor_prompt = True
...@@ -1017,19 +1129,16 @@ class LLM: ...@@ -1017,19 +1129,16 @@ class LLM:
pooling_params = self.io_processor.validate_or_generate_params( pooling_params = self.io_processor.validate_or_generate_params(
pooling_params pooling_params
) )
else:
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
if pooling_task not in self.supported_tasks: if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
for param in as_iter(pooling_params): for param in as_iter(pooling_params):
param.verify(pooling_task, model_config) param.verify(pooling_task, model_config)
# for backwards compatibility
if truncate_prompt_tokens is not None:
param.truncate_prompt_tokens = truncate_prompt_tokens
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=prompts, prompts=prompts,
...@@ -1094,6 +1203,7 @@ class LLM: ...@@ -1094,6 +1203,7 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns: Returns:
A list of `EmbeddingRequestOutput` objects containing the A list of `EmbeddingRequestOutput` objects containing the
...@@ -1105,9 +1215,14 @@ class LLM: ...@@ -1105,9 +1215,14 @@ class LLM:
"Try converting the model using `--convert embed`." "Try converting the model using `--convert embed`."
) )
if truncate_prompt_tokens is not None:
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
items = self.encode( items = self.encode(
prompts, prompts,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
pooling_params=pooling_params, pooling_params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
...@@ -1121,8 +1236,8 @@ class LLM: ...@@ -1121,8 +1236,8 @@ class LLM:
self, self,
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
*, *,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> list[ClassificationRequestOutput]: ) -> list[ClassificationRequestOutput]:
...@@ -1137,13 +1252,15 @@ class LLM: ...@@ -1137,13 +1252,15 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See [PromptType][vllm.inputs.PromptType] for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompt. for more details about the format of each prompt.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: If `True`, shows a tqdm progress bar. use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`), If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we tokenization_kwargs: Overrides for `tokenizer.encode`.
use the default pooling parameters.
Returns: Returns:
A list of `ClassificationRequestOutput` objects containing the A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts. embedding vectors in the same order as the input prompts.
...@@ -1170,9 +1287,9 @@ class LLM: ...@@ -1170,9 +1287,9 @@ class LLM:
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
/, /,
*, *,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
truncate_prompt_tokens: int | None = None, truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
...@@ -1183,13 +1300,15 @@ class LLM: ...@@ -1183,13 +1300,15 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See [PromptType][vllm.inputs.PromptType] for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompt. for more details about the format of each prompt.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: If `True`, shows a tqdm progress bar. use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`), If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we tokenization_kwargs: Overrides for `tokenizer.encode`.
use the default pooling parameters.
Returns: Returns:
A list of `PoolingRequestOutput` objects containing the A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts. pooled hidden states in the same order as the input prompts.
...@@ -1207,18 +1326,18 @@ class LLM: ...@@ -1207,18 +1326,18 @@ class LLM:
def _embedding_score( def _embedding_score(
self, self,
tokenizer: TokenizerLike, text_1: list[SingletonPrompt],
text_1: list[str | TextPrompt | TokensPrompt], text_2: list[SingletonPrompt],
text_2: list[str | TextPrompt | TokensPrompt], *,
truncate_prompt_tokens: int | None = None, use_tqdm: bool | Callable[..., tqdm],
use_tqdm: bool | Callable[..., tqdm] = True, pooling_params: PoolingParams | None,
pooling_params: PoolingParams | None = None, lora_request: list[LoRARequest] | LoRARequest | None,
lora_request: list[LoRARequest] | LoRARequest | None = None, tokenization_kwargs: dict[str, Any],
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
encoded_output: list[PoolingRequestOutput] = self.encode( tokenizer = self.get_tokenizer()
encoded_output = self.encode(
text_1 + text_2, text_1 + text_2,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
pooling_params=pooling_params, pooling_params=pooling_params,
...@@ -1226,14 +1345,16 @@ class LLM: ...@@ -1226,14 +1345,16 @@ class LLM:
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)] encoded_output_1 = encoded_output[0 : len(text_1)]
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :] encoded_output_2 = encoded_output[len(text_1) :]
if len(encoded_output_1) == 1: if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2) encoded_output_1 = encoded_output_1 * len(encoded_output_2)
scores = _cosine_similarity( scores = _cosine_similarity(
tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2 tokenizer=tokenizer,
embed_1=encoded_output_1,
embed_2=encoded_output_2,
) )
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
...@@ -1241,17 +1362,17 @@ class LLM: ...@@ -1241,17 +1362,17 @@ class LLM:
def _cross_encoding_score( def _cross_encoding_score(
self, self,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam], data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam], data_2: list[str] | list[ScoreContentPartParam],
truncate_prompt_tokens: int | None = None, *,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm],
pooling_params: PoolingParams | None = None, pooling_params: PoolingParams | None,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any],
score_template: str | None = None, score_template: str | None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
model_config = self.model_config model_config = self.model_config
tokenizer = self.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
raise ValueError("Score API is not supported for Mistral tokenizer") raise ValueError("Score API is not supported for Mistral tokenizer")
...@@ -1265,13 +1386,6 @@ class LLM: ...@@ -1265,13 +1386,6 @@ class LLM:
pooling_params.verify("score", model_config) pooling_params.verify("score", model_config)
pooling_params_list = list[PoolingParams]() pooling_params_list = list[PoolingParams]()
local_kwargs = tokenization_kwargs or {}
tokenization_kwargs = local_kwargs.copy()
_validate_truncation_size(
model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
)
prompts = list[PromptType]() prompts = list[PromptType]()
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
...@@ -1314,10 +1428,10 @@ class LLM: ...@@ -1314,10 +1428,10 @@ class LLM:
data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam, data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
/, /,
*, *,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | None = None, pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
chat_template: str | None = None, chat_template: str | None = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs `<text,text_pair>` or """Generate similarity scores for all pairs `<text,text_pair>` or
...@@ -1344,20 +1458,22 @@ class LLM: ...@@ -1344,20 +1458,22 @@ class LLM:
the LLM. Can be text or multi-modal data. See [PromptType] the LLM. Can be text or multi-modal data. See [PromptType]
[vllm.inputs.PromptType] for more details about the format of [vllm.inputs.PromptType] for more details about the format of
each prompt. each prompt.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: If `True`, shows a tqdm progress bar. use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`), If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
chat_template: The chat template to use for the scoring. If None, we chat_template: The chat template to use for the scoring. If None, we
use the model's default chat template. use the model's default chat template.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns: Returns:
A list of `ScoringRequestOutput` objects containing the A list of `ScoringRequestOutput` objects containing the
generated scores in the same order as the input prompts. generated scores in the same order as the input prompts.
""" """
model_config = self.model_config model_config = self.model_config
runner_type = model_config.runner_type runner_type = model_config.runner_type
if runner_type != "pooling": if runner_type != "pooling":
raise ValueError( raise ValueError(
...@@ -1445,26 +1561,27 @@ class LLM: ...@@ -1445,26 +1561,27 @@ class LLM:
_validate_score_input_lens(data_1, data_2) # type: ignore[arg-type] _validate_score_input_lens(data_1, data_2) # type: ignore[arg-type]
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
encode_kwargs = tok_params.get_encode_kwargs()
if model_config.is_cross_encoder: if model_config.is_cross_encoder:
return self._cross_encoding_score( return self._cross_encoding_score(
tokenizer,
data_1, # type: ignore[arg-type] data_1, # type: ignore[arg-type]
data_2, # type: ignore[arg-type] data_2, # type: ignore[arg-type]
truncate_prompt_tokens, use_tqdm=use_tqdm,
use_tqdm, pooling_params=pooling_params,
pooling_params, lora_request=lora_request,
lora_request, tokenization_kwargs=encode_kwargs,
score_template=chat_template, score_template=chat_template,
) )
else: else:
return self._embedding_score( return self._embedding_score(
tokenizer,
data_1, # type: ignore[arg-type] data_1, # type: ignore[arg-type]
data_2, # type: ignore[arg-type] data_2, # type: ignore[arg-type]
truncate_prompt_tokens, use_tqdm=use_tqdm,
use_tqdm, pooling_params=pooling_params,
pooling_params, lora_request=lora_request,
lora_request, tokenization_kwargs=encode_kwargs,
) )
def start_profile(self) -> None: def start_profile(self) -> None:
...@@ -1530,42 +1647,79 @@ class LLM: ...@@ -1530,42 +1647,79 @@ class LLM:
def _validate_and_add_requests( def _validate_and_add_requests(
self, self,
prompts: PromptType | Sequence[PromptType] | DataPrompt, prompts: PromptType | Sequence[PromptType],
params: SamplingParams params: SamplingParams
| Sequence[SamplingParams] | Sequence[SamplingParams]
| PoolingParams | PoolingParams
| Sequence[PoolingParams], | Sequence[PoolingParams],
*, *,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None, lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
) -> None: ) -> None:
if isinstance(prompts, (str, dict)): in_prompts = self._normalize_prompts(prompts)
# Convert a single prompt to a list. num_requests = len(in_prompts)
prompts = [prompts] # type: ignore[list-item]
if isinstance(params, Sequence):
num_requests = len(prompts) if len(params) != num_requests:
if isinstance(params, Sequence) and len(params) != num_requests: raise ValueError(
raise ValueError("The lengths of prompts and params must be the same.") f"The lengths of prompts ({params}) "
if isinstance(lora_request, Sequence) and len(lora_request) != num_requests: f"and lora_request ({len(params)}) must be the same."
raise ValueError( )
"The lengths of prompts and lora_request must be the same."
) engine_params = params
if priority is not None and len(priority) != num_requests: else:
raise ValueError( engine_params = [params] * num_requests
"The lengths of prompts "
f"({num_requests}) and priority ({len(priority)}) " if isinstance(lora_request, Sequence):
"must be the same." if len(lora_request) != num_requests:
raise ValueError(
f"The lengths of prompts ({num_requests}) "
f"and lora_request ({len(lora_request)}) must be the same."
)
engine_lora_requests: Sequence[LoRARequest | None] = lora_request
else:
engine_lora_requests = [lora_request] * num_requests
if priority is not None:
if len(priority) != num_requests:
raise ValueError(
f"The lengths of prompts ({num_requests}) "
f"and priority ({len(priority)}) must be the same."
)
else:
priority = [0] * num_requests
if any(param.truncate_prompt_tokens is not None for param in engine_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let
# `self._preprocess_completion` handle prompt normalization
engine_prompts = [
engine_prompt
for in_prompt, param in zip(in_prompts, engine_params)
for engine_prompt in self._preprocess_completion(
[in_prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
),
)
]
else:
engine_prompts = self._preprocess_completion(
in_prompts,
tokenization_kwargs=tokenization_kwargs,
) )
for sp in params if isinstance(params, Sequence) else (params,): for sp in engine_params:
if isinstance(sp, SamplingParams): if isinstance(sp, SamplingParams):
# We only care about the final output # We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine. # Add requests to the engine.
it = prompts it = engine_prompts
if use_tqdm: if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests") it = tqdm_func(it, desc="Adding requests")
...@@ -1576,12 +1730,10 @@ class LLM: ...@@ -1576,12 +1730,10 @@ class LLM:
for i, prompt in enumerate(it): for i, prompt in enumerate(it):
request_id = self._add_request( request_id = self._add_request(
prompt, prompt,
params[i] if isinstance(params, Sequence) else params, engine_params[i],
lora_request=lora_request[i] lora_request=engine_lora_requests[i],
if isinstance(lora_request, Sequence)
else lora_request,
priority=priority[i] if priority else 0,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority[i],
) )
added_request_ids.append(request_id) added_request_ids.append(request_id)
except Exception as e: except Exception as e:
...@@ -1589,54 +1741,42 @@ class LLM: ...@@ -1589,54 +1741,42 @@ class LLM:
self.llm_engine.abort_request(added_request_ids, internal=True) self.llm_engine.abort_request(added_request_ids, internal=True)
raise e raise e
def _process_inputs(
self,
request_id: str,
engine_prompt: PromptType,
params: SamplingParams | PoolingParams,
*,
lora_request: LoRARequest | None,
priority: int,
tokenization_kwargs: dict[str, Any] | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for LLMEngine."""
local_kwargs = tokenization_kwargs or {}
tokenization_kwargs = local_kwargs.copy()
_validate_truncation_size(
self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs,
)
engine_request = self.input_processor.process_inputs(
request_id,
engine_prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return engine_request, tokenization_kwargs
def _add_request( def _add_request(
self, self,
prompt: PromptType, prompt: PromptType,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
priority: int = 0,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
priority: int = 0,
) -> str: ) -> str:
prompt_text, _, _ = get_prompt_components(prompt) prompt_text, _, _ = get_prompt_components(prompt)
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
engine_request, tokenization_kwargs = self._process_inputs( if params.truncate_prompt_tokens is not None:
params_type = type(params).__name__
warnings.warn(
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
"is deprecated and will be removed in v0.16. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
)
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id, request_id,
prompt, prompt,
params, params,
lora_request=lora_request, lora_request=lora_request,
priority=priority,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority,
) )
self.llm_engine.add_request( self.llm_engine.add_request(
......
...@@ -13,12 +13,13 @@ from openai.types.chat.chat_completion_audio import ( ...@@ -13,12 +13,13 @@ from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio as OpenAIChatCompletionAudio, ChatCompletionAudio as OpenAIChatCompletionAudio,
) )
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
from pydantic import ( from pydantic import Field, model_validator
Field,
model_validator,
)
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat, AnyResponseFormat,
DeltaMessage, DeltaMessage,
...@@ -36,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -36,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import ( from vllm.sampling_params import (
BeamSearchParams, BeamSearchParams,
RequestOutputKind, RequestOutputKind,
...@@ -348,6 +350,43 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -348,6 +350,43 @@ class ChatCompletionRequest(OpenAIBaseModel):
# --8<-- [end:chat-completion-extra-params] # --8<-- [end:chat-completion-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
documents=self.documents,
reasoning_effort=self.reasoning_effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
if self.max_completion_tokens is not None:
max_output_tokens: int | None = self.max_completion_tokens
max_output_tokens_param = "max_completion_tokens"
else:
max_output_tokens = self.max_tokens
max_output_tokens_param = "max_tokens"
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=max_output_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param=max_output_tokens_param,
)
# Default sampling parameters for chat completion requests # Default sampling parameters for chat completion requests
_DEFAULT_SAMPLING_PARAMS: dict = { _DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
......
...@@ -67,7 +67,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -67,7 +67,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
) )
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
...@@ -185,8 +185,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -185,8 +185,6 @@ class OpenAIServingChat(OpenAIServing):
start_time = time.perf_counter() start_time = time.perf_counter()
try: try:
renderer = self.engine_client.renderer
# Create a minimal dummy request # Create a minimal dummy request
dummy_request = ChatCompletionRequest( dummy_request = ChatCompletionRequest(
messages=[{"role": "user", "content": "warmup"}], messages=[{"role": "user", "content": "warmup"}],
...@@ -201,18 +199,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -201,18 +199,10 @@ class OpenAIServingChat(OpenAIServing):
# 3. Tokenizer initialization for chat # 3. Tokenizer initialization for chat
await self._preprocess_chat( await self._preprocess_chat(
dummy_request, dummy_request,
renderer,
dummy_request.messages, dummy_request.messages,
chat_template=self.chat_template, default_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
add_generation_prompt=True, default_template_kwargs=self.default_chat_template_kwargs,
continue_final_message=False,
tool_dicts=None,
documents=None,
chat_template_kwargs=None,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=None,
add_special_tokens=False,
) )
elapsed = (time.perf_counter() - start_time) * 1000 elapsed = (time.perf_counter() - start_time) * 1000
...@@ -225,7 +215,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -225,7 +215,10 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request( async def render_chat_request(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[Any]] | ErrorResponse: ) -> (
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
| ErrorResponse
):
""" """
render chat request by validating and preprocessing inputs. render chat request by validating and preprocessing inputs.
...@@ -302,23 +295,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -302,23 +295,14 @@ class OpenAIServingChat(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
chat_template_kwargs = request.chat_template_kwargs or {}
chat_template_kwargs.update(reasoning_effort=request.reasoning_effort)
conversation, engine_prompts = await self._preprocess_chat( conversation, engine_prompts = await self._preprocess_chat(
request, request,
renderer,
request.messages, request.messages,
chat_template=request.chat_template or self.chat_template, default_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt, default_template_kwargs=self.default_chat_template_kwargs,
continue_final_message=request.continue_final_message,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
documents=request.documents,
chat_template_kwargs=chat_template_kwargs,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=tool_parser, tool_parser=tool_parser,
add_special_tokens=request.add_special_tokens,
) )
else: else:
# For GPT-OSS. # For GPT-OSS.
...@@ -428,11 +412,15 @@ class OpenAIServingChat(OpenAIServing): ...@@ -428,11 +412,15 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers, trace_headers=trace_headers,
) )
else: else:
engine_request, tokenization_kwargs = await self._process_inputs( tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id, sub_request_id,
engine_prompt, engine_prompt,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
......
...@@ -9,11 +9,9 @@ from dataclasses import replace ...@@ -9,11 +9,9 @@ from dataclasses import replace
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
import torch import torch
from pydantic import ( from pydantic import Field, model_validator
Field,
model_validator,
)
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat, AnyResponseFormat,
LegacyStructuralTagResponseFormat, LegacyStructuralTagResponseFormat,
...@@ -27,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -27,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import ( from vllm.sampling_params import (
BeamSearchParams, BeamSearchParams,
RequestOutputKind, RequestOutputKind,
...@@ -178,6 +177,17 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -178,6 +177,17 @@ class CompletionRequest(OpenAIBaseModel):
# --8<-- [end:completion-extra-params] # --8<-- [end:completion-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_tokens",
)
# Default sampling parameters for completion requests # Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = { _DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0, "repetition_penalty": 1.0,
......
...@@ -32,7 +32,6 @@ from vllm.entrypoints.openai.engine.serving import ( ...@@ -32,7 +32,6 @@ from vllm.entrypoints.openai.engine.serving import (
clamp_prompt_logprobs, clamp_prompt_logprobs,
) )
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
...@@ -111,11 +110,10 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -111,11 +110,10 @@ class OpenAIServingCompletion(OpenAIServing):
) )
try: try:
renderer = self._get_completion_renderer() engine_prompts = await self._preprocess_completion(
engine_prompts = await renderer.render_prompt_and_embeds( request,
prompt_or_prompts=request.prompt, prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds, prompt_embeds=request.prompt_embeds,
config=self._build_render_config(request),
) )
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
...@@ -203,10 +201,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -203,10 +201,6 @@ class OpenAIServingCompletion(OpenAIServing):
else await self._get_trace_headers(raw_request.headers) else await self._get_trace_headers(raw_request.headers)
) )
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search( generator = self.beam_search(
prompt=engine_prompt, prompt=engine_prompt,
...@@ -216,11 +210,15 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -216,11 +210,15 @@ class OpenAIServingCompletion(OpenAIServing):
trace_headers=trace_headers, trace_headers=trace_headers,
) )
else: else:
engine_request, tokenization_kwargs = await self._process_inputs( tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id_item, request_id_item,
engine_prompt, engine_prompt,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
...@@ -709,26 +707,3 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -709,26 +707,3 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens, tokens=out_tokens,
top_logprobs=out_top_logprobs, top_logprobs=out_top_logprobs,
) )
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: int | None = None,
) -> RenderConfig:
# Validate max_tokens before using it
if request.max_tokens is not None and request.max_tokens > self.max_model_len:
raise VLLMValidationError(
f"'max_tokens' ({request.max_tokens}) cannot be greater than "
f"the model's maximum context length ({self.max_model_len}).",
parameter="max_tokens",
value=request.max_tokens,
)
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo and not request.return_token_ids),
)
...@@ -16,9 +16,7 @@ from pydantic import ( ...@@ -16,9 +16,7 @@ from pydantic import (
from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import ( from vllm.sampling_params import SamplingParams
SamplingParams,
)
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
......
...@@ -5,10 +5,10 @@ import json ...@@ -5,10 +5,10 @@ import json
import sys import sys
import time import time
import traceback import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping from collections.abc import AsyncGenerator, Callable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
import numpy as np import numpy as np
from fastapi import Request from fastapi import Request
...@@ -20,6 +20,7 @@ from starlette.datastructures import Headers ...@@ -20,6 +20,7 @@ from starlette.datastructures import Headers
import vllm.envs as envs import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
...@@ -86,7 +87,6 @@ from vllm.entrypoints.pooling.score.protocol import ( ...@@ -86,7 +87,6 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreResponse, ScoreResponse,
ScoreTextRequest, ScoreTextRequest,
) )
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.serve.tokenize.protocol import ( from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest, DetokenizeRequest,
...@@ -94,13 +94,9 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -94,13 +94,9 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeResponse, TokenizeResponse,
) )
from vllm.entrypoints.utils import ( from vllm.entrypoints.utils import get_max_tokens, sanitize_message
_validate_truncation_size,
get_max_tokens,
sanitize_message,
)
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt
from vllm.inputs.parse import ( from vllm.inputs.parse import (
get_prompt_components, get_prompt_components,
is_explicit_encoder_decoder_prompt, is_explicit_encoder_decoder_prompt,
...@@ -112,7 +108,7 @@ from vllm.multimodal import MultiModalDataDict ...@@ -112,7 +108,7 @@ from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.renderers import RendererLike from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager from vllm.tool_parsers import ToolParser, ToolParserManager
...@@ -123,11 +119,9 @@ from vllm.tracing import ( ...@@ -123,11 +119,9 @@ from vllm.tracing import (
) )
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.async_utils import ( from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
collect_from_async_generator, collect_from_async_generator,
merge_async_iterators, merge_async_iterators,
) )
from vllm.v1.engine import EngineCoreRequest
class GenerationError(Exception): class GenerationError(Exception):
...@@ -140,6 +134,21 @@ class GenerationError(Exception): ...@@ -140,6 +134,21 @@ class GenerationError(Exception):
logger = init_logger(__name__) logger = init_logger(__name__)
class RendererRequest(Protocol):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
raise NotImplementedError
class RendererChatRequest(RendererRequest, Protocol):
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
raise NotImplementedError
CompletionLikeRequest: TypeAlias = ( CompletionLikeRequest: TypeAlias = (
CompletionRequest CompletionRequest
| TokenizeCompletionRequest | TokenizeCompletionRequest
...@@ -158,7 +167,9 @@ ChatLikeRequest: TypeAlias = ( ...@@ -158,7 +167,9 @@ ChatLikeRequest: TypeAlias = (
| ClassificationChatRequest | ClassificationChatRequest
| PoolingChatRequest | PoolingChatRequest
) )
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = ( AnyRequest: TypeAlias = (
CompletionLikeRequest CompletionLikeRequest
| ChatLikeRequest | ChatLikeRequest
...@@ -193,7 +204,7 @@ class ServeContext(Generic[RequestT]): ...@@ -193,7 +204,7 @@ class ServeContext(Generic[RequestT]):
request_id: str request_id: str
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt] | None = None engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None None
...@@ -227,7 +238,6 @@ class OpenAIServing: ...@@ -227,7 +238,6 @@ class OpenAIServing:
self.request_logger = request_logger self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids self.return_tokens_as_token_ids = return_tokens_as_token_ids
self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack self.log_error_stack = log_error_stack
self.input_processor = self.models.input_processor self.input_processor = self.models.input_processor
...@@ -519,41 +529,6 @@ class OpenAIServing: ...@@ -519,41 +529,6 @@ class OpenAIServing:
prompt_logprobs=None, prompt_logprobs=None,
) )
def _get_completion_renderer(self) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
"""
return CompletionRenderer(
model_config=self.model_config,
tokenizer=self.renderer.tokenizer,
async_tokenizer_pool=self._async_tokenizer_pool,
)
def _build_render_config(
self,
request: Any,
) -> RenderConfig:
"""
Build and return a `RenderConfig` for an endpoint.
Used by the renderer to control how prompts are prepared
(e.g., tokenization and length handling). Endpoints should
implement this with logic appropriate to their request type.
"""
raise NotImplementedError
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
given tokenizer.
"""
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self._async_tokenizer_pool[tokenizer] = async_tokenizer
return async_tokenizer
async def _preprocess( async def _preprocess(
self, self,
ctx: ServeContext, ctx: ServeContext,
...@@ -912,71 +887,6 @@ class OpenAIServing: ...@@ -912,71 +887,6 @@ class OpenAIServing:
message_types.add(content_dict["type"].split("_")[0]) message_types.add(content_dict["type"].split("_")[0])
return message_types return message_types
async def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
prompt: str,
tokenizer: TokenizerLike,
add_special_tokens: bool,
) -> TokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
if (
self.model_config.encoder_config is not None
and self.model_config.encoder_config.get("do_lower_case", False)
):
prompt = prompt.lower()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(
prompt, add_special_tokens=add_special_tokens
)
elif truncate_prompt_tokens < 0:
# Negative means we cap at the model's max length
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=self.max_model_len,
)
else:
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens,
)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
async def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
prompt_ids: list[int],
tokenizer: TokenizerLike | None,
) -> TokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
if truncate_prompt_tokens is None:
input_ids = prompt_ids
elif truncate_prompt_tokens < 0:
input_ids = prompt_ids[-self.max_model_len :]
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
if tokenizer is None:
input_text = ""
else:
async_tokenizer = self._get_async_tokenizer(tokenizer)
input_text = await async_tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input( def _validate_input(
self, self,
request: object, request: object,
...@@ -1061,50 +971,6 @@ class OpenAIServing: ...@@ -1061,50 +971,6 @@ class OpenAIServing:
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
async def _tokenize_prompt_input_async(
self,
request: AnyRequest,
tokenizer: TokenizerLike,
prompt_input: str | list[int],
add_special_tokens: bool = True,
) -> TokensPrompt:
"""
A simpler implementation that tokenizes a single prompt input.
"""
async for result in self._tokenize_prompt_inputs_async(
request,
tokenizer,
[prompt_input],
add_special_tokens=add_special_tokens,
):
return result
raise ValueError("No results yielded from tokenization")
async def _tokenize_prompt_inputs_async(
self,
request: AnyRequest,
tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True,
) -> AsyncGenerator[TokensPrompt, None]:
"""
A simpler implementation that tokenizes multiple prompt inputs.
"""
for prompt in prompt_inputs:
if isinstance(prompt, str):
yield await self._normalize_prompt_text_to_input(
request,
prompt=prompt,
tokenizer=tokenizer,
add_special_tokens=add_special_tokens,
)
else:
yield await self._normalize_prompt_tokens_to_input(
request,
prompt_ids=prompt,
tokenizer=tokenizer,
)
def _validate_chat_template( def _validate_chat_template(
self, self,
request_chat_template: str | None, request_chat_template: str | None,
...@@ -1137,131 +1003,94 @@ class OpenAIServing: ...@@ -1137,131 +1003,94 @@ class OpenAIServing:
# Apply server defaults first, then request kwargs override. # Apply server defaults first, then request kwargs override.
return default_chat_template_kwargs | request_chat_template_kwargs return default_chat_template_kwargs | request_chat_template_kwargs
async def _preprocess_completion(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokensPrompt | EmbedsPrompt]:
renderer = self.renderer
tok_params = request.build_tok_params(self.model_config)
in_prompts = await renderer.render_completions_async(
prompt_input, prompt_embeds
)
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
for prompt in engine_prompts:
prompt.update(extra_items) # type: ignore
return engine_prompts
async def _preprocess_chat( async def _preprocess_chat(
self, self,
request: ChatLikeRequest | ResponsesRequest, request: RendererChatRequest,
renderer: RendererLike,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
chat_template: str | None, default_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption, default_template_content_format: ChatTemplateContentFormatOption,
add_generation_prompt: bool = True, default_template_kwargs: dict[str, Any] | None,
continue_final_message: bool = False,
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
documents: list[dict[str, str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
default_chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False, ) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]:
) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
chat_template_kwargs = {
"chat_template": chat_template,
"add_generation_prompt": add_generation_prompt,
"continue_final_message": continue_final_message,
"tools": tool_dicts,
"documents": documents,
**(chat_template_kwargs or {}),
}
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
chat_template_kwargs,
default_chat_template_kwargs,
)
# Use the async tokenizer in `OpenAIServing` if possible.
# Later we can move it into the renderer so that we can return both
# text and token IDs in the same prompt from `render_messages_async`
# which is used for logging and `enable_response_messages`.
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
conversation, engine_prompt = await renderer.render_messages_async( renderer = self.renderer
messages,
chat_template_content_format=chat_template_content_format, default_template_kwargs = merge_kwargs(
tokenize=( default_template_kwargs,
chat_template_kwargs.pop("tokenize", False) dict(
or isinstance(renderer.tokenizer, MistralTokenizer) tools=tool_dicts,
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
), ),
**chat_template_kwargs,
) )
if "prompt_token_ids" not in engine_prompt: tok_params = request.build_tok_params(self.model_config)
extra_data = engine_prompt chat_params = request.build_chat_params(
engine_prompt = await self._tokenize_prompt_input_async( default_template, default_template_content_format
request, ).with_defaults(default_template_kwargs)
renderer.get_tokenizer(),
engine_prompt["prompt"],
add_special_tokens=add_special_tokens,
)
# Fill in other keys like MM data
engine_prompt.update(extra_data) # type: ignore
else:
self._validate_input(
request=request,
input_ids=engine_prompt["prompt_token_ids"], # type: ignore
input_text="",
)
engine_prompt = cast(TokensPrompt, engine_prompt) conversation, prompt = await renderer.render_messages_async(
messages, chat_params
)
engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params)
if request.mm_processor_kwargs is not None: extra_items = {
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs k: v
if (cache_salt := getattr(request, "cache_salt", None)) is not None: for k in ("mm_processor_kwargs", "cache_salt")
engine_prompt["cache_salt"] = cache_salt if (v := getattr(request, k, None)) is not None
}
engine_prompt.update(extra_items) # type: ignore
# tool parsing is done only if a tool_parser has been set and if # tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM # is set, we want to prevent parsing a tool_call hallucinated by the LLM
should_parse_tools = tool_parser is not None and ( if tool_parser is not None:
hasattr(request, "tool_choice") and request.tool_choice != "none" tool_choice = getattr(request, "tool_choice", "none")
) if tool_choice != "none":
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
if should_parse_tools: msg = (
if not isinstance(request, ChatCompletionRequest | ResponsesRequest): "Tool usage is only supported for Chat Completions API "
msg = ( "or Responses API requests."
"Tool usage is only supported for Chat Completions API " )
"or Responses API requests." raise NotImplementedError(msg)
)
raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer() # TODO: Update adjust_request to accept ResponsesRequest
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type]
return conversation, [engine_prompt] return conversation, [engine_prompt]
async def _process_inputs(
self,
request_id: str,
engine_prompt: PromptType,
params: SamplingParams | PoolingParams,
*,
lora_request: LoRARequest | None,
trace_headers: Mapping[str, str] | None,
priority: int,
data_parallel_rank: int | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(
self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
)
engine_request = self.input_processor.process_inputs(
request_id,
engine_prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return engine_request, tokenization_kwargs
async def _render_next_turn( async def _render_next_turn(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
renderer: RendererLike,
messages: list[ResponseInputOutputItem], messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None, tool_dicts: list[dict[str, Any]] | None,
tool_parser, tool_parser: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None, chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
): ):
...@@ -1271,24 +1100,25 @@ class OpenAIServing: ...@@ -1271,24 +1100,25 @@ class OpenAIServing:
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
request, request,
renderer,
new_messages, new_messages,
default_template=chat_template,
default_template_content_format=chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
tool_parser=tool_parser, tool_parser=tool_parser,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
) )
return engine_prompts return engine_prompts
async def _generate_with_builtin_tools( async def _generate_with_builtin_tools(
self, self,
request_id: str, request_id: str,
engine_prompt: TokensPrompt, engine_prompt: TokensPrompt | EmbedsPrompt,
sampling_params: SamplingParams, sampling_params: SamplingParams,
tok_params: TokenizeParams,
context: ConversationContext, context: ConversationContext,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
priority: int = 0, priority: int = 0,
**kwargs, trace_headers: Mapping[str, str] | None = None,
): ):
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text, _, _ = get_prompt_components(engine_prompt)
...@@ -1297,18 +1127,21 @@ class OpenAIServing: ...@@ -1297,18 +1127,21 @@ class OpenAIServing:
while True: while True:
# Ensure that each sub-request has a unique request id. # Ensure that each sub-request has a unique request id.
sub_request_id = f"{request_id}_{sub_request}" sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs( self._log_inputs(
sub_request_id, sub_request_id,
engine_prompt, engine_prompt,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
trace_headers = kwargs.get("trace_headers")
engine_request, tokenization_kwargs = await self._process_inputs( tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id, sub_request_id,
engine_prompt, engine_prompt,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
) )
...@@ -1318,10 +1151,10 @@ class OpenAIServing: ...@@ -1318,10 +1151,10 @@ class OpenAIServing:
sampling_params, sampling_params,
sub_request_id, sub_request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers,
priority=priority, priority=priority,
prompt_text=prompt_text, prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
**kwargs,
) )
async for res in generator: async for res in generator:
...@@ -1350,7 +1183,6 @@ class OpenAIServing: ...@@ -1350,7 +1183,6 @@ class OpenAIServing:
elif isinstance(context, ParsableContext): elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn( engine_prompts = await self._render_next_turn(
context.request, context.request,
context.renderer,
context.parser.response_messages, context.parser.response_messages,
context.tool_dicts, context.tool_dicts,
context.tool_parser_cls, context.tool_parser_cls,
......
...@@ -43,7 +43,6 @@ from vllm.entrypoints.openai.responses.protocol import ( ...@@ -43,7 +43,6 @@ from vllm.entrypoints.openai.responses.protocol import (
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.renderers import RendererLike
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -261,7 +260,7 @@ class ParsableContext(ConversationContext): ...@@ -261,7 +260,7 @@ class ParsableContext(ConversationContext):
self, self,
*, *,
response_messages: list[ResponseInputOutputItem], response_messages: list[ResponseInputOutputItem],
renderer: RendererLike, tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None, reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
request: ResponsesRequest, request: ResponsesRequest,
available_tools: list[str] | None, available_tools: list[str] | None,
...@@ -280,7 +279,6 @@ class ParsableContext(ConversationContext): ...@@ -280,7 +279,6 @@ class ParsableContext(ConversationContext):
if reasoning_parser_cls is None: if reasoning_parser_cls is None:
raise ValueError("reasoning_parser_cls must be provided.") raise ValueError("reasoning_parser_cls must be provided.")
tokenizer = renderer.get_tokenizer()
self.parser = get_responses_parser_for_simple_context( self.parser = get_responses_parser_for_simple_context(
tokenizer=tokenizer, tokenizer=tokenizer,
reasoning_parser_cls=reasoning_parser_cls, reasoning_parser_cls=reasoning_parser_cls,
...@@ -290,8 +288,6 @@ class ParsableContext(ConversationContext): ...@@ -290,8 +288,6 @@ class ParsableContext(ConversationContext):
) )
self.tool_parser_cls = tool_parser_cls self.tool_parser_cls = tool_parser_cls
self.request = request self.request = request
self.renderer = renderer
self.tokenizer = tokenizer
self.available_tools = available_tools or [] self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {} self._tool_sessions: dict[str, ClientSession | Tool] = {}
......
...@@ -59,12 +59,15 @@ from pydantic import ( ...@@ -59,12 +59,15 @@ from pydantic import (
model_validator, model_validator,
) )
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.chat_utils import (
OpenAIBaseModel, ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
) )
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import ( from vllm.sampling_params import (
RequestOutputKind, RequestOutputKind,
SamplingParams, SamplingParams,
...@@ -230,6 +233,42 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -230,6 +233,42 @@ class ResponsesRequest(OpenAIBaseModel):
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
# --8<-- [end:responses-extra-params] # --8<-- [end:responses-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
from .utils import should_continue_final_message
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(self.input)
reasoning = self.reasoning
return ChatParams(
chat_template=default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs( # To remove unset values
{},
dict(
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
reasoning_effort=None if reasoning is None else reasoning.effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_output_tokens or 0,
truncate_prompt_tokens=-1 if self.truncation != "disabled" else None,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_output_tokens",
)
_DEFAULT_SAMPLING_PARAMS = { _DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0, "temperature": 1.0,
"top_p": 1.0, "top_p": 1.0,
......
...@@ -114,16 +114,15 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -114,16 +114,15 @@ from vllm.entrypoints.openai.responses.utils import (
construct_input_messages, construct_input_messages,
construct_tool_dicts, construct_tool_dicts,
extract_tool_types, extract_tool_types,
should_continue_final_message,
) )
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -291,13 +290,14 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -291,13 +290,14 @@ class OpenAIServingResponses(OpenAIServing):
self.tool_server = tool_server self.tool_server = tool_server
def _validate_generator_input( def _validate_generator_input(
self, engine_prompt: TokensPrompt self,
engine_prompt: TokensPrompt | EmbedsPrompt,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Add validations to the input to the generator here.""" """Add validations to the input to the generator here."""
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): prompt_len = get_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len:
error_message = ( error_message = (
"The engine prompt length" f"The engine prompt length {prompt_len} "
f" {len(engine_prompt['prompt_token_ids'])} "
f"exceeds the max_model_len {self.max_model_len}. " f"exceeds the max_model_len {self.max_model_len}. "
"Please reduce prompt." "Please reduce prompt."
) )
...@@ -307,6 +307,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -307,6 +307,7 @@ class OpenAIServingResponses(OpenAIServing):
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
param="input", param="input",
) )
return None return None
def _validate_create_responses_input( def _validate_create_responses_input(
...@@ -387,8 +388,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -387,8 +388,6 @@ class OpenAIServingResponses(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request) model_name = self.models.model_name(lora_request)
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
if self.use_harmony: if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony( messages, engine_prompts = self._make_request_with_harmony(
...@@ -396,7 +395,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -396,7 +395,7 @@ class OpenAIServingResponses(OpenAIServing):
) )
else: else:
messages, engine_prompts = await self._make_request( messages, engine_prompts = await self._make_request(
request, prev_response, renderer request, prev_response
) )
except ( except (
...@@ -431,6 +430,9 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -431,6 +430,9 @@ class OpenAIServingResponses(OpenAIServing):
assert len(builtin_tool_list) == 0 assert len(builtin_tool_list) == 0
available_tools = [] available_tools = []
try: try:
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
for engine_prompt in engine_prompts: for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt) maybe_error = self._validate_generator_input(engine_prompt)
if maybe_error is not None: if maybe_error is not None:
...@@ -446,6 +448,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -446,6 +448,7 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params default_max_tokens, self.default_sampling_params
) )
tok_params = request.build_tok_params(self.model_config)
trace_headers = ( trace_headers = (
None None
...@@ -465,7 +468,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -465,7 +468,7 @@ class OpenAIServingResponses(OpenAIServing):
# tokens during generation instead of at the end # tokens during generation instead of at the end
context = ParsableContext( context = ParsableContext(
response_messages=messages, response_messages=messages,
renderer=renderer, tokenizer=tokenizer,
reasoning_parser_cls=self.reasoning_parser, reasoning_parser_cls=self.reasoning_parser,
request=request, request=request,
tool_parser_cls=self.tool_parser, tool_parser_cls=self.tool_parser,
...@@ -495,6 +498,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -495,6 +498,7 @@ class OpenAIServingResponses(OpenAIServing):
request_id=request.request_id, request_id=request.request_id,
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
tok_params=tok_params,
context=context, context=context,
lora_request=lora_request, lora_request=lora_request,
priority=request.priority, priority=request.priority,
...@@ -596,7 +600,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -596,7 +600,6 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
request: ResponsesRequest, request: ResponsesRequest,
prev_response: ResponsesResponse | None, prev_response: ResponsesResponse | None,
renderer: RendererLike,
): ):
tool_dicts = construct_tool_dicts(request.tools, request.tool_choice) tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
# Construct the input messages. # Construct the input messages.
...@@ -606,30 +609,15 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -606,30 +609,15 @@ class OpenAIServingResponses(OpenAIServing):
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None, prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
prev_response_output=prev_response.output if prev_response else None, prev_response_output=prev_response.output if prev_response else None,
) )
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(request.input)
chat_template_kwargs = dict(
reasoning_effort=None
if request.reasoning is None
else request.reasoning.effort
)
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
request, request,
renderer,
messages, messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
tool_parser=self.tool_parser, tool_parser=self.tool_parser,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
# When continuing a partial message, we set continue_final_message=True
# and add_generation_prompt=False so the model continues the message
# rather than starting a new one.
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
chat_template_kwargs=chat_template_kwargs,
) )
return messages, engine_prompts return messages, engine_prompts
......
...@@ -8,8 +8,12 @@ from pydantic import Field, model_validator ...@@ -8,8 +8,12 @@ from pydantic import Field, model_validator
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config.pooler import get_use_activation from vllm.config.pooler import get_use_activation
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.renderers import ChatParams, merge_kwargs
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
...@@ -119,6 +123,23 @@ class ChatRequestMixin(OpenAIBaseModel): ...@@ -119,6 +123,23 @@ class ChatRequestMixin(OpenAIBaseModel):
) )
return data return data
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
),
),
)
class EncodingRequestMixin(OpenAIBaseModel): class EncodingRequestMixin(OpenAIBaseModel):
# --8<-- [start:encoding-params] # --8<-- [start:encoding-params]
......
...@@ -4,10 +4,9 @@ ...@@ -4,10 +4,9 @@
import time import time
from typing import Any, TypeAlias from typing import Any, TypeAlias
from pydantic import ( from pydantic import Field
Field,
)
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin, ChatRequestMixin,
...@@ -15,13 +14,24 @@ from vllm.entrypoints.pooling.base.protocol import ( ...@@ -15,13 +14,24 @@ from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin, CompletionRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
class ClassificationCompletionRequest( class ClassificationCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
): ):
pass def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
class ClassificationChatRequest( class ClassificationChatRequest(
...@@ -33,6 +43,18 @@ class ClassificationChatRequest( ...@@ -33,6 +43,18 @@ class ClassificationChatRequest(
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
ClassificationRequest: TypeAlias = ( ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest ClassificationCompletionRequest | ClassificationChatRequest
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus from typing import Final, TypeAlias
from typing import Final, cast
import jinja2 import jinja2
import numpy as np import numpy as np
...@@ -21,15 +20,14 @@ from vllm.entrypoints.pooling.classify.protocol import ( ...@@ -21,15 +20,14 @@ from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest, ClassificationRequest,
ClassificationResponse, ClassificationResponse,
) )
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput from vllm.outputs import ClassificationOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
logger = init_logger(__name__) logger = init_logger(__name__)
ClassificationServeContext = ServeContext[ClassificationRequest] ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing): class ServingClassification(OpenAIServing):
...@@ -77,34 +75,18 @@ class ServingClassification(OpenAIServing): ...@@ -77,34 +75,18 @@ class ServingClassification(OpenAIServing):
if error_check_ret: if error_check_ret:
return error_check_ret return error_check_ret
_, engine_prompts = await self._preprocess_chat( _, ctx.engine_prompts = await self._preprocess_chat(
ctx.request, ctx.request,
self.renderer,
ctx.request.messages, ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template, default_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt, default_template_kwargs=None,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
) )
ctx.engine_prompts = engine_prompts
elif isinstance(ctx.request, ClassificationCompletionRequest): elif isinstance(ctx.request, ClassificationCompletionRequest):
input_data = ctx.request.input ctx.engine_prompts = await self._preprocess_completion(
if input_data in (None, ""): ctx.request,
return self.create_error_response( prompt_input=ctx.request.input,
"Input or messages must be provided", prompt_embeds=None,
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(input_data, list) and not input_data:
ctx.engine_prompts = []
return None
renderer = self._get_completion_renderer()
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(ctx.request),
) )
else: else:
return self.create_error_response("Invalid classification request type") return self.create_error_response("Invalid classification request type")
...@@ -128,7 +110,7 @@ class ServingClassification(OpenAIServing): ...@@ -128,7 +110,7 @@ class ServingClassification(OpenAIServing):
items: list[ClassificationData] = [] items: list[ClassificationData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) final_res_batch_checked = ctx.final_res_batch
for idx, final_res in enumerate(final_res_batch_checked): for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs) classify_res = ClassificationOutput.from_base(final_res.outputs)
...@@ -161,13 +143,6 @@ class ServingClassification(OpenAIServing): ...@@ -161,13 +143,6 @@ class ServingClassification(OpenAIServing):
usage=usage, usage=usage,
) )
def _build_render_config(self, request: ClassificationRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
async def create_classify( async def create_classify(
self, self,
request: ClassificationRequest, request: ClassificationRequest,
......
...@@ -3,10 +3,9 @@ ...@@ -3,10 +3,9 @@
import time import time
from typing import Any, TypeAlias from typing import Any, TypeAlias
from pydantic import ( from pydantic import Field
Field,
)
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin, ChatRequestMixin,
...@@ -14,15 +13,47 @@ from vllm.entrypoints.pooling.base.protocol import ( ...@@ -14,15 +13,47 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin, EmbedRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
def _get_max_total_output_tokens(
model_config: ModelConfig,
) -> tuple[int | None, int]:
max_total_tokens = model_config.max_model_len
pooler_config = model_config.pooler_config
if pooler_config is None:
return max_total_tokens, 0
if pooler_config.enable_chunked_processing:
return None, 0
max_embed_len = pooler_config.max_embed_len or max_total_tokens
max_output_tokens = max_total_tokens - max_embed_len
return max_total_tokens, max_output_tokens
class EmbeddingCompletionRequest( class EmbeddingCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
): ):
# Ordered by official OpenAI API documentation def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
# https://platform.openai.com/docs/api-reference/embeddings encoder_config = model_config.encoder_config or {}
pass
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
class EmbeddingChatRequest( class EmbeddingChatRequest(
...@@ -33,6 +64,24 @@ class EmbeddingChatRequest( ...@@ -33,6 +64,24 @@ class EmbeddingChatRequest(
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json import json
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, cast from typing import Any, Final, TypeAlias
import torch import torch
from fastapi import Request from fastapi import Request
...@@ -22,8 +22,7 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -22,8 +22,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponse, EmbeddingResponse,
EmbeddingResponseData, EmbeddingResponseData,
) )
from vllm.entrypoints.renderer import RenderConfig from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -37,7 +36,7 @@ from vllm.utils.serial_utils import ( ...@@ -37,7 +36,7 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
EmbeddingServeContext = ServeContext[EmbeddingRequest] EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest]
class OpenAIServingEmbedding(OpenAIServing): class OpenAIServingEmbedding(OpenAIServing):
...@@ -95,19 +94,16 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -95,19 +94,16 @@ class OpenAIServingEmbedding(OpenAIServing):
_, ctx.engine_prompts = await self._preprocess_chat( _, ctx.engine_prompts = await self._preprocess_chat(
ctx.request, ctx.request,
self.renderer,
ctx.request.messages, ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template, default_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt, default_template_kwargs=None,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
) )
elif isinstance(ctx.request, EmbeddingCompletionRequest): elif isinstance(ctx.request, EmbeddingCompletionRequest):
renderer = self._get_completion_renderer() ctx.engine_prompts = await self._preprocess_completion(
ctx.engine_prompts = await renderer.render_prompt( ctx.request,
prompt_or_prompts=ctx.request.input, prompt_input=ctx.request.input,
config=self._build_render_config(ctx.request), prompt_embeds=None,
) )
else: else:
return self.create_error_response("Invalid classification request type") return self.create_error_response("Invalid classification request type")
...@@ -117,19 +113,6 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -117,19 +113,6 @@ class OpenAIServingEmbedding(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
return RenderConfig(
max_length=max_length,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
def _build_response( def _build_response(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
...@@ -246,14 +229,18 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -246,14 +229,18 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Create generator for this chunk and wrap it to return indices # Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode( original_generator = self.engine_client.encode(
chunk_engine_prompt, chunk_engine_prompt,
pooling_params, pooling_params,
chunk_request_id, chunk_request_id,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0), priority=ctx.request.priority,
) )
generators.append(original_generator) generators.append(original_generator)
...@@ -338,7 +325,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -338,7 +325,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator( async def _create_single_prompt_generator(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
engine_prompt: TokensPrompt, engine_prompt: TokensPrompt | EmbedsPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
...@@ -353,23 +340,25 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -353,23 +340,25 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Return the original generator without wrapping # Return the original generator without wrapping
return self.engine_client.encode( return self.engine_client.encode(
engine_prompt, engine_prompt,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0), priority=ctx.request.priority,
) )
async def _prepare_generators( async def _prepare_generators(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Override to support chunked processing.""" """Override to support chunked processing."""
ctx = cast(EmbeddingServeContext, ctx)
# Check if we should use chunked processing # Check if we should use chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request) use_chunked = self._should_use_chunked_processing(ctx.request)
...@@ -405,7 +394,8 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -405,7 +394,8 @@ class OpenAIServingEmbedding(OpenAIServing):
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_prompt in enumerate(ctx.engine_prompts):
# Check if this specific prompt needs chunked processing # Check if this specific prompt needs chunked processing
if "prompt_token_ids" in engine_prompt: if "prompt_token_ids" in engine_prompt:
prompt_token_ids = engine_prompt["prompt_token_ids"] prompt_token_ids = engine_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
if len(prompt_token_ids) > max_pos_embeddings: if len(prompt_token_ids) > max_pos_embeddings:
# Use chunked processing for this prompt # Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request( chunk_generators = await self._process_chunked_request(
...@@ -573,7 +563,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -573,7 +563,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"token IDs" "token IDs"
) )
original_token_ids = original_prompt["prompt_token_ids"] original_token_ids = original_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
pooling_request_output = PoolingRequestOutput( pooling_request_output = PoolingRequestOutput(
request_id=aggregator["request_id"], request_id=aggregator["request_id"],
......
...@@ -3,11 +3,10 @@ ...@@ -3,11 +3,10 @@
import time import time
from typing import Any, Generic, TypeAlias, TypeVar from typing import Any, Generic, TypeAlias, TypeVar
from pydantic import ( from pydantic import Field
Field,
)
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.config.pooler import get_use_activation from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.entrypoints.pooling.base.protocol import (
...@@ -18,6 +17,7 @@ from vllm.entrypoints.pooling.base.protocol import ( ...@@ -18,6 +17,7 @@ from vllm.entrypoints.pooling.base.protocol import (
EncodingRequestMixin, EncodingRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -30,6 +30,18 @@ class PoolingCompletionRequest( ...@@ -30,6 +30,18 @@ class PoolingCompletionRequest(
): ):
task: PoolingTask | None = None task: PoolingTask | None = None
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
...@@ -48,6 +60,18 @@ class PoolingChatRequest( ...@@ -48,6 +60,18 @@ class PoolingChatRequest(
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
......
...@@ -5,7 +5,7 @@ import asyncio ...@@ -5,7 +5,7 @@ import asyncio
import json import json
import time import time
from collections.abc import AsyncGenerator, Sequence from collections.abc import AsyncGenerator, Sequence
from typing import Final, cast from typing import Any, Final, cast
import jinja2 import jinja2
from fastapi import Request from fastapi import Request
...@@ -14,10 +14,7 @@ from typing_extensions import assert_never ...@@ -14,10 +14,7 @@ from typing_extensions import assert_never
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.entrypoints.pooling.pooling.protocol import (
...@@ -30,8 +27,6 @@ from vllm.entrypoints.pooling.pooling.protocol import ( ...@@ -30,8 +27,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingResponse, PoolingResponse,
PoolingResponseData, PoolingResponseData,
) )
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask from vllm.tasks import PoolingTask, SupportedTask
...@@ -99,11 +94,6 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -99,11 +94,6 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported" "dimensions is currently not supported"
) )
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens
)
if is_io_processor_request: if is_io_processor_request:
if self.io_processor is None: if self.io_processor is None:
raise ValueError( raise ValueError(
...@@ -134,19 +124,16 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -134,19 +124,16 @@ class OpenAIServingPooling(OpenAIServing):
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
request, request,
self.renderer,
request.messages, request.messages,
chat_template=request.chat_template or self.chat_template, default_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt, default_template_kwargs=None,
continue_final_message=request.continue_final_message,
add_special_tokens=request.add_special_tokens,
) )
elif isinstance(request, PoolingCompletionRequest): elif isinstance(request, PoolingCompletionRequest):
renderer = self._get_completion_renderer() engine_prompts = await self._preprocess_completion(
engine_prompts = await renderer.render_prompt( request,
prompt_or_prompts=request.input, prompt_input=request.input,
config=self._build_render_config(request), prompt_embeds=None,
) )
else: else:
raise ValueError(f"Unsupported request of type {type(request)}") raise ValueError(f"Unsupported request of type {type(request)}")
...@@ -207,11 +194,18 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -207,11 +194,18 @@ class OpenAIServingPooling(OpenAIServing):
else await self._get_trace_headers(raw_request.headers) else await self._get_trace_headers(raw_request.headers)
) )
if is_io_processor_request:
tokenization_kwargs: dict[str, Any] = {}
else:
tok_params = request.build_tok_params(self.model_config) # type: ignore
tokenization_kwargs = tok_params.get_encode_kwargs()
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_prompt,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
) )
...@@ -338,10 +332,3 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -338,10 +332,3 @@ class OpenAIServingPooling(OpenAIServing):
return encode_bytes(bytes_only=encoding_format == "bytes_only") return encode_bytes(bytes_only=encoding_format == "bytes_only")
else: else:
assert_never(encoding_format) assert_never(encoding_format)
def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment