Unverified Commit cd8b405b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Consolidate sequence normalization and enc-dec parsing (#33928)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4707f7eb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar
from typing_extensions import NotRequired, TypedDict
from vllm.sampling_params import SamplingParams
......@@ -23,7 +22,13 @@ else:
MultiModalUUIDDict = object
class _CommonKeys(TypedDict):
# Inputs to LLM API
class _PromptOptions(TypedDict):
"""
Additional options available to all
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt].
"""
multi_modal_data: NotRequired[MultiModalDataDict | None]
"""
Optional multi-modal data to pass to the model,
......@@ -53,14 +58,14 @@ class _CommonKeys(TypedDict):
"""
class TextPrompt(_CommonKeys):
class TextPrompt(_PromptOptions):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
class TokensPrompt(_CommonKeys):
class TokensPrompt(_PromptOptions):
"""Schema for a tokenized prompt."""
prompt_token_ids: list[int]
......@@ -73,7 +78,7 @@ class TokensPrompt(_CommonKeys):
"""A list of token type IDs to pass to the cross encoder model."""
class EmbedsPrompt(_CommonKeys):
class EmbedsPrompt(_PromptOptions):
"""Schema for a prompt provided via token embeddings."""
prompt_embeds: torch.Tensor
......@@ -83,93 +88,113 @@ class EmbedsPrompt(_CommonKeys):
"""The prompt text corresponding to the token embeddings, if available."""
class DataPrompt(_CommonKeys):
"""Represents generic inputs handled by IO processor plugins."""
DecoderOnlyPrompt: TypeAlias = (
str | TextPrompt | list[int] | TokensPrompt | EmbedsPrompt
)
"""
Schema of a prompt for a decoder-only model:
data: Any
"""The input data"""
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
data_format: str
"""The input data format"""
For encoder-decoder models, passing a singleton prompt is shorthand for passing
`ExplicitEncoderDecoderPrompt(encoder_prompt=prompt, decoder_prompt=None)`.
"""
SingletonPrompt: TypeAlias = str | TextPrompt | TokensPrompt | EmbedsPrompt
EncoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
"""
Schema of a prompt for the encoder part of a encoder-decoder model:
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
"""
Set of possible schemas for a single prompt:
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be
employed as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
DecoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
"""
Schema of a prompt for the decoder part of an encoder-decoder model:
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
_T1_co = TypeVar(
"_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
_T2_co = TypeVar(
"_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
Note:
Multi-modal inputs are not supported for decoder prompts.
"""
# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
class ExplicitEncoderDecoderPrompt(TypedDict):
"""
Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively, may be formatted
according to any of the
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
and are not required to have the same schema.
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
Note that an
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
may not be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
Schema for a pair of encoder and decoder singleton prompts.
Note:
This schema is not valid for decoder-only models.
"""
encoder_prompt: _T1_co
encoder_prompt: EncoderPrompt
"""The prompt for the encoder part of the model."""
decoder_prompt: _T2_co | None
decoder_prompt: DecoderPrompt | None
"""
The prompt for the decoder part of the model.
mm_processor_kwargs: NotRequired[dict[str, Any]]
Passing `None` will cause the prompt to be inferred automatically.
"""
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt[Any, Any]
EncoderDecoderPrompt: TypeAlias = EncoderPrompt | ExplicitEncoderDecoderPrompt
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
Schema for a prompt for an encoder-decoder model.
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
- A single data structure containing both an encoder and a decoder prompt
([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
You can pass a singleton encoder prompt, in which case the decoder prompt is
considered to be `None` (i.e., infer automatically).
"""
SingletonPrompt: TypeAlias = DecoderOnlyPrompt | EncoderPrompt | DecoderPrompt
"""
Schema for a single prompt. This is as opposed to a data structure
which encapsulates multiple prompts, such as
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt].
"""
class TokenInputs(TypedDict):
PromptType: TypeAlias = DecoderOnlyPrompt | EncoderDecoderPrompt
"""
Schema for any prompt, regardless of model type.
This is the input format accepted by most [`LLM`][vllm.entrypoints.llm.LLM] APIs.
"""
class DataPrompt(_PromptOptions):
"""
Represents generic inputs that are converted to
[`PromptType`][vllm.inputs.data.PromptType] by IO processor plugins.
"""
data: Any
"""The input data."""
data_format: str
"""The input data format."""
# Outputs of processor
class _InputOptions(TypedDict):
"""
Additional options available to all input types.
"""
cache_salt: NotRequired[str]
"""Optional cache salt to be used for prefix caching."""
class TokenInputs(_InputOptions):
"""Represents token-based inputs."""
type: Literal["token"]
......@@ -178,11 +203,6 @@ class TokenInputs(TypedDict):
prompt_token_ids: list[int]
"""The token IDs of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def token_inputs(
prompt_token_ids: list[int],
......@@ -198,7 +218,7 @@ def token_inputs(
return inputs
class EmbedsInputs(TypedDict):
class EmbedsInputs(_InputOptions):
"""Represents embeddings-based inputs."""
type: Literal["embeds"]
......@@ -207,11 +227,6 @@ class EmbedsInputs(TypedDict):
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def embeds_inputs(
prompt_embeds: torch.Tensor,
......@@ -229,96 +244,60 @@ def embeds_inputs(
DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
"""
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
passed to the model executor.
This specifies the data required for decoder-only models.
A processed prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for decoder-only models.
"""
EncoderInputs: TypeAlias = TokenInputs | MultiModalEncDecInputs
"""
A processed encoder prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
"""
DecoderInputs: TypeAlias = TokenInputs | MultiModalInputs
"""
A processed decoder prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
"""
class EncoderDecoderInputs(TypedDict):
"""
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
are passed to the model executor.
This specifies the required data for encoder-decoder models.
A processed pair of encoder and decoder singleton prompts.
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
"""
encoder: TokenInputs | MultiModalEncDecInputs
encoder: EncoderInputs
"""The inputs for the encoder portion."""
decoder: TokenInputs | MultiModalInputs
decoder: DecoderInputs
"""The inputs for the decoder portion."""
SingletonInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
"""
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
passed to [`Sequence`][collections.abc.Sequence].
"""
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
"""
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
A processed prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor].
"""
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
def build_explicit_enc_dec_prompt(
encoder_prompt: _T1,
decoder_prompt: _T2 | None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt,
mm_processor_kwargs=mm_processor_kwargs,
)
def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1],
dec_prompts: Iterable[_T2 | None],
mm_processor_kwargs: Iterable[dict[str, Any]] | dict[str, Any] | None = None,
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Zip encoder and decoder prompts together into a list of
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
instances.
`mm_processor_kwargs` may also be provided; if a dict is passed, the same
dictionary will be used for every encoder/decoder prompt. If an iterable is
provided, it will be zipped with the encoder/decoder prompts.
"""
if mm_processor_kwargs is None:
mm_processor_kwargs = cast(dict[str, Any], {})
if isinstance(mm_processor_kwargs, dict):
return [
build_explicit_enc_dec_prompt(
encoder_prompt,
decoder_prompt,
cast(dict[str, Any], mm_processor_kwargs),
)
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
]
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs)
for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip(
enc_prompts, dec_prompts, mm_processor_kwargs
)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
) -> list[tuple[_T1, _T2 | None]]:
return [
(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts
]
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
@dataclass
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict
from typing_extensions import TypeIs
from vllm.utils import length_from_prompt_token_ids_or_embeds
from .data import (
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
if TYPE_CHECKING:
import torch
class ParsedStrPrompt(TypedDict):
type: Literal["str"]
content: str
class ParsedTextPrompt(TypedDict):
type: Literal["text"]
content: TextPrompt
class ParsedTokensPrompt(TypedDict):
type: Literal["tokens"]
content: TokensPrompt
class ParsedEmbedsPrompt(TypedDict):
type: Literal["embeds"]
content: EmbedsPrompt
ParsedSingletonPrompt: TypeAlias = (
ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt
)
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
# Type ignores are because mypy does not correctly infer the TypedDicts
# Pyright does succeed.
if "prompt_embeds" in prompt:
return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item]
elif "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item]
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)
raise TypeError(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
)
def is_explicit_encoder_decoder_prompt(
prompt: PromptType,
) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def split_enc_dec_prompt(
prompt: PromptType,
) -> tuple[SingletonPrompt, SingletonPrompt | None]:
if isinstance(prompt, str):
return prompt, None
if "encoder_prompt" in prompt and "decoder_prompt" in prompt:
# NOTE: This passes pyright but not mypy
return (
prompt["encoder_prompt"], # type: ignore[typeddict-item]
prompt["decoder_prompt"], # type: ignore[typeddict-item]
)
return prompt, None
from .data import ProcessorInputs, SingletonInputs
def split_enc_dec_inputs(
......@@ -96,30 +15,3 @@ def split_enc_dec_inputs(
)
return None, inputs
class PromptComponents(NamedTuple):
text: str | None = None
token_ids: list[int] | None = None
embeds: "torch.Tensor | None" = None
def get_prompt_components(prompt: PromptType) -> PromptComponents:
if isinstance(prompt, str):
return PromptComponents(text=prompt)
if encoder_prompt := prompt.get("encoder_prompt"):
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
return PromptComponents(
text=prompt.get("prompt"), # type: ignore[arg-type]
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=prompt.get("prompt_embeds"),
)
def get_prompt_len(prompt: TokensPrompt | EmbedsPrompt):
return length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
......@@ -2,43 +2,51 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from typing import Any
from typing import Any, overload
from typing_extensions import assert_never
from vllm.config import ModelConfig, ObservabilityConfig
from vllm.inputs.parse import split_enc_dec_prompt
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalEncDecInputs,
MultiModalInputs,
MultiModalUUIDDict,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.renderers import renderer_from_config
from vllm.renderers.inputs import (
DecoderDictPrompt,
DecoderOnlyDictPrompt,
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDictPrompt,
SingletonDictPrompt,
TokPrompt,
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
from .data import (
DecoderInputs,
DecoderOnlyInputs,
EmbedsInputs,
EmbedsPrompt,
EncoderDecoderInputs,
EncoderInputs,
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
TokenInputs,
TokensPrompt,
embeds_inputs,
token_inputs,
)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
logger = init_logger(__name__)
......@@ -328,9 +336,36 @@ class InputPreprocessor:
return inputs
@overload
def _prompt_to_llm_inputs(
self,
prompt: SingletonPrompt,
prompt: EncoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderInputs: ...
@overload
def _prompt_to_llm_inputs( # type: ignore[misc]
self,
prompt: DecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderInputs: ...
@overload
def _prompt_to_llm_inputs( # type: ignore[misc]
self,
prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs: ...
def _prompt_to_llm_inputs(
self,
prompt: SingletonDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
......@@ -346,34 +381,25 @@ class InputPreprocessor:
* [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
"""
parsed = parse_singleton_prompt(prompt)
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
if parsed["type"] == "embeds":
return self._process_embeds(parsed["content"])
if parsed["type"] == "tokens":
if "prompt_token_ids" in prompt:
return self._process_tokens(
parsed["content"],
prompt, # type: ignore[arg-type]
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
if "prompt" in prompt:
return self._process_text(
TextPrompt(prompt=parsed["content"]),
prompt, # type: ignore[arg-type]
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
assert_never(parsed)
assert_never(prompt) # type: ignore[arg-type]
def _validate_enc_inputs(
self,
inputs: SingletonInputs,
) -> TokenInputs | MultiModalEncDecInputs:
def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
......@@ -387,10 +413,7 @@ class InputPreprocessor:
return inputs # type: ignore[return-value]
def _validate_dec_inputs(
self,
inputs: SingletonInputs,
) -> TokenInputs | MultiModalInputs:
def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
......@@ -403,14 +426,15 @@ class InputPreprocessor:
encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None = None,
) -> EncoderDecoderInputs:
if decoder_inputs is None:
decoder_inputs = encoder_inputs
enc_inputs = self._validate_enc_inputs(encoder_inputs)
dec_inputs = self._validate_dec_inputs(decoder_inputs)
enc_inputs_new: TokenInputs | MultiModalEncDecInputs
dec_inputs_new: TokenInputs | MultiModalInputs
if decoder_inputs is None:
dec_inputs: DecoderInputs = enc_inputs # type: ignore[assignment]
else:
dec_inputs = self._validate_dec_inputs(decoder_inputs)
enc_inputs_new: EncoderInputs
dec_inputs_new: DecoderInputs
if enc_inputs["type"] == "multimodal":
enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
......@@ -437,7 +461,7 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
prompt: EncoderDecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
......@@ -448,24 +472,6 @@ class InputPreprocessor:
[`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance.
There are two types of input prompts:
singleton prompts which carry only the
encoder prompt, and explicit encoder/decoder
prompts which carry both the encoder and the
decoder prompts as member variables.
This function handles the following scenarios:
* Singleton encoder prompt: extract encoder prompt
token ids & infer default decoder prompt token ids
* Explicit encoder/decoder prompt: extract encoder
and decoder prompt token ids
Note that for Explicit encoder/decoder prompts,
each sub-prompt (encoder or decoder prompt) can
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments:
* prompt: an input prompt
......@@ -475,7 +481,8 @@ class InputPreprocessor:
* [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance
"""
encoder_prompt, decoder_prompt = split_enc_dec_prompt(prompt)
encoder_prompt = prompt["encoder_prompt"]
decoder_prompt = prompt["decoder_prompt"]
return self._build_enc_dec_inputs(
encoder_inputs=self._prompt_to_llm_inputs(
......@@ -495,7 +502,7 @@ class InputPreprocessor:
def _process_decoder_only_prompt(
self,
prompt: SingletonPrompt,
prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
......@@ -521,7 +528,7 @@ class InputPreprocessor:
def _preprocess(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
......@@ -530,25 +537,20 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder.
return self._process_encoder_decoder_prompt(
prompt,
parse_enc_dec_prompt(prompt),
tokenization_kwargs,
mm_uuids=mm_uuids,
)
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError(
"Cannot pass encoder-decoder prompt to decoder-only models"
)
return self._process_decoder_only_prompt(
prompt,
parse_dec_only_prompt(prompt),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
def preprocess(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
......
......@@ -20,7 +20,7 @@ from typing import (
import numpy as np
from PIL.Image import Image
from typing_extensions import NotRequired, TypeVar
from typing_extensions import TypeVar
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader
......@@ -32,9 +32,13 @@ if TYPE_CHECKING:
import torch
import torch.types
from transformers.feature_extraction_utils import BatchFeature
from vllm.inputs.data import _InputOptions
else:
torch = LazyLoader("torch", globals(), "torch")
_InputOptions = dict
_T = TypeVar("_T")
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
......@@ -1059,7 +1063,7 @@ A dictionary containing per-item placeholder ranges for each modality.
"""
class MultiModalInputs(TypedDict):
class MultiModalInputs(_InputOptions):
"""
Represents the outputs of
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
......@@ -1084,11 +1088,6 @@ class MultiModalInputs(TypedDict):
`prompt_token_ids`.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
class MultiModalEncDecInputs(MultiModalInputs):
"""
......
......@@ -19,6 +19,7 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.selector import AttentionSelectorConfig
......@@ -565,7 +566,7 @@ class Platform:
@classmethod
def validate_request(
cls,
prompt: "PromptType",
prompt: "PromptType | DictPrompt | TokPrompt",
params: "SamplingParams | PoolingParams",
processed_inputs: "ProcessorInputs",
) -> None:
......
......@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
from .protocol import BaseRenderer
......@@ -61,7 +62,7 @@ class DeepseekV32Renderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
......@@ -75,7 +76,7 @@ class DeepseekV32Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......@@ -87,7 +88,7 @@ class DeepseekV32Renderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
......@@ -101,7 +102,7 @@ class DeepseekV32Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......
......@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
from .protocol import BaseRenderer
......@@ -61,7 +62,7 @@ class Grok2Renderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
......@@ -75,7 +76,7 @@ class Grok2Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......@@ -87,7 +88,7 @@ class Grok2Renderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
......@@ -101,7 +102,7 @@ class Grok2Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......
......@@ -25,7 +25,6 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
......@@ -33,6 +32,8 @@ from vllm.transformers_utils.chat_templates import get_chat_template_fallback_pa
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
from .protocol import BaseRenderer
......@@ -632,7 +633,7 @@ class HfRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config
tokenizer = self.get_tokenizer()
......@@ -674,7 +675,7 @@ class HfRenderer(BaseRenderer):
video_placeholder,
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......@@ -686,7 +687,7 @@ class HfRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config
tokenizer = self.get_tokenizer()
......@@ -726,7 +727,7 @@ class HfRenderer(BaseRenderer):
video_placeholder,
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .preprocess import (
DecoderDictPrompt,
DecoderOnlyDictPrompt,
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDictPrompt,
SingletonDictPrompt,
)
from .tokenize import (
DecoderOnlyTokPrompt,
DecoderTokPrompt,
EncoderDecoderTokPrompt,
EncoderTokPrompt,
SingletonTokPrompt,
TokPrompt,
)
__all__ = [
"DecoderOnlyDictPrompt",
"EncoderDictPrompt",
"DecoderDictPrompt",
"EncoderDecoderDictPrompt",
"SingletonDictPrompt",
"DictPrompt",
"DecoderOnlyTokPrompt",
"EncoderTokPrompt",
"DecoderTokPrompt",
"EncoderDecoderTokPrompt",
"SingletonTokPrompt",
"TokPrompt",
]
"""
Schemas and utilites for preprocessing inputs.
"""
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
from vllm.inputs import (
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
PromptType,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import is_list_of
if TYPE_CHECKING:
import torch
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@overload
def prompt_to_seq(
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
) -> Sequence[SingletonPrompt]: ...
@overload
def prompt_to_seq( # type: ignore[misc]
prompt_or_prompts: ExplicitEncoderDecoderPrompt
| Sequence[ExplicitEncoderDecoderPrompt],
) -> Sequence[ExplicitEncoderDecoderPrompt]: ...
@overload
def prompt_to_seq( # type: ignore[misc]
prompt_or_prompts: PromptType | Sequence[PromptType],
) -> Sequence[PromptType]: ...
def prompt_to_seq(
prompt_or_prompts: PromptType | bytes | Sequence[PromptType | bytes],
) -> Sequence[PromptType]:
if isinstance(prompt_or_prompts, (dict, str, bytes)) or (
len(prompt_or_prompts) > 0 and is_list_of(prompt_or_prompts, int)
):
return [prompt_or_prompts] # type: ignore[list-item]
return prompt_or_prompts # type: ignore[return-value]
def conversation_to_seq(
conversation_or_conversations: list["ChatCompletionMessageParam"]
| Sequence[list["ChatCompletionMessageParam"]],
) -> Sequence[list["ChatCompletionMessageParam"]]:
if len(conversation_or_conversations) > 0 and is_list_of(
conversation_or_conversations, dict
):
return [conversation_or_conversations] # type: ignore[list-item]
return conversation_or_conversations # type: ignore[return-value]
DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
"""
A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt]
that has been standardized into a dictionary.
"""
EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt]
that has been standardized into a dictionary.
"""
DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt]
that has been standardized into a dictionary.
"""
class EncoderDecoderDictPrompt(TypedDict):
"""
A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt]
that has been standardized into a dictionary.
"""
encoder_prompt: EncoderDictPrompt
decoder_prompt: DecoderDictPrompt | None
SingletonDictPrompt: TypeAlias = (
DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt
)
"""
A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt]
that has been standardized into a dictionary.
"""
DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt
"""
A [`PromptType`][vllm.inputs.data.PromptType]
that has been standardized into a dictionary.
"""
def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
"""
Parse a prompt for a decoder-only model and normalize it to a dictionary.
"""
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "encoder_prompt" in prompt:
raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models")
if (
"prompt" in prompt
or "prompt_token_ids" in prompt
or "prompt_embeds" in prompt
):
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text, tokens, or embeddings")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "prompt_embeds" in prompt:
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
if "prompt" in prompt or "prompt_token_ids" in prompt:
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text or tokens")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "prompt_embeds" in prompt:
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
if (
"multi_modal_data" in prompt
or "mm_processor_kwargs" in prompt
or "multi_modal_uuids" in prompt
):
raise TypeError("Cannot pass multi-modal inputs to decoder prompt")
if "prompt" in prompt or "prompt_token_ids" in prompt:
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text or tokens")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def parse_enc_dec_prompt(prompt: object) -> EncoderDecoderDictPrompt:
"""
Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
"""
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc_prompt: object = prompt["encoder_prompt"] # type: ignore[typeddict-item]
dec_prompt: object | None = prompt["decoder_prompt"] # type: ignore[typeddict-item]
else:
enc_prompt = prompt
dec_prompt = None
return EncoderDecoderDictPrompt(
encoder_prompt=_parse_enc_prompt(enc_prompt),
decoder_prompt=None if dec_prompt is None else _parse_dec_prompt(dec_prompt),
)
def parse_model_prompt(model_config: "ModelConfig", prompt: object):
if model_config.is_encoder_decoder:
return parse_enc_dec_prompt(prompt)
return parse_dec_only_prompt(prompt)
class PromptComponents(NamedTuple):
text: str | None = None
token_ids: list[int] | None = None
embeds: "torch.Tensor | None" = None
def extract_prompt_components(
model_config: "ModelConfig",
prompt: object,
) -> PromptComponents:
target_prompt = (
parse_enc_dec_prompt(prompt)["encoder_prompt"]
if model_config.is_encoder_decoder
else parse_dec_only_prompt(prompt)
)
return PromptComponents(
text=target_prompt.get("prompt"),
token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=target_prompt.get("prompt_embeds"),
)
def extract_prompt_len(model_config: "ModelConfig", prompt: object):
target_prompt = (
parse_enc_dec_prompt(prompt)["encoder_prompt"]
if model_config.is_encoder_decoder
else parse_dec_only_prompt(prompt)
)
return length_from_prompt_token_ids_or_embeds(
target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
target_prompt.get("prompt_embeds"),
)
"""
Schemas and utilites for tokenization inputs.
"""
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias, TypedDict
from vllm.inputs import EmbedsPrompt, TokensPrompt
DecoderOnlyTokPrompt: TypeAlias = TokensPrompt | EmbedsPrompt
"""
A [`DecoderOnlyDictPrompt`][vllm.renderers.inputs.preprocess.DecoderOnlyDictPrompt]
that has been tokenized.
"""
EncoderTokPrompt: TypeAlias = TokensPrompt
"""
A [`EncoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDictPrompt]
that has been tokenized.
"""
DecoderTokPrompt: TypeAlias = TokensPrompt
"""
A [`DecoderDictPrompt`][vllm.renderers.inputs.preprocess.DecoderDictPrompt]
that has been tokenized.
"""
class EncoderDecoderTokPrompt(TypedDict):
"""
A
[`EncoderDecoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDecoderDictPrompt]
that has been tokenized.
"""
encoder_prompt: EncoderTokPrompt
decoder_prompt: DecoderTokPrompt | None
SingletonTokPrompt: TypeAlias = (
DecoderOnlyTokPrompt | EncoderTokPrompt | DecoderTokPrompt
)
"""
A [`SingletonDictPrompt`][vllm.renderers.inputs.preprocess.SingletonDictPrompt]
that has been tokenized.
"""
TokPrompt: TypeAlias = DecoderOnlyTokPrompt | EncoderDecoderTokPrompt
"""
A [`DictPrompt`][vllm.renderers.inputs.preprocess.DictPrompt]
that has been tokenized.
"""
......@@ -10,12 +10,13 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
from .protocol import BaseRenderer
......@@ -95,7 +96,7 @@ class MistralRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
......@@ -109,7 +110,7 @@ class MistralRenderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......@@ -121,7 +122,7 @@ class MistralRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
......@@ -135,7 +136,7 @@ class MistralRenderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......
......@@ -2,14 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, overload
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.collection_utils import is_list_of
from .embed_utils import safe_load_prompt_embeds
from .inputs import (
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDecoderTokPrompt,
TokPrompt,
)
from .params import ChatParams, TokenizeParams
if TYPE_CHECKING:
......@@ -57,140 +63,217 @@ class BaseRenderer(ABC):
return self._async_tokenizer
# Step 1: Convert raw inputs to prompts
def render_completion(
def render_prompt(
self,
prompt_raw: str | list[int] | bytes,
) -> TextPrompt | TokensPrompt | EmbedsPrompt:
error_msg = "Each prompt must be a string or an array of tokens"
prompt: DictPrompt | bytes,
) -> DictPrompt:
if isinstance(prompt, bytes):
embeds = safe_load_prompt_embeds(self.config, prompt)
prompt = EmbedsPrompt(prompt_embeds=embeds)
if isinstance(prompt_raw, str):
return TextPrompt(prompt=prompt_raw)
return prompt
if isinstance(prompt_raw, list):
if not is_list_of(prompt_raw, int):
raise TypeError(error_msg)
return TokensPrompt(prompt_token_ids=prompt_raw)
if isinstance(prompt_raw, bytes):
embeds = safe_load_prompt_embeds(self.config, prompt_raw)
return EmbedsPrompt(prompt_embeds=embeds)
raise TypeError(error_msg)
def render_completions(
def render_prompts(
self,
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
prompts_raw = list[str | list[int] | bytes]()
if prompt_embeds is not None: # embeds take higher priority
if isinstance(prompt_embeds, bytes):
prompts_raw.append(prompt_embeds)
else:
prompts_raw.extend(prompt_embeds)
if prompt_input is not None:
if isinstance(prompt_input, str) or (
len(prompt_input) > 0 and is_list_of(prompt_input, int)
):
prompts_raw.append(prompt_input) # type: ignore[arg-type]
else:
prompts_raw.extend(prompt_input) # type: ignore[arg-type]
if len(prompts_raw) == 0:
prompts: Sequence[DictPrompt | bytes],
) -> list[DictPrompt]:
if len(prompts) == 0:
raise ValueError("You must pass at least one prompt")
return [self.render_completion(prompt) for prompt in prompts_raw]
return [self.render_prompt(prompt) for prompt in prompts]
async def render_completions_async(
async def render_prompts_async(
self,
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
return self.render_completions(prompt_input, prompt_embeds)
prompts: Sequence[DictPrompt | bytes],
) -> list[DictPrompt]:
return self.render_prompts(prompts)
@abstractmethod
def render_messages(
self,
messages: list["ChatCompletionMessageParam"],
params: ChatParams,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list["ConversationMessage"], DictPrompt]:
raise NotImplementedError
async def render_messages_async(
self,
messages: list["ChatCompletionMessageParam"],
params: ChatParams,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list["ConversationMessage"], DictPrompt]:
return self.render_messages(messages, params)
# Step 2: Tokenize prompts if necessary
def _tokenize_prompt(
self,
prompt: TextPrompt,
params: TokenizeParams,
) -> TokensPrompt:
tokenizer = self.get_tokenizer()
prompt_token_ids = tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
async def _tokenize_prompt_async(
self,
prompt: TextPrompt,
params: TokenizeParams,
) -> TokensPrompt:
tokenizer = self.get_async_tokenizer()
prompt_token_ids = await tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
tokenizer = self.get_tokenizer()
prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])
return prompt
async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
tokenizer = self.get_async_tokenizer()
prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])
return prompt
def _tokenize_enc_dec_prompt(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = (
self.tokenize_prompt(prompt["encoder_prompt"], params),
(
None
if prompt["decoder_prompt"] is None
else self.tokenize_prompt(prompt["decoder_prompt"], params)
),
)
return EncoderDecoderTokPrompt(
encoder_prompt=enc_prompt,
decoder_prompt=dec_prompt,
)
async def _tokenize_enc_dec_prompt_async(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = await asyncio.gather(
self.tokenize_prompt_async(prompt["encoder_prompt"], params),
(
asyncio.sleep(0)
if prompt["decoder_prompt"] is None
else self.tokenize_prompt_async(prompt["decoder_prompt"], params)
),
)
return EncoderDecoderTokPrompt(
encoder_prompt=enc_prompt,
decoder_prompt=dec_prompt,
)
@overload
def tokenize_prompt(
self,
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt | EmbedsPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
) -> TokensPrompt: ...
tokenizer = self.get_tokenizer()
prompt_token_ids = tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
@overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
@overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
def tokenize_prompt(
self,
prompt: DictPrompt,
params: TokenizeParams,
) -> TokPrompt:
if "encoder_prompt" in prompt:
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
tokenizer = self.get_tokenizer()
prompt_text = tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def tokenize_prompts(
self,
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
prompts: Sequence[DictPrompt],
params: TokenizeParams,
) -> list[TokensPrompt | EmbedsPrompt]:
) -> list[TokPrompt]:
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
@overload
async def tokenize_prompt_async(
self,
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt | EmbedsPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
) -> TokensPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
tokenizer = self.get_async_tokenizer()
prompt_token_ids = await tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
async def tokenize_prompt_async(
self,
prompt: DictPrompt,
params: TokenizeParams,
) -> TokPrompt:
if "encoder_prompt" in prompt:
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
tokenizer = self.get_async_tokenizer()
prompt_text = await tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
async def tokenize_prompts_async(
self,
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
prompts: Sequence[DictPrompt],
params: TokenizeParams,
) -> list[TokensPrompt | EmbedsPrompt]:
) -> list[TokPrompt]:
return await asyncio.gather(
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
)
......@@ -9,10 +9,11 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
from .protocol import BaseRenderer
......@@ -45,7 +46,7 @@ class TerratorchRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -54,7 +55,7 @@ class TerratorchRenderer(BaseRenderer):
content_format="string",
)
prompt = self.render_completion([1]) # Dummy token IDs
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......@@ -66,7 +67,7 @@ class TerratorchRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
......@@ -75,7 +76,7 @@ class TerratorchRenderer(BaseRenderer):
content_format="string",
)
prompt = self.render_completion([1]) # Dummy token IDs
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
......
......@@ -28,6 +28,8 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
......@@ -42,7 +44,6 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.utils import get_prompt_text
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import (
StatLoggerFactory,
......@@ -284,7 +285,11 @@ class AsyncLLM(EngineClient):
async def add_request(
self,
request_id: str,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
......@@ -367,7 +372,7 @@ class AsyncLLM(EngineClient):
data_parallel_rank=data_parallel_rank,
supported_tasks=await self.get_supported_tasks(),
)
prompt_text = get_prompt_text(prompt)
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
self.input_processor.assign_request_id(request)
......@@ -484,7 +489,9 @@ class AsyncLLM(EngineClient):
raise ValueError(
"prompt_embeds not supported for streaming inputs"
)
prompt_text = get_prompt_text(input_chunk.prompt)
prompt_text, _, _ = extract_prompt_components(
self.model_config, input_chunk.prompt
)
await self._add_request(req, prompt_text, None, 0, queue)
except (asyncio.CancelledError, GeneratorExit):
cancelled = True
......@@ -528,7 +535,11 @@ class AsyncLLM(EngineClient):
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
*,
......@@ -769,7 +780,7 @@ class AsyncLLM(EngineClient):
async def encode(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,
......
......@@ -7,14 +7,13 @@ from typing import Any, Literal, cast
from vllm.config import VllmConfig
from vllm.exceptions import VLLMValidationError
from vllm.inputs import (
from vllm.inputs.data import (
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt, split_enc_dec_inputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
......@@ -30,6 +29,7 @@ from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike
......@@ -243,8 +243,8 @@ class InputProcessor:
return mm_processor.info.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if isinstance(prompt, str):
prompt = TextPrompt(prompt=prompt)
if not isinstance(prompt, dict):
return
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
......@@ -297,7 +297,7 @@ class InputProcessor:
f"multi_modal_uuids[{modality!r}] is missing."
)
def _validate_mm_uuids(self, prompt: PromptType) -> None:
def _validate_mm_uuids(self, prompt: PromptType | DictPrompt | TokPrompt) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
......@@ -305,10 +305,10 @@ class InputProcessor:
auto-hashed downstream.
"""
if is_explicit_encoder_decoder_prompt(prompt):
self._validate_singleton_mm_uuids(prompt["encoder_prompt"])
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
if (dec_prompt := prompt["decoder_prompt"]) is not None:
if (dec_prompt := prompt["decoder_prompt"]) is not None: # type: ignore[typeddict-item]
self._validate_singleton_mm_uuids(dec_prompt)
else:
self._validate_singleton_mm_uuids(prompt)
......@@ -449,21 +449,23 @@ class InputProcessor:
def _extract_singleton_mm_data(
self, prompt: SingletonPrompt
) -> MultiModalDataDict | None:
if isinstance(prompt, str):
if not isinstance(prompt, dict):
return None
return prompt.get("multi_modal_data") # type: ignore[return-value]
return prompt.get("multi_modal_data")
def _extract_mm_data(self, prompt: PromptType) -> MultiModalDataDict | None:
if is_explicit_encoder_decoder_prompt(prompt):
return self._extract_singleton_mm_data(prompt["encoder_prompt"])
def _extract_mm_data(
self, prompt: PromptType | DictPrompt | TokPrompt
) -> MultiModalDataDict | None:
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
return self._extract_singleton_mm_data(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
else:
return self._extract_singleton_mm_data(prompt)
def _maybe_build_mm_uuids(
self,
request_id: str,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
) -> MultiModalUUIDDict | None:
"""Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and
......@@ -519,7 +521,7 @@ class InputProcessor:
def process_inputs(
self,
request_id: str,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
......
......@@ -22,6 +22,8 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
......@@ -32,7 +34,6 @@ from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.utils import get_prompt_text
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
......@@ -216,7 +217,7 @@ class LLMEngine:
def add_request(
self,
request_id: str,
prompt: EngineCoreRequest | PromptType,
prompt: EngineCoreRequest | PromptType | DictPrompt | TokPrompt,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
......@@ -251,7 +252,7 @@ class LLMEngine:
priority,
supported_tasks=self.get_supported_tasks(),
)
prompt_text = get_prompt_text(prompt)
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
self.input_processor.assign_request_id(request)
......
......@@ -17,8 +17,6 @@ import zmq
from vllm import envs
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.inputs import PromptType
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
......@@ -226,10 +224,6 @@ def get_device_indices(
return value
def get_prompt_text(prompt: PromptType) -> str | None:
return get_prompt_components(prompt)[0]
class CoreEngineActorManager:
"""
Utility class to handle creation, readiness, and shutdown
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment