Unverified Commit f0a1c845 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Use new Renderer for Completions and Tokenize API (#32863)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8980001c
......@@ -5,65 +5,10 @@ import pytest
from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_raw_prompts
from vllm.inputs.preprocess import InputPreprocessor
pytestmark = pytest.mark.cpu_test
STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
def test_invalid_input_raise_type_error(invalid_input):
with pytest.raises(TypeError):
parse_raw_prompts(invalid_input)
def test_parse_raw_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([])
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([[]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_parse_raw_single_batch_string_consistent(string_input: str):
assert parse_raw_prompts(string_input) == parse_raw_prompts([string_input])
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
assert parse_raw_prompts(token_input) == parse_raw_prompts([token_input])
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] == parse_raw_prompts(
STRING_INPUTS[inputs_slice]
)
@pytest.mark.parametrize(
"mm_processor_kwargs,expected_mm_kwargs",
......
......@@ -768,7 +768,7 @@ class ModelConfig:
)
self.tokenizer = object_storage_tokenizer.dir
def _get_encoder_config(self):
def _get_encoder_config(self) -> dict[str, Any] | None:
model = self.model
if is_remote_gguf(model):
model, _ = split_remote_gguf(model)
......@@ -1918,7 +1918,7 @@ def _get_and_verify_max_len(
disable_sliding_window: bool,
sliding_window: int | None,
spec_target_max_model_len: int | None = None,
encoder_config: Any | None = None,
encoder_config: dict[str, Any] | None = None,
) -> int:
"""Get and verify the model's maximum length."""
(derived_max_model_len, max_len_key) = (
......
......@@ -72,14 +72,9 @@ class EngineClient(ABC):
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.
NOTE: truncate_prompt_tokens is deprecated in v0.14.
TODO: Remove this argument in v0.15.
"""
"""Generate outputs for a request from a pooling model."""
...
@abstractmethod
......
......@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeAlias, cast
import cloudpickle
import torch.nn as nn
......@@ -46,15 +47,17 @@ from vllm.entrypoints.pooling.score.utils import (
compress_token_type_ids,
get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs import (
DataPrompt,
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
PromptType,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
from vllm.inputs.parse import get_prompt_components
from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -67,6 +70,7 @@ from vllm.outputs import (
)
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike
......@@ -74,7 +78,6 @@ from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor
......@@ -85,6 +88,9 @@ logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
......@@ -372,6 +378,7 @@ class LLM:
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[RequestOutput]:
"""Generates the completions for the input prompts.
......@@ -398,15 +405,11 @@ class LLM:
If provided, must be a list of integers matching the length
of `prompts`, where each priority value corresponds to the prompt
at the same index.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts.
Note:
Using `prompts` and `prompt_token_ids` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
"""
model_config = self.model_config
runner_type = model_config.runner_type
......@@ -418,17 +421,14 @@ class LLM:
)
if sampling_params is None:
# Use default sampling params.
sampling_params = self.get_default_sampling_params()
# Add any modality specific loras to the corresponding prompts
lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
self._validate_and_add_requests(
prompts=prompts,
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request),
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
......@@ -771,65 +771,169 @@ class LLM:
return outputs
def preprocess_chat(
def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs)
def _normalize_prompts(
self,
messages: list[ChatCompletionMessageParam]
prompts: PromptType | Sequence[PromptType],
) -> list[EnginePrompt | EngineEncDecPrompt]:
if isinstance(prompts, str):
prompts = TextPrompt(prompt=prompts)
return prompts if isinstance(prompts, Sequence) else [prompts] # type: ignore[return-value]
def _preprocess_cmpl_singleton(
self,
prompt: SingletonPrompt,
tok_params: TokenizeParams,
*,
tokenize: bool,
) -> EnginePrompt:
renderer = self.llm_engine.renderer
if not isinstance(prompt, dict):
prompt = renderer.render_completion(prompt)
return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt
def _preprocess_cmpl_enc_dec(
self,
prompt: ExplicitEncoderDecoderPrompt,
tok_params: TokenizeParams,
) -> EngineEncDecPrompt:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return EngineEncDecPrompt(
encoder_prompt=self._preprocess_cmpl_singleton(
enc_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
),
decoder_prompt=(
None
if dec_prompt is None
else self._preprocess_cmpl_singleton(
dec_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
)
),
)
def _preprocess_completion(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[EnginePrompt | EngineEncDecPrompt]:
"""
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`.
Refer to [LLM.generate][] for a complete description of the arguments.
Returns:
A list of `TokensPrompts` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
engine_prompts = list[EnginePrompt | EngineEncDecPrompt]()
for prompt in self._normalize_prompts(prompts):
if is_explicit_encoder_decoder_prompt(prompt):
engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params))
else:
# Some MM models have non-default `add_special_tokens`
# TODO: Move multi-modal processor into tokenization
engine_prompts.append(
self._preprocess_cmpl_singleton(
prompt,
tok_params,
tokenize=not self.model_config.is_multimodal_model,
)
)
return engine_prompts
def _normalize_conversations(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
) -> list[list[ChatCompletionMessageParam]]:
return conversations if is_list_of(conversations, list) else [conversations] # type: ignore[list-item,return-value]
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=False,
).with_kwargs(tokenization_kwargs)
def _preprocess_chat(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
chat_template_kwargs: dict[str, Any] | None = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[TextPrompt | TokensPrompt]:
) -> list[EnginePrompt]:
"""
Generate prompt for a chat conversation. The pre-processed
prompt can then be used as input for the other LLM methods.
Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs.
Refer to [LLM.chat][] for a complete description of the arguments.
Refer to `chat` for a complete description of the arguments.
Returns:
A list of `TokensPrompts` objects containing the tokenized
prompt after chat template interpolation, and the
pre-processed multi-modal inputs.
A list of `TokensPrompts` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
list_of_messages: list[list[ChatCompletionMessageParam]]
# Handle multi and single conversations
if is_list_of(messages, list):
# messages is list[list[...]]
list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
else:
# messages is list[...]
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
renderer = self.llm_engine.renderer
chat_template_kwargs = {
"chat_template": chat_template,
"add_generation_prompt": add_generation_prompt,
"continue_final_message": continue_final_message,
"tools": tools,
**(chat_template_kwargs or {}),
}
prompts = list[TextPrompt | TokensPrompt]()
for msgs in list_of_messages:
# NOTE: renderer.render_messages() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
_, prompt = renderer.render_messages(
msgs,
chat_template_content_format=chat_template_content_format,
**chat_template_kwargs,
)
chat_params = ChatParams(
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=merge_kwargs(
chat_template_kwargs,
dict(
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
),
),
)
tok_params = self._get_chat_tok_params(tokenization_kwargs)
engine_prompts = list[EnginePrompt]()
for conversation in self._normalize_conversations(conversations):
_, in_prompt = renderer.render_messages(conversation, chat_params)
if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs
in_prompt["mm_processor_kwargs"] = mm_processor_kwargs
prompts.append(prompt)
engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
return prompts
return engine_prompts
def chat(
self,
......@@ -844,6 +948,7 @@ class LLM:
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[RequestOutput]:
"""
......@@ -889,22 +994,22 @@ class LLM:
`True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
tokenization_kwargs: Overrides for `tokenizer.encode`.
mm_processor_kwargs: Overrides for `processor.__call__`.
Returns:
A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages.
"""
prompts = self.preprocess_chat(
messages=messages,
prompts = self._preprocess_chat(
messages,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
chat_template_kwargs=chat_template_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
......@@ -913,6 +1018,7 @@ class LLM:
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
def encode(
......@@ -945,37 +1051,29 @@ class LLM:
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_task: Override the pooling task to use.
tokenization_kwargs: overrides tokenization_kwargs set in
pooling_params
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts.
Note:
Using `prompts` and `prompt_token_ids` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
"""
error_str = (
"pooling_task required for `LLM.encode`\n"
"Please use one of the more specific methods or set the "
"pooling_task when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For similarity scores, use `LLM.score(...)`.\n"
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="token_classify"`\n'
" - For token classification, "
'use `pooling_task="token_classify"`\n'
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
)
if pooling_task is None:
raise ValueError(error_str)
raise ValueError(
"pooling_task required for `LLM.encode`\n"
"Please use one of the more specific methods or set the "
"pooling_task when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For similarity scores, use `LLM.score(...)`.\n"
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="token_classify"`\n'
" - For token classification, "
'use `pooling_task="token_classify"`\n'
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
)
model_config = self.model_config
runner_type = model_config.runner_type
......@@ -986,6 +1084,20 @@ class LLM:
"pooling model."
)
if truncate_prompt_tokens is not None:
warnings.warn(
"The `truncate_prompt_tokens` parameter in `LLM.encode()` "
"is deprecated and will be removed in v0.16. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
io_processor_prompt = False
if isinstance(prompts, dict) and "data" in prompts:
io_processor_prompt = True
......@@ -1017,19 +1129,16 @@ class LLM:
pooling_params = self.io_processor.validate_or_generate_params(
pooling_params
)
else:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
# for backwards compatibility
if truncate_prompt_tokens is not None:
param.truncate_prompt_tokens = truncate_prompt_tokens
self._validate_and_add_requests(
prompts=prompts,
......@@ -1094,6 +1203,7 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
......@@ -1105,9 +1215,14 @@ class LLM:
"Try converting the model using `--convert embed`."
)
if truncate_prompt_tokens is not None:
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
items = self.encode(
prompts,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
......@@ -1121,8 +1236,8 @@ class LLM:
self,
prompts: PromptType | Sequence[PromptType],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[ClassificationRequestOutput]:
......@@ -1137,13 +1252,15 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompt.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts.
......@@ -1170,9 +1287,9 @@ class LLM:
prompts: PromptType | Sequence[PromptType],
/,
*,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[PoolingRequestOutput]:
......@@ -1183,13 +1300,15 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompt.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts.
......@@ -1207,18 +1326,18 @@ class LLM:
def _embedding_score(
self,
tokenizer: TokenizerLike,
text_1: list[str | TextPrompt | TokensPrompt],
text_2: list[str | TextPrompt | TokensPrompt],
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
text_1: list[SingletonPrompt],
text_2: list[SingletonPrompt],
*,
use_tqdm: bool | Callable[..., tqdm],
pooling_params: PoolingParams | None,
lora_request: list[LoRARequest] | LoRARequest | None,
tokenization_kwargs: dict[str, Any],
) -> list[ScoringRequestOutput]:
encoded_output: list[PoolingRequestOutput] = self.encode(
tokenizer = self.get_tokenizer()
encoded_output = self.encode(
text_1 + text_2,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
......@@ -1226,14 +1345,16 @@ class LLM:
tokenization_kwargs=tokenization_kwargs,
)
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
encoded_output_1 = encoded_output[0 : len(text_1)]
encoded_output_2 = encoded_output[len(text_1) :]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
scores = _cosine_similarity(
tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
tokenizer=tokenizer,
embed_1=encoded_output_1,
embed_2=encoded_output_2,
)
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
......@@ -1241,17 +1362,17 @@ class LLM:
def _cross_encoding_score(
self,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
score_template: str | None = None,
*,
use_tqdm: bool | Callable[..., tqdm],
pooling_params: PoolingParams | None,
lora_request: list[LoRARequest] | LoRARequest | None,
tokenization_kwargs: dict[str, Any],
score_template: str | None,
) -> list[ScoringRequestOutput]:
model_config = self.model_config
tokenizer = self.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
raise ValueError("Score API is not supported for Mistral tokenizer")
......@@ -1265,13 +1386,6 @@ class LLM:
pooling_params.verify("score", model_config)
pooling_params_list = list[PoolingParams]()
local_kwargs = tokenization_kwargs or {}
tokenization_kwargs = local_kwargs.copy()
_validate_truncation_size(
model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
)
prompts = list[PromptType]()
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
......@@ -1314,10 +1428,10 @@ class LLM:
data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
/,
*,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
chat_template: str | None = None,
) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs `<text,text_pair>` or
......@@ -1344,20 +1458,22 @@ class LLM:
the LLM. Can be text or multi-modal data. See [PromptType]
[vllm.inputs.PromptType] for more details about the format of
each prompt.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
chat_template: The chat template to use for the scoring. If None, we
use the model's default chat template.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `ScoringRequestOutput` objects containing the
generated scores in the same order as the input prompts.
"""
model_config = self.model_config
runner_type = model_config.runner_type
if runner_type != "pooling":
raise ValueError(
......@@ -1445,26 +1561,27 @@ class LLM:
_validate_score_input_lens(data_1, data_2) # type: ignore[arg-type]
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
encode_kwargs = tok_params.get_encode_kwargs()
if model_config.is_cross_encoder:
return self._cross_encoding_score(
tokenizer,
data_1, # type: ignore[arg-type]
data_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
pooling_params,
lora_request,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=encode_kwargs,
score_template=chat_template,
)
else:
return self._embedding_score(
tokenizer,
data_1, # type: ignore[arg-type]
data_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
pooling_params,
lora_request,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=encode_kwargs,
)
def start_profile(self) -> None:
......@@ -1530,42 +1647,79 @@ class LLM:
def _validate_and_add_requests(
self,
prompts: PromptType | Sequence[PromptType] | DataPrompt,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams
| Sequence[SamplingParams]
| PoolingParams
| Sequence[PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None,
priority: list[int] | None = None,
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
) -> None:
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts] # type: ignore[list-item]
num_requests = len(prompts)
if isinstance(params, Sequence) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params must be the same.")
if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
raise ValueError(
"The lengths of prompts and lora_request must be the same."
)
if priority is not None and len(priority) != num_requests:
raise ValueError(
"The lengths of prompts "
f"({num_requests}) and priority ({len(priority)}) "
"must be the same."
in_prompts = self._normalize_prompts(prompts)
num_requests = len(in_prompts)
if isinstance(params, Sequence):
if len(params) != num_requests:
raise ValueError(
f"The lengths of prompts ({params}) "
f"and lora_request ({len(params)}) must be the same."
)
engine_params = params
else:
engine_params = [params] * num_requests
if isinstance(lora_request, Sequence):
if len(lora_request) != num_requests:
raise ValueError(
f"The lengths of prompts ({num_requests}) "
f"and lora_request ({len(lora_request)}) must be the same."
)
engine_lora_requests: Sequence[LoRARequest | None] = lora_request
else:
engine_lora_requests = [lora_request] * num_requests
if priority is not None:
if len(priority) != num_requests:
raise ValueError(
f"The lengths of prompts ({num_requests}) "
f"and priority ({len(priority)}) must be the same."
)
else:
priority = [0] * num_requests
if any(param.truncate_prompt_tokens is not None for param in engine_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let
# `self._preprocess_completion` handle prompt normalization
engine_prompts = [
engine_prompt
for in_prompt, param in zip(in_prompts, engine_params)
for engine_prompt in self._preprocess_completion(
[in_prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
),
)
]
else:
engine_prompts = self._preprocess_completion(
in_prompts,
tokenization_kwargs=tokenization_kwargs,
)
for sp in params if isinstance(params, Sequence) else (params,):
for sp in engine_params:
if isinstance(sp, SamplingParams):
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
it = prompts
it = engine_prompts
if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
......@@ -1576,12 +1730,10 @@ class LLM:
for i, prompt in enumerate(it):
request_id = self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i]
if isinstance(lora_request, Sequence)
else lora_request,
priority=priority[i] if priority else 0,
engine_params[i],
lora_request=engine_lora_requests[i],
tokenization_kwargs=tokenization_kwargs,
priority=priority[i],
)
added_request_ids.append(request_id)
except Exception as e:
......@@ -1589,54 +1741,42 @@ class LLM:
self.llm_engine.abort_request(added_request_ids, internal=True)
raise e
def _process_inputs(
self,
request_id: str,
engine_prompt: PromptType,
params: SamplingParams | PoolingParams,
*,
lora_request: LoRARequest | None,
priority: int,
tokenization_kwargs: dict[str, Any] | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for LLMEngine."""
local_kwargs = tokenization_kwargs or {}
tokenization_kwargs = local_kwargs.copy()
_validate_truncation_size(
self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs,
)
engine_request = self.input_processor.process_inputs(
request_id,
engine_prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return engine_request, tokenization_kwargs
def _add_request(
self,
prompt: PromptType,
params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None,
priority: int = 0,
tokenization_kwargs: dict[str, Any] | None = None,
priority: int = 0,
) -> str:
prompt_text, _, _ = get_prompt_components(prompt)
request_id = str(next(self.request_counter))
engine_request, tokenization_kwargs = self._process_inputs(
if params.truncate_prompt_tokens is not None:
params_type = type(params).__name__
warnings.warn(
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
"is deprecated and will be removed in v0.16. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
)
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id,
prompt,
params,
lora_request=lora_request,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
self.llm_engine.add_request(
......
......@@ -13,12 +13,13 @@ from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
from pydantic import (
Field,
model_validator,
)
from pydantic import Field, model_validator
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat,
DeltaMessage,
......@@ -36,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
......@@ -348,6 +350,43 @@ class ChatCompletionRequest(OpenAIBaseModel):
# --8<-- [end:chat-completion-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
documents=self.documents,
reasoning_effort=self.reasoning_effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
if self.max_completion_tokens is not None:
max_output_tokens: int | None = self.max_completion_tokens
max_output_tokens_param = "max_completion_tokens"
else:
max_output_tokens = self.max_tokens
max_output_tokens_param = "max_tokens"
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=max_output_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param=max_output_tokens_param,
)
# Default sampling parameters for chat completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
......
......@@ -67,7 +67,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
)
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob
......@@ -185,8 +185,6 @@ class OpenAIServingChat(OpenAIServing):
start_time = time.perf_counter()
try:
renderer = self.engine_client.renderer
# Create a minimal dummy request
dummy_request = ChatCompletionRequest(
messages=[{"role": "user", "content": "warmup"}],
......@@ -201,18 +199,10 @@ class OpenAIServingChat(OpenAIServing):
# 3. Tokenizer initialization for chat
await self._preprocess_chat(
dummy_request,
renderer,
dummy_request.messages,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=True,
continue_final_message=False,
tool_dicts=None,
documents=None,
chat_template_kwargs=None,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=None,
add_special_tokens=False,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
)
elapsed = (time.perf_counter() - start_time) * 1000
......@@ -225,7 +215,10 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request(
self,
request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[Any]] | ErrorResponse:
) -> (
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
| ErrorResponse
):
"""
render chat request by validating and preprocessing inputs.
......@@ -302,23 +295,14 @@ class OpenAIServingChat(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
chat_template_kwargs = request.chat_template_kwargs or {}
chat_template_kwargs.update(reasoning_effort=request.reasoning_effort)
conversation, engine_prompts = await self._preprocess_chat(
request,
renderer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
documents=request.documents,
chat_template_kwargs=chat_template_kwargs,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=tool_parser,
add_special_tokens=request.add_special_tokens,
)
else:
# For GPT-OSS.
......@@ -428,11 +412,15 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
......
......@@ -9,11 +9,9 @@ from dataclasses import replace
from typing import Annotated, Any, Literal
import torch
from pydantic import (
Field,
model_validator,
)
from pydantic import Field, model_validator
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat,
LegacyStructuralTagResponseFormat,
......@@ -27,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
......@@ -178,6 +177,17 @@ class CompletionRequest(OpenAIBaseModel):
# --8<-- [end:completion-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_tokens",
)
# Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
......
......@@ -32,7 +32,6 @@ from vllm.entrypoints.openai.engine.serving import (
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
......@@ -111,11 +110,10 @@ class OpenAIServingCompletion(OpenAIServing):
)
try:
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds,
config=self._build_render_config(request),
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
......@@ -203,10 +201,6 @@ class OpenAIServingCompletion(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
......@@ -216,11 +210,15 @@ class OpenAIServingCompletion(OpenAIServing):
trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id_item,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
......@@ -709,26 +707,3 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: int | None = None,
) -> RenderConfig:
# Validate max_tokens before using it
if request.max_tokens is not None and request.max_tokens > self.max_model_len:
raise VLLMValidationError(
f"'max_tokens' ({request.max_tokens}) cannot be greater than "
f"the model's maximum context length ({self.max_model_len}).",
parameter="max_tokens",
value=request.max_tokens,
)
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo and not request.return_token_ids),
)
......@@ -16,9 +16,7 @@ from pydantic import (
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.logger import init_logger
from vllm.sampling_params import (
SamplingParams,
)
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
......
......@@ -5,10 +5,10 @@ import json
import sys
import time
import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
from collections.abc import AsyncGenerator, Callable, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
import numpy as np
from fastapi import Request
......@@ -20,6 +20,7 @@ from starlette.datastructures import Headers
import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
......@@ -86,7 +87,6 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreResponse,
ScoreTextRequest,
)
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest,
......@@ -94,13 +94,9 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.utils import (
_validate_truncation_size,
get_max_tokens,
sanitize_message,
)
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt
from vllm.inputs.parse import (
get_prompt_components,
is_explicit_encoder_decoder_prompt,
......@@ -112,7 +108,7 @@ from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.renderers import RendererLike
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
......@@ -123,11 +119,9 @@ from vllm.tracing import (
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
collect_from_async_generator,
merge_async_iterators,
)
from vllm.v1.engine import EngineCoreRequest
class GenerationError(Exception):
......@@ -140,6 +134,21 @@ class GenerationError(Exception):
logger = init_logger(__name__)
class RendererRequest(Protocol):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
raise NotImplementedError
class RendererChatRequest(RendererRequest, Protocol):
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
raise NotImplementedError
CompletionLikeRequest: TypeAlias = (
CompletionRequest
| TokenizeCompletionRequest
......@@ -158,7 +167,9 @@ ChatLikeRequest: TypeAlias = (
| ClassificationChatRequest
| PoolingChatRequest
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
CompletionLikeRequest
| ChatLikeRequest
......@@ -193,7 +204,7 @@ class ServeContext(Generic[RequestT]):
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt] | None = None
engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
......@@ -227,7 +238,6 @@ class OpenAIServing:
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
self.input_processor = self.models.input_processor
......@@ -519,41 +529,6 @@ class OpenAIServing:
prompt_logprobs=None,
)
def _get_completion_renderer(self) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
"""
return CompletionRenderer(
model_config=self.model_config,
tokenizer=self.renderer.tokenizer,
async_tokenizer_pool=self._async_tokenizer_pool,
)
def _build_render_config(
self,
request: Any,
) -> RenderConfig:
"""
Build and return a `RenderConfig` for an endpoint.
Used by the renderer to control how prompts are prepared
(e.g., tokenization and length handling). Endpoints should
implement this with logic appropriate to their request type.
"""
raise NotImplementedError
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
given tokenizer.
"""
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self._async_tokenizer_pool[tokenizer] = async_tokenizer
return async_tokenizer
async def _preprocess(
self,
ctx: ServeContext,
......@@ -912,71 +887,6 @@ class OpenAIServing:
message_types.add(content_dict["type"].split("_")[0])
return message_types
async def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
prompt: str,
tokenizer: TokenizerLike,
add_special_tokens: bool,
) -> TokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
if (
self.model_config.encoder_config is not None
and self.model_config.encoder_config.get("do_lower_case", False)
):
prompt = prompt.lower()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(
prompt, add_special_tokens=add_special_tokens
)
elif truncate_prompt_tokens < 0:
# Negative means we cap at the model's max length
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=self.max_model_len,
)
else:
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens,
)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
async def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
prompt_ids: list[int],
tokenizer: TokenizerLike | None,
) -> TokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
if truncate_prompt_tokens is None:
input_ids = prompt_ids
elif truncate_prompt_tokens < 0:
input_ids = prompt_ids[-self.max_model_len :]
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
if tokenizer is None:
input_text = ""
else:
async_tokenizer = self._get_async_tokenizer(tokenizer)
input_text = await async_tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: object,
......@@ -1061,50 +971,6 @@ class OpenAIServing:
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
async def _tokenize_prompt_input_async(
self,
request: AnyRequest,
tokenizer: TokenizerLike,
prompt_input: str | list[int],
add_special_tokens: bool = True,
) -> TokensPrompt:
"""
A simpler implementation that tokenizes a single prompt input.
"""
async for result in self._tokenize_prompt_inputs_async(
request,
tokenizer,
[prompt_input],
add_special_tokens=add_special_tokens,
):
return result
raise ValueError("No results yielded from tokenization")
async def _tokenize_prompt_inputs_async(
self,
request: AnyRequest,
tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True,
) -> AsyncGenerator[TokensPrompt, None]:
"""
A simpler implementation that tokenizes multiple prompt inputs.
"""
for prompt in prompt_inputs:
if isinstance(prompt, str):
yield await self._normalize_prompt_text_to_input(
request,
prompt=prompt,
tokenizer=tokenizer,
add_special_tokens=add_special_tokens,
)
else:
yield await self._normalize_prompt_tokens_to_input(
request,
prompt_ids=prompt,
tokenizer=tokenizer,
)
def _validate_chat_template(
self,
request_chat_template: str | None,
......@@ -1137,131 +1003,94 @@ class OpenAIServing:
# Apply server defaults first, then request kwargs override.
return default_chat_template_kwargs | request_chat_template_kwargs
async def _preprocess_completion(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokensPrompt | EmbedsPrompt]:
renderer = self.renderer
tok_params = request.build_tok_params(self.model_config)
in_prompts = await renderer.render_completions_async(
prompt_input, prompt_embeds
)
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
for prompt in engine_prompts:
prompt.update(extra_items) # type: ignore
return engine_prompts
async def _preprocess_chat(
self,
request: ChatLikeRequest | ResponsesRequest,
renderer: RendererLike,
request: RendererChatRequest,
messages: list[ChatCompletionMessageParam],
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
documents: list[dict[str, str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
default_chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
chat_template_kwargs = {
"chat_template": chat_template,
"add_generation_prompt": add_generation_prompt,
"continue_final_message": continue_final_message,
"tools": tool_dicts,
"documents": documents,
**(chat_template_kwargs or {}),
}
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
chat_template_kwargs,
default_chat_template_kwargs,
)
# Use the async tokenizer in `OpenAIServing` if possible.
# Later we can move it into the renderer so that we can return both
# text and token IDs in the same prompt from `render_messages_async`
# which is used for logging and `enable_response_messages`.
) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]:
from vllm.tokenizers.mistral import MistralTokenizer
conversation, engine_prompt = await renderer.render_messages_async(
messages,
chat_template_content_format=chat_template_content_format,
tokenize=(
chat_template_kwargs.pop("tokenize", False)
or isinstance(renderer.tokenizer, MistralTokenizer)
renderer = self.renderer
default_template_kwargs = merge_kwargs(
default_template_kwargs,
dict(
tools=tool_dicts,
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
),
**chat_template_kwargs,
)
if "prompt_token_ids" not in engine_prompt:
extra_data = engine_prompt
engine_prompt = await self._tokenize_prompt_input_async(
request,
renderer.get_tokenizer(),
engine_prompt["prompt"],
add_special_tokens=add_special_tokens,
)
# Fill in other keys like MM data
engine_prompt.update(extra_data) # type: ignore
else:
self._validate_input(
request=request,
input_ids=engine_prompt["prompt_token_ids"], # type: ignore
input_text="",
)
tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
engine_prompt = cast(TokensPrompt, engine_prompt)
conversation, prompt = await renderer.render_messages_async(
messages, chat_params
)
engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params)
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
if (cache_salt := getattr(request, "cache_salt", None)) is not None:
engine_prompt["cache_salt"] = cache_salt
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
engine_prompt.update(extra_items) # type: ignore
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
should_parse_tools = tool_parser is not None and (
hasattr(request, "tool_choice") and request.tool_choice != "none"
)
if should_parse_tools:
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = (
"Tool usage is only supported for Chat Completions API "
"or Responses API requests."
)
raise NotImplementedError(msg)
if tool_parser is not None:
tool_choice = getattr(request, "tool_choice", "none")
if tool_choice != "none":
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = (
"Tool usage is only supported for Chat Completions API "
"or Responses API requests."
)
raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore
# TODO: Update adjust_request to accept ResponsesRequest
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type]
return conversation, [engine_prompt]
async def _process_inputs(
self,
request_id: str,
engine_prompt: PromptType,
params: SamplingParams | PoolingParams,
*,
lora_request: LoRARequest | None,
trace_headers: Mapping[str, str] | None,
priority: int,
data_parallel_rank: int | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(
self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
)
engine_request = self.input_processor.process_inputs(
request_id,
engine_prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return engine_request, tokenization_kwargs
async def _render_next_turn(
self,
request: ResponsesRequest,
renderer: RendererLike,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser,
tool_parser: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
......@@ -1271,24 +1100,25 @@ class OpenAIServing:
_, engine_prompts = await self._preprocess_chat(
request,
renderer,
new_messages,
default_template=chat_template,
default_template_content_format=chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
)
return engine_prompts
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: TokensPrompt,
engine_prompt: TokensPrompt | EmbedsPrompt,
sampling_params: SamplingParams,
tok_params: TokenizeParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
priority: int = 0,
**kwargs,
trace_headers: Mapping[str, str] | None = None,
):
prompt_text, _, _ = get_prompt_components(engine_prompt)
......@@ -1297,18 +1127,21 @@ class OpenAIServing:
while True:
# Ensure that each sub-request has a unique request id.
sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs(
sub_request_id,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = kwargs.get("trace_headers")
engine_request, tokenization_kwargs = await self._process_inputs(
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
)
......@@ -1318,10 +1151,10 @@ class OpenAIServing:
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
**kwargs,
)
async for res in generator:
......@@ -1350,7 +1183,6 @@ class OpenAIServing:
elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn(
context.request,
context.renderer,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,
......
......@@ -43,7 +43,6 @@ from vllm.entrypoints.openai.responses.protocol import (
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.renderers import RendererLike
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.utils import random_uuid
......@@ -261,7 +260,7 @@ class ParsableContext(ConversationContext):
self,
*,
response_messages: list[ResponseInputOutputItem],
renderer: RendererLike,
tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
request: ResponsesRequest,
available_tools: list[str] | None,
......@@ -280,7 +279,6 @@ class ParsableContext(ConversationContext):
if reasoning_parser_cls is None:
raise ValueError("reasoning_parser_cls must be provided.")
tokenizer = renderer.get_tokenizer()
self.parser = get_responses_parser_for_simple_context(
tokenizer=tokenizer,
reasoning_parser_cls=reasoning_parser_cls,
......@@ -290,8 +288,6 @@ class ParsableContext(ConversationContext):
)
self.tool_parser_cls = tool_parser_cls
self.request = request
self.renderer = renderer
self.tokenizer = tokenizer
self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {}
......
......@@ -59,12 +59,15 @@ from pydantic import (
model_validator,
)
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import (
OpenAIBaseModel,
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
RequestOutputKind,
SamplingParams,
......@@ -230,6 +233,42 @@ class ResponsesRequest(OpenAIBaseModel):
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
# --8<-- [end:responses-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
from .utils import should_continue_final_message
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(self.input)
reasoning = self.reasoning
return ChatParams(
chat_template=default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs( # To remove unset values
{},
dict(
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
reasoning_effort=None if reasoning is None else reasoning.effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_output_tokens or 0,
truncate_prompt_tokens=-1 if self.truncation != "disabled" else None,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_output_tokens",
)
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,
......
......@@ -114,16 +114,15 @@ from vllm.entrypoints.openai.responses.utils import (
construct_input_messages,
construct_tool_dicts,
extract_tool_types,
should_continue_final_message,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
......@@ -291,13 +290,14 @@ class OpenAIServingResponses(OpenAIServing):
self.tool_server = tool_server
def _validate_generator_input(
self, engine_prompt: TokensPrompt
self,
engine_prompt: TokensPrompt | EmbedsPrompt,
) -> ErrorResponse | None:
"""Add validations to the input to the generator here."""
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]):
prompt_len = get_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len:
error_message = (
"The engine prompt length"
f" {len(engine_prompt['prompt_token_ids'])} "
f"The engine prompt length {prompt_len} "
f"exceeds the max_model_len {self.max_model_len}. "
"Please reduce prompt."
)
......@@ -307,6 +307,7 @@ class OpenAIServingResponses(OpenAIServing):
status_code=HTTPStatus.BAD_REQUEST,
param="input",
)
return None
def _validate_create_responses_input(
......@@ -387,8 +388,6 @@ class OpenAIServingResponses(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request)
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony(
......@@ -396,7 +395,7 @@ class OpenAIServingResponses(OpenAIServing):
)
else:
messages, engine_prompts = await self._make_request(
request, prev_response, renderer
request, prev_response
)
except (
......@@ -431,6 +430,9 @@ class OpenAIServingResponses(OpenAIServing):
assert len(builtin_tool_list) == 0
available_tools = []
try:
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt)
if maybe_error is not None:
......@@ -446,6 +448,7 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
tok_params = request.build_tok_params(self.model_config)
trace_headers = (
None
......@@ -465,7 +468,7 @@ class OpenAIServingResponses(OpenAIServing):
# tokens during generation instead of at the end
context = ParsableContext(
response_messages=messages,
renderer=renderer,
tokenizer=tokenizer,
reasoning_parser_cls=self.reasoning_parser,
request=request,
tool_parser_cls=self.tool_parser,
......@@ -495,6 +498,7 @@ class OpenAIServingResponses(OpenAIServing):
request_id=request.request_id,
engine_prompt=engine_prompt,
sampling_params=sampling_params,
tok_params=tok_params,
context=context,
lora_request=lora_request,
priority=request.priority,
......@@ -596,7 +600,6 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
prev_response: ResponsesResponse | None,
renderer: RendererLike,
):
tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
# Construct the input messages.
......@@ -606,30 +609,15 @@ class OpenAIServingResponses(OpenAIServing):
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
prev_response_output=prev_response.output if prev_response else None,
)
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(request.input)
chat_template_kwargs = dict(
reasoning_effort=None
if request.reasoning is None
else request.reasoning.effort
)
_, engine_prompts = await self._preprocess_chat(
request,
renderer,
messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=self.tool_parser,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
# When continuing a partial message, we set continue_final_message=True
# and add_generation_prompt=False so the model continues the message
# rather than starting a new one.
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
chat_template_kwargs=chat_template_kwargs,
)
return messages, engine_prompts
......
......@@ -8,8 +8,12 @@ from pydantic import Field, model_validator
from vllm import PoolingParams
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.renderers import ChatParams, merge_kwargs
from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
......@@ -119,6 +123,23 @@ class ChatRequestMixin(OpenAIBaseModel):
)
return data
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
),
),
)
class EncodingRequestMixin(OpenAIBaseModel):
# --8<-- [start:encoding-params]
......
......@@ -4,10 +4,9 @@
import time
from typing import Any, TypeAlias
from pydantic import (
Field,
)
from pydantic import Field
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
......@@ -15,13 +14,24 @@ from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
class ClassificationCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
):
pass
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
class ClassificationChatRequest(
......@@ -33,6 +43,18 @@ class ClassificationChatRequest(
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import Final, cast
from typing import Final, TypeAlias
import jinja2
import numpy as np
......@@ -21,15 +20,14 @@ from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
from vllm.outputs import ClassificationOutput
from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
ClassificationServeContext = ServeContext[ClassificationRequest]
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
......@@ -77,34 +75,18 @@ class ServingClassification(OpenAIServing):
if error_check_ret:
return error_check_ret
_, engine_prompts = await self._preprocess_chat(
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
ctx.engine_prompts = engine_prompts
elif isinstance(ctx.request, ClassificationCompletionRequest):
input_data = ctx.request.input
if input_data in (None, ""):
return self.create_error_response(
"Input or messages must be provided",
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(input_data, list) and not input_data:
ctx.engine_prompts = []
return None
renderer = self._get_completion_renderer()
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(ctx.request),
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
......@@ -128,7 +110,7 @@ class ServingClassification(OpenAIServing):
items: list[ClassificationData] = []
num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
final_res_batch_checked = ctx.final_res_batch
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
......@@ -161,13 +143,6 @@ class ServingClassification(OpenAIServing):
usage=usage,
)
def _build_render_config(self, request: ClassificationRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
async def create_classify(
self,
request: ClassificationRequest,
......
......@@ -3,10 +3,9 @@
import time
from typing import Any, TypeAlias
from pydantic import (
Field,
)
from pydantic import Field
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
......@@ -14,15 +13,47 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
def _get_max_total_output_tokens(
model_config: ModelConfig,
) -> tuple[int | None, int]:
max_total_tokens = model_config.max_model_len
pooler_config = model_config.pooler_config
if pooler_config is None:
return max_total_tokens, 0
if pooler_config.enable_chunked_processing:
return None, 0
max_embed_len = pooler_config.max_embed_len or max_total_tokens
max_output_tokens = max_total_tokens - max_embed_len
return max_total_tokens, max_output_tokens
class EmbeddingCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
pass
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
class EmbeddingChatRequest(
......@@ -33,6 +64,24 @@ class EmbeddingChatRequest(
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, cast
from typing import Any, Final, TypeAlias
import torch
from fastapi import Request
......@@ -22,8 +22,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponse,
EmbeddingResponseData,
)
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
......@@ -37,7 +36,7 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__)
EmbeddingServeContext = ServeContext[EmbeddingRequest]
EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest]
class OpenAIServingEmbedding(OpenAIServing):
......@@ -95,19 +94,16 @@ class OpenAIServingEmbedding(OpenAIServing):
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(ctx.request, EmbeddingCompletionRequest):
renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
......@@ -117,19 +113,6 @@ class OpenAIServingEmbedding(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
return RenderConfig(
max_length=max_length,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
def _build_response(
self,
ctx: EmbeddingServeContext,
......@@ -246,14 +229,18 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request,
)
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode(
chunk_engine_prompt,
pooling_params,
chunk_request_id,
lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
priority=ctx.request.priority,
)
generators.append(original_generator)
......@@ -338,7 +325,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: TokensPrompt,
engine_prompt: TokensPrompt | EmbedsPrompt,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
......@@ -353,23 +340,25 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request,
)
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Return the original generator without wrapping
return self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
priority=ctx.request.priority,
)
async def _prepare_generators(
self,
ctx: ServeContext,
ctx: EmbeddingServeContext,
) -> ErrorResponse | None:
"""Override to support chunked processing."""
ctx = cast(EmbeddingServeContext, ctx)
# Check if we should use chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request)
......@@ -405,7 +394,8 @@ class OpenAIServingEmbedding(OpenAIServing):
for i, engine_prompt in enumerate(ctx.engine_prompts):
# Check if this specific prompt needs chunked processing
if "prompt_token_ids" in engine_prompt:
prompt_token_ids = engine_prompt["prompt_token_ids"]
prompt_token_ids = engine_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
if len(prompt_token_ids) > max_pos_embeddings:
# Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request(
......@@ -573,7 +563,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"token IDs"
)
original_token_ids = original_prompt["prompt_token_ids"]
original_token_ids = original_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
pooling_request_output = PoolingRequestOutput(
request_id=aggregator["request_id"],
......
......@@ -3,11 +3,10 @@
import time
from typing import Any, Generic, TypeAlias, TypeVar
from pydantic import (
Field,
)
from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
......@@ -18,6 +17,7 @@ from vllm.entrypoints.pooling.base.protocol import (
EncodingRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
......@@ -30,6 +30,18 @@ class PoolingCompletionRequest(
):
task: PoolingTask | None = None
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
......@@ -48,6 +60,18 @@ class PoolingChatRequest(
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
......
......@@ -5,7 +5,7 @@ import asyncio
import json
import time
from collections.abc import AsyncGenerator, Sequence
from typing import Final, cast
from typing import Any, Final, cast
import jinja2
from fastapi import Request
......@@ -14,10 +14,7 @@ from typing_extensions import assert_never
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.pooling.protocol import (
......@@ -30,8 +27,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingResponse,
PoolingResponseData,
)
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask
......@@ -99,11 +94,6 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported"
)
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens
)
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
......@@ -134,19 +124,16 @@ class OpenAIServingPooling(OpenAIServing):
_, engine_prompts = await self._preprocess_chat(
request,
self.renderer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
add_special_tokens=request.add_special_tokens,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(request, PoolingCompletionRequest):
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.input,
config=self._build_render_config(request),
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.input,
prompt_embeds=None,
)
else:
raise ValueError(f"Unsupported request of type {type(request)}")
......@@ -207,11 +194,18 @@ class OpenAIServingPooling(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
if is_io_processor_request:
tokenization_kwargs: dict[str, Any] = {}
else:
tok_params = request.build_tok_params(self.model_config) # type: ignore
tokenization_kwargs = tok_params.get_encode_kwargs()
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
)
......@@ -338,10 +332,3 @@ class OpenAIServingPooling(OpenAIServing):
return encode_bytes(bytes_only=encoding_format == "bytes_only")
else:
assert_never(encoding_format)
def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment