Unverified Commit f0a1c845 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Use new Renderer for Completions and Tokenize API (#32863)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8980001c
......@@ -3,12 +3,10 @@
import time
from typing import Any, TypeAlias
from pydantic import (
BaseModel,
Field,
)
from pydantic import BaseModel, Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
......@@ -19,6 +17,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreContentPartParam,
ScoreMultiModalParam,
)
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
......@@ -30,6 +29,17 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
)
# --8<-- [end:score-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
......@@ -85,6 +95,17 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
)
# --8<-- [end:rerank-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
max_total_tokens_param="max_model_len",
)
class RerankDocument(BaseModel):
text: str | None = None
......
......@@ -34,7 +34,6 @@ from vllm.entrypoints.pooling.score.utils import (
compress_token_type_ids,
get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
......@@ -68,31 +67,31 @@ class ServingScores(OpenAIServing):
async def _embedding_score(
self,
tokenizer: TokenizerLike,
data_1: list[str],
data_2: list[str],
request: RerankRequest | ScoreRequest,
request_id: str,
tokenization_kwargs: dict[str, Any] | None = None,
lora_request: LoRARequest | None | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
input_texts = data_1 + data_2
model_config = self.model_config
tokenizer = self.renderer.get_tokenizer()
engine_prompts: list[TokensPrompt] = []
tokenize_async = make_async(
tokenizer.__call__, executor=self._tokenizer_executor
encode_async = make_async(
tokenizer.encode,
executor=self._tokenizer_executor,
)
tokenization_kwargs = tokenization_kwargs or {}
input_texts = data_1 + data_2
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
tokenized_prompts = await asyncio.gather(
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts)
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
)
engine_prompts: list[TokensPrompt] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(
request, tok_result["input_ids"], input_text
)
text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
......@@ -184,15 +183,16 @@ class ServingScores(OpenAIServing):
async def _cross_encoding_score(
self,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
request: RerankRequest | ScoreRequest,
request_id: str,
tokenization_kwargs: dict[str, Any] | None = None,
lora_request: LoRARequest | None | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
model_config = self.model_config
tokenizer = self.renderer.get_tokenizer()
request_prompts: list[str] = []
engine_prompts: list[TokensPrompt] = []
......@@ -202,12 +202,13 @@ class ServingScores(OpenAIServing):
if isinstance(tokenizer, MistralTokenizer):
raise ValueError("MistralTokenizer not supported for cross-encoding")
tokenization_kwargs = tokenization_kwargs or {}
tok_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
preprocess_async = make_async(
self._preprocess_score, executor=self._tokenizer_executor
self._preprocess_score,
executor=self._tokenizer_executor,
)
preprocessed_prompts = await asyncio.gather(
......@@ -215,7 +216,7 @@ class ServingScores(OpenAIServing):
preprocess_async(
request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
tokenization_kwargs=tok_kwargs,
data_1=t1,
data_2=t2,
)
......@@ -286,14 +287,6 @@ class ServingScores(OpenAIServing):
raw_request: Request | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
lora_request = self._maybe_get_adapters(request)
tokenizer = self.renderer.get_tokenizer()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(
self.max_model_len, truncate_prompt_tokens, tokenization_kwargs
)
trace_headers = (
None
......@@ -322,24 +315,20 @@ class ServingScores(OpenAIServing):
if self.model_config.is_cross_encoder:
return await self._cross_encoding_score(
tokenizer=tokenizer,
data_1=data_1, # type: ignore[arg-type]
data_2=data_2, # type: ignore[arg-type]
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
return await self._embedding_score(
tokenizer=tokenizer,
data_1=data_1, # type: ignore[arg-type]
data_2=data_2, # type: ignore[arg-type]
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
trace_headers=trace_headers,
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Annotated
import pybase64
import torch
from pydantic import Field
from vllm.config import ModelConfig
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
@dataclass(frozen=True)
class RenderConfig:
"""Configuration to control how prompts are prepared."""
max_length: int | None = None
"""Maximum allowable total input token length. If provided,
token inputs longer than this raise `ValueError`."""
truncate_prompt_tokens: int | None = None
"""Number of tokens to keep. `None` means no truncation.
`0` yields an empty list (and skips embeds).
`-1` maps to `model_config.max_model_len`."""
add_special_tokens: bool = True
"""Whether to add model-specific special tokens during tokenization."""
cache_salt: str | None = None
"""String to disambiguate prefix cache entries."""
needs_detokenization: bool | None = False
"""If True, detokenize IDs back to text for inclusion in outputs."""
def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> int | None:
"""Validate and normalize `truncate_prompt_tokens` parameter."""
truncate_prompt_tokens = self.truncate_prompt_tokens
if truncate_prompt_tokens is None or truncate_prompt_tokens == 0:
return truncate_prompt_tokens
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = model_config.max_model_len
max_length = self.max_length
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
raise ValueError(
f"{truncate_prompt_tokens=} cannot be greater than "
f"{max_length=}. Please select a smaller truncation size."
)
return truncate_prompt_tokens
class BaseRenderer(ABC):
"""
Base class for unified input processing and rendering.
The Renderer serves as a unified input processor that consolidates
tokenization, chat template formatting, and multimodal input handling
into a single component.
It converts high-level API requests (OpenAI-style JSON) into token IDs and
multimodal features ready for engine consumption.
Key responsibilities:
- Convert text prompts to token sequences with proper special tokens
- Apply chat templates and format conversations
- Handle multimodal inputs (images, audio, etc.) when applicable
- Manage prompt truncation and length validation
- Provide clean separation between API layer and engine core
"""
def __init__(
self,
model_config: ModelConfig,
tokenizer: TokenizerLike | None = None,
):
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
@abstractmethod
async def render_prompt(
self,
*,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
config: RenderConfig,
) -> list[TokensPrompt]:
"""
Convert text or token inputs into engine-ready TokensPrompt objects.
This method accepts text or token inputs and produces a
list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
for the engine.
Args:
prompt_or_prompts: One of:
- `str`: Single text prompt.
- `list[str]`: Batch of text prompts.
- `list[int]`: Single pre-tokenized sequence.
- `list[list[int]]`: Batch of pre-tokenized sequences.
config: Render configuration controlling how prompts are prepared
(e.g., tokenization and length handling).
Returns:
list[TokensPrompt]: Engine-ready token prompts.
Raises:
ValueError: If input formats are invalid or length limits exceeded.
"""
raise NotImplementedError
@abstractmethod
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
config: RenderConfig,
) -> list[TokensPrompt | EmbedsPrompt]:
"""
Convert text/token and/or base64-encoded embeddings inputs into
engine-ready prompt objects using a unified RenderConfig.
At least one of `prompt_or_prompts` or `prompt_embeds` must be
provided and non-empty. If both are omitted or empty (e.g., empty
string and empty list), a `ValueError` is raised.
Args:
prompt_or_prompts: Text or token inputs to include.
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
torch-saved tensor to be used as prompt embeddings.
config: Render configuration controlling how prompts are prepared
(e.g., tokenization and length handling).
Returns:
list[Union[TokensPrompt, EmbedsPrompt]]:
Engine-ready prompt objects.
Raises:
ValueError: If both `prompt_or_prompts` and `prompt_embeds`
are omitted or empty (decoder prompt cannot be empty), or if
length limits are exceeded.
"""
raise NotImplementedError
def load_prompt_embeds(
self,
prompt_embeds: bytes | list[bytes],
truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None,
cache_salt: str | None = None,
) -> list[EmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds:
raise VLLMValidationError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
parameter="prompt_embeds",
)
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
embeds_prompt = EmbedsPrompt(prompt_embeds=tensor)
if cache_salt is not None:
embeds_prompt["cache_salt"] = cache_salt
return embeds_prompt
if isinstance(prompt_embeds, list):
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
return [_load_and_validate_embed(prompt_embeds)]
class CompletionRenderer(BaseRenderer):
def __init__(
self,
model_config: ModelConfig,
tokenizer: TokenizerLike | None = None,
async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer]
| None = None,
):
super().__init__(model_config, tokenizer)
self.async_tokenizer_pool = async_tokenizer_pool
self.async_tokenizer: AsyncMicrobatchTokenizer | None = None
async def render_prompt(
self,
*,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
config: RenderConfig,
) -> list[TokensPrompt]:
"""Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class
for detailed parameter documentation.
"""
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
if truncate_prompt_tokens == 0:
return []
tasks = (
self._create_prompt(
prompt_input,
config=config,
truncate_prompt_tokens=truncate_prompt_tokens,
)
for prompt_input in parse_raw_prompts(prompt_or_prompts)
)
return await asyncio.gather(*tasks)
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
config: RenderConfig,
) -> list[TokensPrompt | EmbedsPrompt]:
"""
Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
"""
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
if truncate_prompt_tokens == 0:
return []
rendered: list[TokensPrompt | EmbedsPrompt] = []
if prompt_embeds is not None:
rendered.extend(
self.load_prompt_embeds(
prompt_embeds, truncate_prompt_tokens, config.cache_salt
)
)
if prompt_or_prompts is None or prompt_or_prompts == "":
return rendered
token_prompts = await self.render_prompt(
prompt_or_prompts=prompt_or_prompts,
config=config,
)
rendered.extend(token_prompts)
return rendered
def _maybe_apply_truncation(
self, token_ids: list[int], truncate_prompt_tokens: int | None
) -> list[int]:
"""Apply truncation to token sequence."""
if truncate_prompt_tokens is None:
return token_ids
if truncate_prompt_tokens >= len(token_ids):
return token_ids
return token_ids[-truncate_prompt_tokens:]
async def _create_prompt(
self,
prompt_input: TextPrompt | TokensPrompt,
config: RenderConfig,
truncate_prompt_tokens: int | None,
) -> TokensPrompt:
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
if prompt_token_ids is not None:
# NOTE: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
return await self._create_prompt_from_token_ids(
prompt_token_ids,
config.max_length,
truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization,
)
if prompt is not None:
return await self._create_prompt_from_text(
prompt,
config.max_length,
truncate_prompt_tokens,
config.add_special_tokens,
config.cache_salt,
)
# TODO: Also handle embeds prompt using this method
raise NotImplementedError
async def _create_prompt_from_text(
self,
text: str,
max_length: int | None,
truncate_prompt_tokens: int | None,
add_special_tokens: bool,
cache_salt: str | None,
) -> TokensPrompt:
"""Tokenize text input asynchronously."""
async_tokenizer = self._get_async_tokenizer()
# Handle encoder-specific preprocessing
if (
self.model_config.encoder_config is not None
and self.model_config.encoder_config.get("do_lower_case", False)
):
text = text.lower()
# Tokenize texts
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(text, add_special_tokens=add_special_tokens)
else:
encoded = await async_tokenizer(
text,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens,
)
return self._create_tokens_prompt(
encoded.input_ids, max_length, cache_salt, text
)
async def _create_prompt_from_token_ids(
self,
token_ids: list[int],
max_length: int | None,
truncate_prompt_tokens: int | None,
cache_salt: str | None,
needs_detokenization: bool | None = False,
) -> TokensPrompt:
"""Optionally detokenize token IDs and build a tokens prompt."""
token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens)
prompt = None
if needs_detokenization:
async_tokenizer = self._get_async_tokenizer()
prompt = await async_tokenizer.decode(token_ids)
return self._create_tokens_prompt(
token_ids=token_ids,
max_length=max_length,
cache_salt=cache_salt,
prompt=prompt,
)
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
"""Get or create async tokenizer using shared pool."""
async_tokenizer = self.async_tokenizer
if async_tokenizer is not None:
return async_tokenizer
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("No tokenizer available for text input processing")
if self.async_tokenizer_pool is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
else:
async_tokenizer = self.async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self.async_tokenizer_pool[tokenizer] = async_tokenizer
self.async_tokenizer = async_tokenizer
return async_tokenizer
def _create_tokens_prompt(
self,
token_ids: list[int],
max_length: int | None = None,
cache_salt: str | None = None,
prompt: str | None = None,
) -> TokensPrompt:
"""Create validated TokensPrompt."""
if max_length is not None and len(token_ids) > max_length:
raise VLLMValidationError(
f"This model's maximum context length is {max_length} tokens. "
f"However, your request has {len(token_ids)} input tokens. "
"Please reduce the length of the input messages.",
parameter="input_tokens",
value=len(token_ids),
)
tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)
if cache_salt is not None:
tokens_prompt["cache_salt"] = cache_salt
if prompt is not None:
tokens_prompt["prompt"] = prompt
return tokens_prompt
......@@ -4,12 +4,14 @@ from typing import Any
from pydantic import BaseModel, Field
from vllm.config import ModelConfig
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs
from vllm.entrypoints.openai.engine.protocol import (
SamplingParams,
StreamOptions,
)
from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
......@@ -62,6 +64,12 @@ class GenerateRequest(BaseModel):
description="KVTransfer parameters used for disaggregated serving.",
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=None,
max_output_tokens=0,
)
class GenerateResponseChoice(BaseModel):
index: int
......
......@@ -101,12 +101,13 @@ class ServingTokens(OpenAIServing):
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
# completed
engine_prompt = TokensPrompt(prompt_token_ids=request.token_ids)
if request.features is not None:
engine_prompt["multi_modal_data"] = None
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.token_ids,
prompt_embeds=None,
)
assert len(engine_prompts) == 1
engine_prompt = engine_prompts[0]
# Schedule the request and get the result generator.
result_generator: AsyncGenerator[RequestOutput, None] | None = None
......@@ -128,11 +129,15 @@ class ServingTokens(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
result_generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
)
......
......@@ -6,8 +6,10 @@ from typing import Any, TypeAlias
from pydantic import ConfigDict, Field, model_validator
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionToolsParam,
......@@ -15,6 +17,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
from vllm.entrypoints.openai.engine.protocol import (
OpenAIBaseModel,
)
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
class TokenizeCompletionRequest(OpenAIBaseModel):
......@@ -35,6 +38,13 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=None,
max_output_tokens=0,
add_special_tokens=self.add_special_tokens,
)
class TokenizeChatRequest(OpenAIBaseModel):
model: str | None = None
......@@ -109,6 +119,30 @@ class TokenizeChatRequest(OpenAIBaseModel):
)
return data
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=None,
max_output_tokens=0,
add_special_tokens=self.add_special_tokens,
)
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
......@@ -124,6 +158,13 @@ class DetokenizeRequest(OpenAIBaseModel):
model: str | None = None
tokens: list[int]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=None,
max_output_tokens=0,
needs_detokenization=True,
)
class DetokenizeResponse(OpenAIBaseModel):
prompt: str
......
......@@ -9,12 +9,9 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
)
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest,
DetokenizeResponse,
......@@ -83,21 +80,17 @@ class OpenAIServingTokenization(OpenAIServing):
_, engine_prompts = await self._preprocess_chat(
request,
self.renderer,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
chat_template_kwargs=request.chat_template_kwargs,
add_special_tokens=request.add_special_tokens,
)
else:
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.prompt,
config=self._build_render_config(request),
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=None,
)
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
......@@ -106,11 +99,14 @@ class OpenAIServingTokenization(OpenAIServing):
input_ids: list[int] = []
for engine_prompt in engine_prompts:
self._log_inputs(
request_id, engine_prompt, params=None, lora_request=lora_request
request_id,
engine_prompt,
params=None,
lora_request=lora_request,
)
if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"])
if "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
token_strs = None
if request.return_token_strs:
......@@ -136,7 +132,6 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokenize-{self._base_request_id(raw_request)}"
lora_request = self._maybe_get_adapters(request)
tokenizer = self.renderer.get_tokenizer()
self._log_inputs(
request_id,
......@@ -145,14 +140,13 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request=lora_request,
)
prompt_input = await self._tokenize_prompt_input_async(
request,
tokenizer,
request.tokens,
engine_prompt = await self.renderer.tokenize_prompt_async(
TokensPrompt(prompt_token_ids=request.tokens),
request.build_tok_params(self.model_config),
)
input_text = prompt_input["prompt"]
prompt_text = engine_prompt["prompt"] # type: ignore[typeddict-item]
return DetokenizeResponse(prompt=input_text)
return DetokenizeResponse(prompt=prompt_text)
async def get_tokenizer_info(
self,
......@@ -165,9 +159,6 @@ class OpenAIServingTokenization(OpenAIServing):
except Exception as e:
return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
return RenderConfig(add_special_tokens=request.add_special_tokens)
@dataclass
class TokenizerInfo:
......
......@@ -8,7 +8,7 @@ import os
from argparse import Namespace
from logging import Logger
from string import Template
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
import regex as re
from fastapi import Request
......@@ -18,9 +18,9 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
......@@ -34,9 +34,7 @@ if TYPE_CHECKING:
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
else:
ChatCompletionRequest = object
CompletionRequest = object
......@@ -188,33 +186,6 @@ def cli_env_setup():
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def _validate_truncation_size(
max_model_len: int,
truncate_prompt_tokens: int | None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> int | None:
if truncate_prompt_tokens is not None:
if truncate_prompt_tokens <= -1:
truncate_prompt_tokens = max_model_len
if truncate_prompt_tokens > max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({max_model_len})."
f" Please, select a smaller truncation size."
)
if tokenization_kwargs is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
else:
if tokenization_kwargs is not None:
tokenization_kwargs["truncation"] = False
return truncate_prompt_tokens
def get_max_tokens(
max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
......@@ -233,10 +204,7 @@ def get_max_tokens(
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
input_length = get_prompt_len(prompt)
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
......@@ -21,12 +21,7 @@ else:
MultiModalUUIDDict = object
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
class _CommonKeys(TypedDict):
multi_modal_data: NotRequired[MultiModalDataDict | None]
"""
Optional multi-modal data to pass to the model,
......@@ -56,7 +51,14 @@ class TextPrompt(TypedDict):
"""
class TokensPrompt(TypedDict):
class TextPrompt(_CommonKeys):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
class TokensPrompt(_CommonKeys):
"""Schema for a tokenized prompt."""
prompt_token_ids: list[int]
......@@ -68,47 +70,15 @@ class TokensPrompt(TypedDict):
token_type_ids: NotRequired[list[int]]
"""A list of token type IDs to pass to the cross encoder model."""
multi_modal_data: NotRequired[MultiModalDataDict | None]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
mm_processor_kwargs: NotRequired[dict[str, Any] | None]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
multi_modal_uuids: NotRequired[MultiModalUUIDDict]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
class EmbedsPrompt(TypedDict):
class EmbedsPrompt(_CommonKeys):
"""Schema for a prompt provided via token embeddings."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
class DataPrompt(TypedDict):
class DataPrompt(_CommonKeys):
"""Represents generic inputs handled by IO processor plugins."""
data: Any
......@@ -197,7 +167,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
mm_processor_kwargs: NotRequired[dict[str, Any]]
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt[Any, Any]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cast
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict
from typing_extensions import TypeIs
from vllm.utils.collection_utils import is_list_of
from vllm.utils import length_from_prompt_token_ids_or_embeds
from .data import (
EmbedsPrompt,
......@@ -22,50 +21,6 @@ if TYPE_CHECKING:
import torch
def parse_raw_prompts(
prompt: str | list[str] | list[int] | list[list[int]],
) -> Sequence[TextPrompt] | Sequence[TokensPrompt]:
if isinstance(prompt, str):
# case 1: a string
return [TextPrompt(prompt=prompt)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
# case 2: array of strings
if is_list_of(prompt, str):
prompt = cast(list[str], prompt)
return [TextPrompt(prompt=elem) for elem in prompt]
# case 3: array of tokens
if is_list_of(prompt, int):
prompt = cast(list[int], prompt)
return [TokensPrompt(prompt_token_ids=prompt)]
# case 4: array of token arrays
if is_list_of(prompt, list):
if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
for elem in prompt:
if not isinstance(elem, list):
raise TypeError(
"prompt must be a list of lists, but found a non-list element."
)
if not is_list_of(elem, int):
raise TypeError(
"Nested lists of tokens must contain only integers."
)
prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
raise TypeError(
"prompt must be a string, array of strings, "
"array of tokens, or array of token arrays"
)
class ParsedStrPrompt(TypedDict):
type: Literal["str"]
content: str
......@@ -145,3 +100,10 @@ def get_prompt_components(prompt: PromptType) -> PromptComponents:
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=prompt.get("prompt_embeds"),
)
def get_prompt_len(prompt: TokensPrompt | EmbedsPrompt):
return length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
......@@ -209,6 +209,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo
item_processor_data = dict(**mm_data, audios=audios)
# some tokenizer kwargs are incompatible with UltravoxProcessor
tok_kwargs.pop("add_special_tokens", None)
tok_kwargs.pop("padding", None)
tok_kwargs.pop("truncation", None)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .params import ChatParams, TokenizeParams, merge_kwargs
from .protocol import RendererLike
from .registry import RendererRegistry, renderer_from_config
__all__ = ["RendererLike", "RendererRegistry", "renderer_from_config"]
__all__ = [
"RendererLike",
"RendererRegistry",
"renderer_from_config",
"ChatParams",
"TokenizeParams",
"merge_kwargs",
]
......@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from .params import ChatParams
from .protocol import RendererLike
logger = init_logger(__name__)
......@@ -61,8 +62,8 @@ class DeepseekV32Renderer(RendererLike):
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
......@@ -74,26 +75,22 @@ class DeepseekV32Renderer(RendererLike):
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
......@@ -105,17 +102,13 @@ class DeepseekV32Renderer(RendererLike):
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from io import BytesIO
from typing import TYPE_CHECKING
import pybase64
import torch
from vllm.exceptions import VLLMValidationError
if TYPE_CHECKING:
from vllm.config import ModelConfig
def safe_load_prompt_embeds(
model_config: "ModelConfig",
embed: bytes,
) -> torch.Tensor:
if not model_config.enable_prompt_embeds:
raise VLLMValidationError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
parameter="prompt_embeds",
)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(
BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
return tensor
......@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer
from .params import ChatParams
from .protocol import RendererLike
logger = init_logger(__name__)
......@@ -61,8 +62,8 @@ class Grok2Renderer(RendererLike):
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
......@@ -74,26 +75,22 @@ class Grok2Renderer(RendererLike):
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
......@@ -105,17 +102,13 @@ class Grok2Renderer(RendererLike):
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
......@@ -25,7 +25,7 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
......@@ -33,6 +33,7 @@ from vllm.transformers_utils.chat_templates import get_chat_template_fallback_pa
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw
from .params import ChatParams
from .protocol import RendererLike
if TYPE_CHECKING:
......@@ -632,9 +633,8 @@ class HfRenderer(RendererLike):
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
model_config = self.config
tokenizer = self.get_tokenizer()
......@@ -642,9 +642,9 @@ class HfRenderer(RendererLike):
messages,
model_config,
content_format=resolve_chat_template_content_format(
chat_template=kwargs.get("chat_template"),
tools=kwargs.get("tools"),
given_format=chat_template_content_format,
chat_template=params.chat_template,
tools=params.chat_template_kwargs.get("tools"),
given_format=params.chat_template_content_format,
tokenizer=tokenizer,
model_config=model_config,
),
......@@ -654,7 +654,7 @@ class HfRenderer(RendererLike):
model_config,
tokenizer,
conversation,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
......@@ -666,7 +666,7 @@ class HfRenderer(RendererLike):
):
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
# get video placehoder, replace it with runtime video-chunk prompts
# get video placeholder, replace it with runtime video-chunk prompts
video_placeholder = getattr(
model_config.hf_config, "video_placeholder", None
)
......@@ -676,24 +676,19 @@ class HfRenderer(RendererLike):
video_placeholder,
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
model_config = self.config
tokenizer = self.get_tokenizer()
......@@ -701,9 +696,9 @@ class HfRenderer(RendererLike):
messages,
model_config,
content_format=resolve_chat_template_content_format(
chat_template=kwargs.get("chat_template"),
tools=kwargs.get("tools"),
given_format=chat_template_content_format,
chat_template=params.chat_template,
tools=params.chat_template_kwargs.get("tools"),
given_format=params.chat_template_content_format,
tokenizer=tokenizer,
model_config=model_config,
),
......@@ -713,7 +708,7 @@ class HfRenderer(RendererLike):
model_config,
tokenizer,
conversation,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
......@@ -723,9 +718,7 @@ class HfRenderer(RendererLike):
and mm_uuids is not None
and mm_data is not None
):
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
# get video placehoder, replace it with runtime video-chunk prompts
# get video placeholder, replace it with runtime video-chunk prompts
video_placeholder = getattr(
model_config.hf_config, "video_placeholder", None
)
......@@ -735,14 +728,10 @@ class HfRenderer(RendererLike):
video_placeholder,
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
......@@ -10,12 +10,13 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async
from .params import ChatParams
from .protocol import RendererLike
logger = init_logger(__name__)
......@@ -95,8 +96,8 @@ class MistralRenderer(RendererLike):
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
......@@ -104,25 +105,25 @@ class MistralRenderer(RendererLike):
content_format="string",
)
prompt_raw = safe_apply_chat_template(tokenizer, messages, **kwargs)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
prompt_raw = safe_apply_chat_template(
tokenizer,
messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
......@@ -131,17 +132,15 @@ class MistralRenderer(RendererLike):
)
prompt_raw = await self._apply_chat_template_async(
tokenizer, messages, **kwargs
tokenizer,
messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import torch
else:
torch = LazyLoader("torch", globals(), "torch")
logger = init_logger(__name__)
_S = TypeVar("_S", list[int], "torch.Tensor")
def merge_kwargs(
defaults: dict[str, Any] | None,
overrides: dict[str, Any] | None,
/,
*,
unset_values: tuple[object, ...] = (None, "auto"),
) -> dict[str, Any]:
if defaults is None:
defaults = {}
if overrides is None:
overrides = {}
return defaults | {k: v for k, v in overrides.items() if v not in unset_values}
@dataclass(frozen=True)
class ChatParams:
"""Configuration to control how to parse chat messages."""
chat_template: str | None = None
"""The chat template to apply."""
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format of the chat template."""
chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
"""The kwargs to pass to the chat template."""
def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None):
if not default_chat_template_kwargs:
return self
return ChatParams(
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
chat_template_kwargs=merge_kwargs(
default_chat_template_kwargs,
self.chat_template_kwargs,
),
)
def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
"""The arguments to pass to `tokenizer.apply_chat_template`."""
return merge_kwargs(
self.chat_template_kwargs,
dict(chat_template=self.chat_template),
)
@dataclass(frozen=True)
class TokenizeParams:
"""Configuration to control how prompts are tokenized."""
max_total_tokens: int | None
"""
Maximum allowed number of input + output tokens.
Usually, this refers to the model's context length.
"""
max_output_tokens: int = 0
"""Maximum requested number of output tokens."""
pad_prompt_tokens: int | None = None
"""
Number of tokens to pad to:
- `None` means no padding.
- `-1` maps to `max_input_tokens`.
"""
truncate_prompt_tokens: int | None = None
"""
Number of tokens to keep:
- `None` means no truncation.
- `-1` maps to `max_input_tokens`.
"""
do_lower_case: bool = False
"""Whether to normalize text to lower case before tokenization."""
add_special_tokens: bool = True
"""Whether to add special tokens."""
needs_detokenization: bool = False
"""
Whether the tokenized prompt needs to contain the original text.
Not to be confused with `SamplingParams.detokenize` which deals
with the output generated by the model.
"""
max_total_tokens_param: str = "max_total_tokens"
"""Override this to edit the message for validation errors."""
max_output_tokens_param: str = "max_output_tokens"
"""Override this to edit the message for validation errors."""
truncate_prompt_tokens_param: str = "truncate_prompt_tokens"
"""Override this to edit the message for validation errors."""
@property
def max_input_tokens(self) -> int | None:
"""Maximum allowed number of input tokens."""
if self.max_total_tokens is None:
return None
return self.max_total_tokens - self.max_output_tokens
def __post_init__(self) -> None:
max_total_tokens = self.max_total_tokens
max_output_tokens = self.max_output_tokens
max_input_tokens = self.max_input_tokens
truncate_prompt_tokens = self.truncate_prompt_tokens
if (
max_output_tokens is not None
and max_total_tokens is not None
and max_output_tokens > max_total_tokens
):
raise VLLMValidationError(
f"{self.max_output_tokens_param}={max_output_tokens}"
f"cannot be greater than "
f"{self.max_total_tokens_param}={max_total_tokens=}. "
f"Please request fewer output tokens.",
parameter=self.max_output_tokens_param,
value=max_output_tokens,
)
if (
max_input_tokens is not None
and truncate_prompt_tokens is not None
and truncate_prompt_tokens > max_input_tokens
):
raise VLLMValidationError(
f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} "
f"cannot be greater than {self.max_total_tokens_param} - "
f"{self.max_output_tokens_param} = {max_input_tokens}. "
f"Please request a smaller truncation size.",
parameter=self.truncate_prompt_tokens_param,
value=truncate_prompt_tokens,
)
def with_kwargs(self, tokenization_kwargs: dict[str, Any] | None):
if tokenization_kwargs is None:
tokenization_kwargs = {}
max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
pad_prompt_tokens = tokenization_kwargs.pop(
"pad_prompt_tokens", self.pad_prompt_tokens
)
truncate_prompt_tokens = tokenization_kwargs.pop(
"truncate_prompt_tokens", self.truncate_prompt_tokens
)
do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case)
add_special_tokens = tokenization_kwargs.pop(
"add_special_tokens", self.add_special_tokens
)
needs_detokenization = tokenization_kwargs.pop(
"needs_detokenization", self.needs_detokenization
)
# https://huggingface.co/docs/transformers/en/pad_truncation
if padding := tokenization_kwargs.pop("padding", None):
if padding == "max_length":
pad_prompt_tokens = max_length
elif padding in (False, "do_not_pad"):
pad_prompt_tokens = None
else:
# To emit the below warning
tokenization_kwargs["padding"] = padding
if truncation := tokenization_kwargs.pop("truncation", None):
if truncation in (True, "longest_first"):
truncate_prompt_tokens = max_length
elif truncation in (False, "do_not_truncate"):
truncate_prompt_tokens = None
else:
# To emit the below warning
tokenization_kwargs["truncation"] = truncation
if tokenization_kwargs:
logger.warning(
"The following tokenization arguments are not supported "
"by vLLM Renderer and will be ignored: %s",
tokenization_kwargs,
)
max_total_tokens = self.max_total_tokens
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=(
0
if max_total_tokens is None or max_length is None
else max_total_tokens - max_length
),
pad_prompt_tokens=pad_prompt_tokens,
truncate_prompt_tokens=truncate_prompt_tokens,
do_lower_case=do_lower_case,
add_special_tokens=add_special_tokens,
needs_detokenization=needs_detokenization,
)
def get_encode_kwargs(self) -> dict[str, Any]:
"""The arguments to pass to `tokenizer.encode`."""
max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0:
max_length = self.max_input_tokens
return dict(
truncation=self.truncate_prompt_tokens is not None,
max_length=max_length,
add_special_tokens=self.add_special_tokens,
)
def _apply_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
if self.do_lower_case:
text = text.lower()
return text
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply all validators to prompt text."""
# TODO: Implement https://github.com/vllm-project/vllm/pull/31366
for validator in (self._apply_lowercase,):
text = validator(tokenizer, text)
return text
def apply_pre_tokenization(
self,
tokenizer: TokenizerLike | None,
prompt: TextPrompt,
) -> TextPrompt:
"""
Ensure that the prompt meets the requirements set out by this config.
If that is not possible, raise a `VLLMValidationError`.
This method is run before tokenization occurs.
"""
prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"])
return prompt
def _apply_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply padding to a token sequence."""
pad_length = self.pad_prompt_tokens
if pad_length is not None and pad_length < 0:
pad_length = self.max_input_tokens
if pad_length is None or pad_length <= len(tokens):
return tokens
if tokenizer is None:
raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`")
if not isinstance(tokens, list):
raise ValueError("Cannot pad tokens for embedding inputs")
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
def _apply_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply truncation to a token sequence."""
max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0:
max_length = self.max_input_tokens
if max_length is None or max_length >= len(tokens):
return tokens
if max_length == 0:
return tokens[:0]
if getattr(tokenizer, "truncation_side", "left") == "left":
return tokens[-max_length:]
return tokens[:max_length]
def _apply_length_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply length checks to a token sequence."""
max_input_tokens = self.max_input_tokens
if max_input_tokens is not None and len(tokens) > max_input_tokens:
raise VLLMValidationError(
f"You passed {len(tokens)} input tokens and "
f"requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only "
f"{self.max_total_tokens}, resulting in a maximum "
f"input length of {max_input_tokens}. "
f"Please reduce the length of the input messages.",
parameter="input_tokens",
value=len(tokens),
)
return tokens
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply all validators to a token sequence."""
for validator in (
self._apply_padding,
self._apply_truncation,
self._apply_length_check,
):
tokens = validator(tokenizer, tokens)
return tokens
def apply_post_tokenization(
self,
tokenizer: TokenizerLike | None,
prompt: TokensPrompt | EmbedsPrompt,
) -> TokensPrompt | EmbedsPrompt:
"""
Ensure that the prompt meets the requirements set out by this config.
If that is not possible, raise a `VLLMValidationError`.
This method is run after tokenization occurs.
"""
if "prompt_token_ids" in prompt:
prompt["prompt_token_ids"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
tokenizer,
prompt["prompt_token_ids"], # type: ignore[typeddict-item]
)
if "prompt_embeds" in prompt:
prompt["prompt_embeds"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
tokenizer,
prompt["prompt_embeds"], # type: ignore[typeddict-item]
)
return prompt
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from typing import TYPE_CHECKING, Any, Protocol
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.collection_utils import is_list_of
from .embed_utils import safe_load_prompt_embeds
from .params import ChatParams, TokenizeParams
if TYPE_CHECKING:
from vllm.config import ModelConfig
......@@ -14,6 +20,9 @@ if TYPE_CHECKING:
class RendererLike(Protocol):
config: "ModelConfig"
_async_tokenizer: AsyncMicrobatchTokenizer
@classmethod
def from_config(
cls,
......@@ -33,16 +42,147 @@ class RendererLike(Protocol):
return tokenizer
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
# Lazy initialization since offline LLM doesn't use async
if not hasattr(self, "_async_tokenizer"):
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
return self._async_tokenizer
# Step 1: Convert raw inputs to prompts
def render_completion(
self,
prompt_raw: str | list[int] | bytes,
) -> TextPrompt | TokensPrompt | EmbedsPrompt:
error_msg = "Each prompt must be a string or an array of tokens"
if isinstance(prompt_raw, str):
return TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, list):
if not is_list_of(prompt_raw, int):
raise TypeError(error_msg)
return TokensPrompt(prompt_token_ids=prompt_raw)
if isinstance(prompt_raw, bytes):
embeds = safe_load_prompt_embeds(self.config, prompt_raw)
return EmbedsPrompt(prompt_embeds=embeds)
raise TypeError(error_msg)
def render_completions(
self,
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
prompts_raw = list[str | list[int] | bytes]()
if prompt_embeds is not None: # embeds take higher priority
if isinstance(prompt_embeds, bytes):
prompts_raw.append(prompt_embeds)
else:
prompts_raw.extend(prompt_embeds)
if prompt_input is not None:
if isinstance(prompt_input, str) or (
len(prompt_input) > 0 and is_list_of(prompt_input, int)
):
prompts_raw.append(prompt_input) # type: ignore[arg-type]
else:
prompts_raw.extend(prompt_input) # type: ignore[arg-type]
if len(prompts_raw) == 0:
raise ValueError("You must pass at least one prompt")
return [self.render_completion(prompt) for prompt in prompts_raw]
async def render_completions_async(
self,
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
return self.render_completions(prompt_input, prompt_embeds)
def render_messages(
self,
messages: list["ChatCompletionMessageParam"],
**kwargs,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
raise NotImplementedError
async def render_messages_async(
self,
messages: list["ChatCompletionMessageParam"],
**kwargs,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
return self.render_messages(messages, **kwargs)
params: ChatParams,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
return self.render_messages(messages, params)
# Step 2: Tokenize prompts if necessary
def tokenize_prompt(
self,
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
params: TokenizeParams,
) -> TokensPrompt | EmbedsPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
tokenizer = self.get_tokenizer()
prompt_token_ids = tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
tokenizer = self.get_tokenizer()
prompt_text = tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def tokenize_prompts(
self,
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
params: TokenizeParams,
) -> list[TokensPrompt | EmbedsPrompt]:
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
async def tokenize_prompt_async(
self,
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
params: TokenizeParams,
) -> TokensPrompt | EmbedsPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
tokenizer = self.get_async_tokenizer()
prompt_token_ids = await tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
tokenizer = self.get_async_tokenizer()
prompt_text = await tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
async def tokenize_prompts_async(
self,
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
params: TokenizeParams,
) -> list[TokensPrompt | EmbedsPrompt]:
return await asyncio.gather(
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
)
......@@ -9,10 +9,11 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from .params import ChatParams
from .protocol import RendererLike
logger = init_logger(__name__)
......@@ -45,8 +46,8 @@ class TerratorchRenderer(RendererLike):
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -55,7 +56,7 @@ class TerratorchRenderer(RendererLike):
content_format="string",
)
prompt = TokensPrompt(prompt_token_ids=[1])
prompt = self.render_completion([1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......@@ -66,8 +67,8 @@ class TerratorchRenderer(RendererLike):
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
......@@ -76,7 +77,7 @@ class TerratorchRenderer(RendererLike):
content_format="string",
)
prompt = TokensPrompt(prompt_token_ids=[1]) # Dummy token IDs
prompt = self.render_completion([1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment