Unverified Commit cd8b405b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Consolidate sequence normalization and enc-dec parsing (#33928)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4707f7eb
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import torch import torch
from typing_extensions import NotRequired, TypedDict, TypeVar from typing_extensions import NotRequired, TypedDict
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -23,7 +22,13 @@ else: ...@@ -23,7 +22,13 @@ else:
MultiModalUUIDDict = object MultiModalUUIDDict = object
class _CommonKeys(TypedDict): # Inputs to LLM API
class _PromptOptions(TypedDict):
"""
Additional options available to all
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt].
"""
multi_modal_data: NotRequired[MultiModalDataDict | None] multi_modal_data: NotRequired[MultiModalDataDict | None]
""" """
Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
...@@ -53,14 +58,14 @@ class _CommonKeys(TypedDict): ...@@ -53,14 +58,14 @@ class _CommonKeys(TypedDict):
""" """
class TextPrompt(_CommonKeys): class TextPrompt(_PromptOptions):
"""Schema for a text prompt.""" """Schema for a text prompt."""
prompt: str prompt: str
"""The input text to be tokenized before passing to the model.""" """The input text to be tokenized before passing to the model."""
class TokensPrompt(_CommonKeys): class TokensPrompt(_PromptOptions):
"""Schema for a tokenized prompt.""" """Schema for a tokenized prompt."""
prompt_token_ids: list[int] prompt_token_ids: list[int]
...@@ -73,7 +78,7 @@ class TokensPrompt(_CommonKeys): ...@@ -73,7 +78,7 @@ class TokensPrompt(_CommonKeys):
"""A list of token type IDs to pass to the cross encoder model.""" """A list of token type IDs to pass to the cross encoder model."""
class EmbedsPrompt(_CommonKeys): class EmbedsPrompt(_PromptOptions):
"""Schema for a prompt provided via token embeddings.""" """Schema for a prompt provided via token embeddings."""
prompt_embeds: torch.Tensor prompt_embeds: torch.Tensor
...@@ -83,93 +88,113 @@ class EmbedsPrompt(_CommonKeys): ...@@ -83,93 +88,113 @@ class EmbedsPrompt(_CommonKeys):
"""The prompt text corresponding to the token embeddings, if available.""" """The prompt text corresponding to the token embeddings, if available."""
class DataPrompt(_CommonKeys): DecoderOnlyPrompt: TypeAlias = (
"""Represents generic inputs handled by IO processor plugins.""" str | TextPrompt | list[int] | TokensPrompt | EmbedsPrompt
)
"""
Schema of a prompt for a decoder-only model:
data: Any - A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
"""The input data""" - A tokenized prompt (list of token IDs, or
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
data_format: str For encoder-decoder models, passing a singleton prompt is shorthand for passing
"""The input data format""" `ExplicitEncoderDecoderPrompt(encoder_prompt=prompt, decoder_prompt=None)`.
"""
SingletonPrompt: TypeAlias = str | TextPrompt | TokensPrompt | EmbedsPrompt EncoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
"""
Schema of a prompt for the encoder part of a encoder-decoder model:
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
""" """
Set of possible schemas for a single prompt:
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
Note that "singleton" is as opposed to a data structure DecoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be
employed as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
""" """
Schema of a prompt for the decoder part of an encoder-decoder model:
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
_T1_co = TypeVar( Note:
"_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True Multi-modal inputs are not supported for decoder prompts.
) """
_T2_co = TypeVar(
"_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
# TODO: Make fields ReadOnly once mypy supports it class ExplicitEncoderDecoderPrompt(TypedDict):
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
""" """
Represents an encoder/decoder model input prompt, Schema for a pair of encoder and decoder singleton prompts.
comprising an explicit encoder prompt and a decoder prompt.
Note:
The encoder and decoder prompts, respectively, may be formatted This schema is not valid for decoder-only models.
according to any of the
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
and are not required to have the same schema.
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
Note that an
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
may not be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
""" """
encoder_prompt: _T1_co encoder_prompt: EncoderPrompt
"""The prompt for the encoder part of the model."""
decoder_prompt: _T2_co | None decoder_prompt: DecoderPrompt | None
"""
The prompt for the decoder part of the model.
mm_processor_kwargs: NotRequired[dict[str, Any]] Passing `None` will cause the prompt to be inferred automatically.
"""
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt[Any, Any] EncoderDecoderPrompt: TypeAlias = EncoderPrompt | ExplicitEncoderDecoderPrompt
""" """
Set of possible schemas for an LLM input, including Schema for a prompt for an encoder-decoder model.
both decoder-only and encoder/decoder input types:
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt]) You can pass a singleton encoder prompt, in which case the decoder prompt is
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt]) considered to be `None` (i.e., infer automatically).
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt]) """
- A single data structure containing both an encoder and a decoder prompt
([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
SingletonPrompt: TypeAlias = DecoderOnlyPrompt | EncoderPrompt | DecoderPrompt
"""
Schema for a single prompt. This is as opposed to a data structure
which encapsulates multiple prompts, such as
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt].
""" """
class TokenInputs(TypedDict): PromptType: TypeAlias = DecoderOnlyPrompt | EncoderDecoderPrompt
"""
Schema for any prompt, regardless of model type.
This is the input format accepted by most [`LLM`][vllm.entrypoints.llm.LLM] APIs.
"""
class DataPrompt(_PromptOptions):
"""
Represents generic inputs that are converted to
[`PromptType`][vllm.inputs.data.PromptType] by IO processor plugins.
"""
data: Any
"""The input data."""
data_format: str
"""The input data format."""
# Outputs of processor
class _InputOptions(TypedDict):
"""
Additional options available to all input types.
"""
cache_salt: NotRequired[str]
"""Optional cache salt to be used for prefix caching."""
class TokenInputs(_InputOptions):
"""Represents token-based inputs.""" """Represents token-based inputs."""
type: Literal["token"] type: Literal["token"]
...@@ -178,11 +203,6 @@ class TokenInputs(TypedDict): ...@@ -178,11 +203,6 @@ class TokenInputs(TypedDict):
prompt_token_ids: list[int] prompt_token_ids: list[int]
"""The token IDs of the prompt.""" """The token IDs of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def token_inputs( def token_inputs(
prompt_token_ids: list[int], prompt_token_ids: list[int],
...@@ -198,7 +218,7 @@ def token_inputs( ...@@ -198,7 +218,7 @@ def token_inputs(
return inputs return inputs
class EmbedsInputs(TypedDict): class EmbedsInputs(_InputOptions):
"""Represents embeddings-based inputs.""" """Represents embeddings-based inputs."""
type: Literal["embeds"] type: Literal["embeds"]
...@@ -207,11 +227,6 @@ class EmbedsInputs(TypedDict): ...@@ -207,11 +227,6 @@ class EmbedsInputs(TypedDict):
prompt_embeds: torch.Tensor prompt_embeds: torch.Tensor
"""The embeddings of the prompt.""" """The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def embeds_inputs( def embeds_inputs(
prompt_embeds: torch.Tensor, prompt_embeds: torch.Tensor,
...@@ -229,96 +244,60 @@ def embeds_inputs( ...@@ -229,96 +244,60 @@ def embeds_inputs(
DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
""" """
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are A processed prompt from
passed to the model executor. [`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
This specifies the data required for decoder-only models. which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for decoder-only models.
"""
EncoderInputs: TypeAlias = TokenInputs | MultiModalEncDecInputs
"""
A processed encoder prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
"""
DecoderInputs: TypeAlias = TokenInputs | MultiModalInputs
"""
A processed decoder prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
""" """
class EncoderDecoderInputs(TypedDict): class EncoderDecoderInputs(TypedDict):
""" """
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they A processed pair of encoder and decoder singleton prompts.
are passed to the model executor. [`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
This specifies the required data for encoder-decoder models. [`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
""" """
encoder: TokenInputs | MultiModalEncDecInputs encoder: EncoderInputs
"""The inputs for the encoder portion.""" """The inputs for the encoder portion."""
decoder: TokenInputs | MultiModalInputs decoder: DecoderInputs
"""The inputs for the decoder portion.""" """The inputs for the decoder portion."""
SingletonInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
"""
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
passed to [`Sequence`][collections.abc.Sequence].
"""
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
""" """
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][]. A processed prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor].
""" """
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
def build_explicit_enc_dec_prompt(
encoder_prompt: _T1,
decoder_prompt: _T2 | None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt,
mm_processor_kwargs=mm_processor_kwargs,
)
def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1],
dec_prompts: Iterable[_T2 | None],
mm_processor_kwargs: Iterable[dict[str, Any]] | dict[str, Any] | None = None,
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Zip encoder and decoder prompts together into a list of
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
instances.
`mm_processor_kwargs` may also be provided; if a dict is passed, the same SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
dictionary will be used for every encoder/decoder prompt. If an iterable is
provided, it will be zipped with the encoder/decoder prompts.
"""
if mm_processor_kwargs is None:
mm_processor_kwargs = cast(dict[str, Any], {})
if isinstance(mm_processor_kwargs, dict):
return [
build_explicit_enc_dec_prompt(
encoder_prompt,
decoder_prompt,
cast(dict[str, Any], mm_processor_kwargs),
)
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
]
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs)
for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip(
enc_prompts, dec_prompts, mm_processor_kwargs
)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
) -> list[tuple[_T1, _T2 | None]]:
return [
(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts
]
@dataclass @dataclass
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict
from typing_extensions import TypeIs from .data import ProcessorInputs, SingletonInputs
from vllm.utils import length_from_prompt_token_ids_or_embeds
from .data import (
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
if TYPE_CHECKING:
import torch
class ParsedStrPrompt(TypedDict):
type: Literal["str"]
content: str
class ParsedTextPrompt(TypedDict):
type: Literal["text"]
content: TextPrompt
class ParsedTokensPrompt(TypedDict):
type: Literal["tokens"]
content: TokensPrompt
class ParsedEmbedsPrompt(TypedDict):
type: Literal["embeds"]
content: EmbedsPrompt
ParsedSingletonPrompt: TypeAlias = (
ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt
)
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
# Type ignores are because mypy does not correctly infer the TypedDicts
# Pyright does succeed.
if "prompt_embeds" in prompt:
return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item]
elif "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item]
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)
raise TypeError(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
)
def is_explicit_encoder_decoder_prompt(
prompt: PromptType,
) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def split_enc_dec_prompt(
prompt: PromptType,
) -> tuple[SingletonPrompt, SingletonPrompt | None]:
if isinstance(prompt, str):
return prompt, None
if "encoder_prompt" in prompt and "decoder_prompt" in prompt:
# NOTE: This passes pyright but not mypy
return (
prompt["encoder_prompt"], # type: ignore[typeddict-item]
prompt["decoder_prompt"], # type: ignore[typeddict-item]
)
return prompt, None
def split_enc_dec_inputs( def split_enc_dec_inputs(
...@@ -96,30 +15,3 @@ def split_enc_dec_inputs( ...@@ -96,30 +15,3 @@ def split_enc_dec_inputs(
) )
return None, inputs return None, inputs
class PromptComponents(NamedTuple):
text: str | None = None
token_ids: list[int] | None = None
embeds: "torch.Tensor | None" = None
def get_prompt_components(prompt: PromptType) -> PromptComponents:
if isinstance(prompt, str):
return PromptComponents(text=prompt)
if encoder_prompt := prompt.get("encoder_prompt"):
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
return PromptComponents(
text=prompt.get("prompt"), # type: ignore[arg-type]
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=prompt.get("prompt_embeds"),
)
def get_prompt_len(prompt: TokensPrompt | EmbedsPrompt):
return length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
...@@ -2,43 +2,51 @@ ...@@ -2,43 +2,51 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any, overload
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import ModelConfig, ObservabilityConfig from vllm.config import ModelConfig, ObservabilityConfig
from vllm.inputs.parse import split_enc_dec_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalEncDecInputs,
MultiModalInputs, MultiModalInputs,
MultiModalUUIDDict, MultiModalUUIDDict,
) )
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.renderers import renderer_from_config from vllm.renderers import renderer_from_config
from vllm.renderers.inputs import (
DecoderDictPrompt,
DecoderOnlyDictPrompt,
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDictPrompt,
SingletonDictPrompt,
TokPrompt,
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.metrics.stats import MultiModalCacheStats
from .data import ( from .data import (
DecoderInputs,
DecoderOnlyInputs, DecoderOnlyInputs,
EmbedsInputs, EmbedsInputs,
EmbedsPrompt, EmbedsPrompt,
EncoderDecoderInputs, EncoderDecoderInputs,
EncoderInputs,
ProcessorInputs, ProcessorInputs,
PromptType, PromptType,
SingletonInputs, SingletonInputs,
SingletonPrompt,
TextPrompt, TextPrompt,
TokenInputs, TokenInputs,
TokensPrompt, TokensPrompt,
embeds_inputs, embeds_inputs,
token_inputs, token_inputs,
) )
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -328,9 +336,36 @@ class InputPreprocessor: ...@@ -328,9 +336,36 @@ class InputPreprocessor:
return inputs return inputs
@overload
def _prompt_to_llm_inputs( def _prompt_to_llm_inputs(
self, self,
prompt: SingletonPrompt, prompt: EncoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderInputs: ...
@overload
def _prompt_to_llm_inputs( # type: ignore[misc]
self,
prompt: DecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderInputs: ...
@overload
def _prompt_to_llm_inputs( # type: ignore[misc]
self,
prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs: ...
def _prompt_to_llm_inputs(
self,
prompt: SingletonDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
...@@ -346,34 +381,25 @@ class InputPreprocessor: ...@@ -346,34 +381,25 @@ class InputPreprocessor:
* [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
""" """
parsed = parse_singleton_prompt(prompt) if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
if parsed["type"] == "embeds": if "prompt_token_ids" in prompt:
return self._process_embeds(parsed["content"])
if parsed["type"] == "tokens":
return self._process_tokens( return self._process_tokens(
parsed["content"], prompt, # type: ignore[arg-type]
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
if parsed["type"] == "text":
return self._process_text( if "prompt" in prompt:
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return self._process_text( return self._process_text(
TextPrompt(prompt=parsed["content"]), prompt, # type: ignore[arg-type]
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
assert_never(parsed) assert_never(prompt) # type: ignore[arg-type]
def _validate_enc_inputs( def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs:
self,
inputs: SingletonInputs,
) -> TokenInputs | MultiModalEncDecInputs:
if inputs["type"] == "embeds": if inputs["type"] == "embeds":
raise ValueError( raise ValueError(
"Embedding inputs are not supported for encoder-decoder models" "Embedding inputs are not supported for encoder-decoder models"
...@@ -387,10 +413,7 @@ class InputPreprocessor: ...@@ -387,10 +413,7 @@ class InputPreprocessor:
return inputs # type: ignore[return-value] return inputs # type: ignore[return-value]
def _validate_dec_inputs( def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs:
self,
inputs: SingletonInputs,
) -> TokenInputs | MultiModalInputs:
if inputs["type"] == "embeds": if inputs["type"] == "embeds":
raise ValueError( raise ValueError(
"Embedding inputs are not supported for encoder-decoder models" "Embedding inputs are not supported for encoder-decoder models"
...@@ -403,14 +426,15 @@ class InputPreprocessor: ...@@ -403,14 +426,15 @@ class InputPreprocessor:
encoder_inputs: SingletonInputs, encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None = None, decoder_inputs: SingletonInputs | None = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
if decoder_inputs is None:
decoder_inputs = encoder_inputs
enc_inputs = self._validate_enc_inputs(encoder_inputs) enc_inputs = self._validate_enc_inputs(encoder_inputs)
dec_inputs = self._validate_dec_inputs(decoder_inputs)
enc_inputs_new: TokenInputs | MultiModalEncDecInputs if decoder_inputs is None:
dec_inputs_new: TokenInputs | MultiModalInputs dec_inputs: DecoderInputs = enc_inputs # type: ignore[assignment]
else:
dec_inputs = self._validate_dec_inputs(decoder_inputs)
enc_inputs_new: EncoderInputs
dec_inputs_new: DecoderInputs
if enc_inputs["type"] == "multimodal": if enc_inputs["type"] == "multimodal":
enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"]) enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
...@@ -437,7 +461,7 @@ class InputPreprocessor: ...@@ -437,7 +461,7 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
prompt: PromptType, prompt: EncoderDecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
...@@ -448,24 +472,6 @@ class InputPreprocessor: ...@@ -448,24 +472,6 @@ class InputPreprocessor:
[`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance. instance.
There are two types of input prompts:
singleton prompts which carry only the
encoder prompt, and explicit encoder/decoder
prompts which carry both the encoder and the
decoder prompts as member variables.
This function handles the following scenarios:
* Singleton encoder prompt: extract encoder prompt
token ids & infer default decoder prompt token ids
* Explicit encoder/decoder prompt: extract encoder
and decoder prompt token ids
Note that for Explicit encoder/decoder prompts,
each sub-prompt (encoder or decoder prompt) can
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments: Arguments:
* prompt: an input prompt * prompt: an input prompt
...@@ -475,7 +481,8 @@ class InputPreprocessor: ...@@ -475,7 +481,8 @@ class InputPreprocessor:
* [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance instance
""" """
encoder_prompt, decoder_prompt = split_enc_dec_prompt(prompt) encoder_prompt = prompt["encoder_prompt"]
decoder_prompt = prompt["decoder_prompt"]
return self._build_enc_dec_inputs( return self._build_enc_dec_inputs(
encoder_inputs=self._prompt_to_llm_inputs( encoder_inputs=self._prompt_to_llm_inputs(
...@@ -495,7 +502,7 @@ class InputPreprocessor: ...@@ -495,7 +502,7 @@ class InputPreprocessor:
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
prompt: SingletonPrompt, prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
...@@ -521,7 +528,7 @@ class InputPreprocessor: ...@@ -521,7 +528,7 @@ class InputPreprocessor:
def _preprocess( def _preprocess(
self, self,
prompt: PromptType, prompt: PromptType | DictPrompt | TokPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
...@@ -530,25 +537,20 @@ class InputPreprocessor: ...@@ -530,25 +537,20 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder. # input prompts to encoder & decoder.
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
prompt, parse_enc_dec_prompt(prompt),
tokenization_kwargs, tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError(
"Cannot pass encoder-decoder prompt to decoder-only models"
)
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
prompt, parse_dec_only_prompt(prompt),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
def preprocess( def preprocess(
self, self,
prompt: PromptType, prompt: PromptType | DictPrompt | TokPrompt,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
......
...@@ -20,7 +20,7 @@ from typing import ( ...@@ -20,7 +20,7 @@ from typing import (
import numpy as np import numpy as np
from PIL.Image import Image from PIL.Image import Image
from typing_extensions import NotRequired, TypeVar from typing_extensions import TypeVar
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
...@@ -32,9 +32,13 @@ if TYPE_CHECKING: ...@@ -32,9 +32,13 @@ if TYPE_CHECKING:
import torch import torch
import torch.types import torch.types
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
from vllm.inputs.data import _InputOptions
else: else:
torch = LazyLoader("torch", globals(), "torch") torch = LazyLoader("torch", globals(), "torch")
_InputOptions = dict
_T = TypeVar("_T") _T = TypeVar("_T")
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"] HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
...@@ -1059,7 +1063,7 @@ A dictionary containing per-item placeholder ranges for each modality. ...@@ -1059,7 +1063,7 @@ A dictionary containing per-item placeholder ranges for each modality.
""" """
class MultiModalInputs(TypedDict): class MultiModalInputs(_InputOptions):
""" """
Represents the outputs of Represents the outputs of
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor], [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
...@@ -1084,11 +1088,6 @@ class MultiModalInputs(TypedDict): ...@@ -1084,11 +1088,6 @@ class MultiModalInputs(TypedDict):
`prompt_token_ids`. `prompt_token_ids`.
""" """
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
class MultiModalEncDecInputs(MultiModalInputs): class MultiModalEncDecInputs(MultiModalInputs):
""" """
......
...@@ -19,6 +19,7 @@ if TYPE_CHECKING: ...@@ -19,6 +19,7 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.v1.attention.selector import AttentionSelectorConfig
...@@ -565,7 +566,7 @@ class Platform: ...@@ -565,7 +566,7 @@ class Platform:
@classmethod @classmethod
def validate_request( def validate_request(
cls, cls,
prompt: "PromptType", prompt: "PromptType | DictPrompt | TokPrompt",
params: "SamplingParams | PoolingParams", params: "SamplingParams | PoolingParams",
processed_inputs: "ProcessorInputs", processed_inputs: "ProcessorInputs",
) -> None: ) -> None:
......
...@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import ( ...@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages, parse_chat_messages,
parse_chat_messages_async, parse_chat_messages_async,
) )
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams from .params import ChatParams
from .protocol import BaseRenderer from .protocol import BaseRenderer
...@@ -61,7 +62,7 @@ class DeepseekV32Renderer(BaseRenderer): ...@@ -61,7 +62,7 @@ class DeepseekV32Renderer(BaseRenderer):
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
messages, messages,
...@@ -75,7 +76,7 @@ class DeepseekV32Renderer(BaseRenderer): ...@@ -75,7 +76,7 @@ class DeepseekV32Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(), **params.get_apply_chat_template_kwargs(),
) )
prompt = self.render_completion(prompt_raw) prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
...@@ -87,7 +88,7 @@ class DeepseekV32Renderer(BaseRenderer): ...@@ -87,7 +88,7 @@ class DeepseekV32Renderer(BaseRenderer):
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async( conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages, messages,
...@@ -101,7 +102,7 @@ class DeepseekV32Renderer(BaseRenderer): ...@@ -101,7 +102,7 @@ class DeepseekV32Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(), **params.get_apply_chat_template_kwargs(),
) )
prompt = self.render_completion(prompt_raw) prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
......
...@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import ( ...@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages, parse_chat_messages,
parse_chat_messages_async, parse_chat_messages_async,
) )
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer from vllm.tokenizers.grok2 import Grok2Tokenizer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams from .params import ChatParams
from .protocol import BaseRenderer from .protocol import BaseRenderer
...@@ -61,7 +62,7 @@ class Grok2Renderer(BaseRenderer): ...@@ -61,7 +62,7 @@ class Grok2Renderer(BaseRenderer):
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
messages, messages,
...@@ -75,7 +76,7 @@ class Grok2Renderer(BaseRenderer): ...@@ -75,7 +76,7 @@ class Grok2Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(), **params.get_apply_chat_template_kwargs(),
) )
prompt = self.render_completion(prompt_raw) prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
...@@ -87,7 +88,7 @@ class Grok2Renderer(BaseRenderer): ...@@ -87,7 +88,7 @@ class Grok2Renderer(BaseRenderer):
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async( conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages, messages,
...@@ -101,7 +102,7 @@ class Grok2Renderer(BaseRenderer): ...@@ -101,7 +102,7 @@ class Grok2Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(), **params.get_apply_chat_template_kwargs(),
) )
prompt = self.render_completion(prompt_raw) prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
......
...@@ -25,7 +25,6 @@ from vllm.entrypoints.chat_utils import ( ...@@ -25,7 +25,6 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages, parse_chat_messages,
parse_chat_messages_async, parse_chat_messages_async,
) )
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
...@@ -33,6 +32,8 @@ from vllm.transformers_utils.chat_templates import get_chat_template_fallback_pa ...@@ -33,6 +32,8 @@ from vllm.transformers_utils.chat_templates import get_chat_template_fallback_pa
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams from .params import ChatParams
from .protocol import BaseRenderer from .protocol import BaseRenderer
...@@ -632,7 +633,7 @@ class HfRenderer(BaseRenderer): ...@@ -632,7 +633,7 @@ class HfRenderer(BaseRenderer):
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config model_config = self.config
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
...@@ -674,7 +675,7 @@ class HfRenderer(BaseRenderer): ...@@ -674,7 +675,7 @@ class HfRenderer(BaseRenderer):
video_placeholder, video_placeholder,
) )
prompt = self.render_completion(prompt_raw) prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
...@@ -686,7 +687,7 @@ class HfRenderer(BaseRenderer): ...@@ -686,7 +687,7 @@ class HfRenderer(BaseRenderer):
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config model_config = self.config
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
...@@ -726,7 +727,7 @@ class HfRenderer(BaseRenderer): ...@@ -726,7 +727,7 @@ class HfRenderer(BaseRenderer):
video_placeholder, video_placeholder,
) )
prompt = self.render_completion(prompt_raw) prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .preprocess import (
DecoderDictPrompt,
DecoderOnlyDictPrompt,
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDictPrompt,
SingletonDictPrompt,
)
from .tokenize import (
DecoderOnlyTokPrompt,
DecoderTokPrompt,
EncoderDecoderTokPrompt,
EncoderTokPrompt,
SingletonTokPrompt,
TokPrompt,
)
__all__ = [
"DecoderOnlyDictPrompt",
"EncoderDictPrompt",
"DecoderDictPrompt",
"EncoderDecoderDictPrompt",
"SingletonDictPrompt",
"DictPrompt",
"DecoderOnlyTokPrompt",
"EncoderTokPrompt",
"DecoderTokPrompt",
"EncoderDecoderTokPrompt",
"SingletonTokPrompt",
"TokPrompt",
]
"""
Schemas and utilites for preprocessing inputs.
"""
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
from vllm.inputs import (
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
PromptType,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import is_list_of
if TYPE_CHECKING:
import torch
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@overload
def prompt_to_seq(
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
) -> Sequence[SingletonPrompt]: ...
@overload
def prompt_to_seq( # type: ignore[misc]
prompt_or_prompts: ExplicitEncoderDecoderPrompt
| Sequence[ExplicitEncoderDecoderPrompt],
) -> Sequence[ExplicitEncoderDecoderPrompt]: ...
@overload
def prompt_to_seq( # type: ignore[misc]
prompt_or_prompts: PromptType | Sequence[PromptType],
) -> Sequence[PromptType]: ...
def prompt_to_seq(
prompt_or_prompts: PromptType | bytes | Sequence[PromptType | bytes],
) -> Sequence[PromptType]:
if isinstance(prompt_or_prompts, (dict, str, bytes)) or (
len(prompt_or_prompts) > 0 and is_list_of(prompt_or_prompts, int)
):
return [prompt_or_prompts] # type: ignore[list-item]
return prompt_or_prompts # type: ignore[return-value]
def conversation_to_seq(
conversation_or_conversations: list["ChatCompletionMessageParam"]
| Sequence[list["ChatCompletionMessageParam"]],
) -> Sequence[list["ChatCompletionMessageParam"]]:
if len(conversation_or_conversations) > 0 and is_list_of(
conversation_or_conversations, dict
):
return [conversation_or_conversations] # type: ignore[list-item]
return conversation_or_conversations # type: ignore[return-value]
DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
"""
A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt]
that has been standardized into a dictionary.
"""
EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt]
that has been standardized into a dictionary.
"""
DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt]
that has been standardized into a dictionary.
"""
class EncoderDecoderDictPrompt(TypedDict):
"""
A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt]
that has been standardized into a dictionary.
"""
encoder_prompt: EncoderDictPrompt
decoder_prompt: DecoderDictPrompt | None
SingletonDictPrompt: TypeAlias = (
DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt
)
"""
A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt]
that has been standardized into a dictionary.
"""
DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt
"""
A [`PromptType`][vllm.inputs.data.PromptType]
that has been standardized into a dictionary.
"""
def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
"""
Parse a prompt for a decoder-only model and normalize it to a dictionary.
"""
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "encoder_prompt" in prompt:
raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models")
if (
"prompt" in prompt
or "prompt_token_ids" in prompt
or "prompt_embeds" in prompt
):
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text, tokens, or embeddings")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "prompt_embeds" in prompt:
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
if "prompt" in prompt or "prompt_token_ids" in prompt:
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text or tokens")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "prompt_embeds" in prompt:
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
if (
"multi_modal_data" in prompt
or "mm_processor_kwargs" in prompt
or "multi_modal_uuids" in prompt
):
raise TypeError("Cannot pass multi-modal inputs to decoder prompt")
if "prompt" in prompt or "prompt_token_ids" in prompt:
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text or tokens")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def parse_enc_dec_prompt(prompt: object) -> EncoderDecoderDictPrompt:
"""
Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
"""
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc_prompt: object = prompt["encoder_prompt"] # type: ignore[typeddict-item]
dec_prompt: object | None = prompt["decoder_prompt"] # type: ignore[typeddict-item]
else:
enc_prompt = prompt
dec_prompt = None
return EncoderDecoderDictPrompt(
encoder_prompt=_parse_enc_prompt(enc_prompt),
decoder_prompt=None if dec_prompt is None else _parse_dec_prompt(dec_prompt),
)
def parse_model_prompt(model_config: "ModelConfig", prompt: object):
if model_config.is_encoder_decoder:
return parse_enc_dec_prompt(prompt)
return parse_dec_only_prompt(prompt)
class PromptComponents(NamedTuple):
text: str | None = None
token_ids: list[int] | None = None
embeds: "torch.Tensor | None" = None
def extract_prompt_components(
model_config: "ModelConfig",
prompt: object,
) -> PromptComponents:
target_prompt = (
parse_enc_dec_prompt(prompt)["encoder_prompt"]
if model_config.is_encoder_decoder
else parse_dec_only_prompt(prompt)
)
return PromptComponents(
text=target_prompt.get("prompt"),
token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=target_prompt.get("prompt_embeds"),
)
def extract_prompt_len(model_config: "ModelConfig", prompt: object):
target_prompt = (
parse_enc_dec_prompt(prompt)["encoder_prompt"]
if model_config.is_encoder_decoder
else parse_dec_only_prompt(prompt)
)
return length_from_prompt_token_ids_or_embeds(
target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
target_prompt.get("prompt_embeds"),
)
"""
Schemas and utilites for tokenization inputs.
"""
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias, TypedDict
from vllm.inputs import EmbedsPrompt, TokensPrompt
DecoderOnlyTokPrompt: TypeAlias = TokensPrompt | EmbedsPrompt
"""
A [`DecoderOnlyDictPrompt`][vllm.renderers.inputs.preprocess.DecoderOnlyDictPrompt]
that has been tokenized.
"""
EncoderTokPrompt: TypeAlias = TokensPrompt
"""
A [`EncoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDictPrompt]
that has been tokenized.
"""
DecoderTokPrompt: TypeAlias = TokensPrompt
"""
A [`DecoderDictPrompt`][vllm.renderers.inputs.preprocess.DecoderDictPrompt]
that has been tokenized.
"""
class EncoderDecoderTokPrompt(TypedDict):
"""
A
[`EncoderDecoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDecoderDictPrompt]
that has been tokenized.
"""
encoder_prompt: EncoderTokPrompt
decoder_prompt: DecoderTokPrompt | None
SingletonTokPrompt: TypeAlias = (
DecoderOnlyTokPrompt | EncoderTokPrompt | DecoderTokPrompt
)
"""
A [`SingletonDictPrompt`][vllm.renderers.inputs.preprocess.SingletonDictPrompt]
that has been tokenized.
"""
TokPrompt: TypeAlias = DecoderOnlyTokPrompt | EncoderDecoderTokPrompt
"""
A [`DictPrompt`][vllm.renderers.inputs.preprocess.DictPrompt]
that has been tokenized.
"""
...@@ -10,12 +10,13 @@ from vllm.entrypoints.chat_utils import ( ...@@ -10,12 +10,13 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages, parse_chat_messages,
parse_chat_messages_async, parse_chat_messages_async,
) )
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async from vllm.utils.async_utils import make_async
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams from .params import ChatParams
from .protocol import BaseRenderer from .protocol import BaseRenderer
...@@ -95,7 +96,7 @@ class MistralRenderer(BaseRenderer): ...@@ -95,7 +96,7 @@ class MistralRenderer(BaseRenderer):
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
messages, messages,
...@@ -109,7 +110,7 @@ class MistralRenderer(BaseRenderer): ...@@ -109,7 +110,7 @@ class MistralRenderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(), **params.get_apply_chat_template_kwargs(),
) )
prompt = self.render_completion(prompt_raw) prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
...@@ -121,7 +122,7 @@ class MistralRenderer(BaseRenderer): ...@@ -121,7 +122,7 @@ class MistralRenderer(BaseRenderer):
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async( conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages, messages,
...@@ -135,7 +136,7 @@ class MistralRenderer(BaseRenderer): ...@@ -135,7 +136,7 @@ class MistralRenderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(), **params.get_apply_chat_template_kwargs(),
) )
prompt = self.render_completion(prompt_raw) prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
......
...@@ -2,14 +2,20 @@ ...@@ -2,14 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, overload
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.collection_utils import is_list_of
from .embed_utils import safe_load_prompt_embeds from .embed_utils import safe_load_prompt_embeds
from .inputs import (
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDecoderTokPrompt,
TokPrompt,
)
from .params import ChatParams, TokenizeParams from .params import ChatParams, TokenizeParams
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -57,140 +63,217 @@ class BaseRenderer(ABC): ...@@ -57,140 +63,217 @@ class BaseRenderer(ABC):
return self._async_tokenizer return self._async_tokenizer
# Step 1: Convert raw inputs to prompts # Step 1: Convert raw inputs to prompts
def render_completion( def render_prompt(
self, self,
prompt_raw: str | list[int] | bytes, prompt: DictPrompt | bytes,
) -> TextPrompt | TokensPrompt | EmbedsPrompt: ) -> DictPrompt:
error_msg = "Each prompt must be a string or an array of tokens" if isinstance(prompt, bytes):
embeds = safe_load_prompt_embeds(self.config, prompt)
prompt = EmbedsPrompt(prompt_embeds=embeds)
if isinstance(prompt_raw, str): return prompt
return TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, list): def render_prompts(
if not is_list_of(prompt_raw, int):
raise TypeError(error_msg)
return TokensPrompt(prompt_token_ids=prompt_raw)
if isinstance(prompt_raw, bytes):
embeds = safe_load_prompt_embeds(self.config, prompt_raw)
return EmbedsPrompt(prompt_embeds=embeds)
raise TypeError(error_msg)
def render_completions(
self, self,
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None, prompts: Sequence[DictPrompt | bytes],
prompt_embeds: bytes | list[bytes] | None = None, ) -> list[DictPrompt]:
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]: if len(prompts) == 0:
prompts_raw = list[str | list[int] | bytes]()
if prompt_embeds is not None: # embeds take higher priority
if isinstance(prompt_embeds, bytes):
prompts_raw.append(prompt_embeds)
else:
prompts_raw.extend(prompt_embeds)
if prompt_input is not None:
if isinstance(prompt_input, str) or (
len(prompt_input) > 0 and is_list_of(prompt_input, int)
):
prompts_raw.append(prompt_input) # type: ignore[arg-type]
else:
prompts_raw.extend(prompt_input) # type: ignore[arg-type]
if len(prompts_raw) == 0:
raise ValueError("You must pass at least one prompt") raise ValueError("You must pass at least one prompt")
return [self.render_completion(prompt) for prompt in prompts_raw] return [self.render_prompt(prompt) for prompt in prompts]
async def render_completions_async( async def render_prompts_async(
self, self,
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None, prompts: Sequence[DictPrompt | bytes],
prompt_embeds: bytes | list[bytes] | None = None, ) -> list[DictPrompt]:
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]: return self.render_prompts(prompts)
return self.render_completions(prompt_input, prompt_embeds)
@abstractmethod @abstractmethod
def render_messages( def render_messages(
self, self,
messages: list["ChatCompletionMessageParam"], messages: list["ChatCompletionMessageParam"],
params: ChatParams, params: ChatParams,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list["ConversationMessage"], DictPrompt]:
raise NotImplementedError raise NotImplementedError
async def render_messages_async( async def render_messages_async(
self, self,
messages: list["ChatCompletionMessageParam"], messages: list["ChatCompletionMessageParam"],
params: ChatParams, params: ChatParams,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list["ConversationMessage"], DictPrompt]:
return self.render_messages(messages, params) return self.render_messages(messages, params)
# Step 2: Tokenize prompts if necessary # Step 2: Tokenize prompts if necessary
def _tokenize_prompt(
self,
prompt: TextPrompt,
params: TokenizeParams,
) -> TokensPrompt:
tokenizer = self.get_tokenizer()
prompt_token_ids = tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
async def _tokenize_prompt_async(
self,
prompt: TextPrompt,
params: TokenizeParams,
) -> TokensPrompt:
tokenizer = self.get_async_tokenizer()
prompt_token_ids = await tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
tokenizer = self.get_tokenizer()
prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])
return prompt
async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
tokenizer = self.get_async_tokenizer()
prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])
return prompt
def _tokenize_enc_dec_prompt(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = (
self.tokenize_prompt(prompt["encoder_prompt"], params),
(
None
if prompt["decoder_prompt"] is None
else self.tokenize_prompt(prompt["decoder_prompt"], params)
),
)
return EncoderDecoderTokPrompt(
encoder_prompt=enc_prompt,
decoder_prompt=dec_prompt,
)
async def _tokenize_enc_dec_prompt_async(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = await asyncio.gather(
self.tokenize_prompt_async(prompt["encoder_prompt"], params),
(
asyncio.sleep(0)
if prompt["decoder_prompt"] is None
else self.tokenize_prompt_async(prompt["decoder_prompt"], params)
),
)
return EncoderDecoderTokPrompt(
encoder_prompt=enc_prompt,
decoder_prompt=dec_prompt,
)
@overload
def tokenize_prompt( def tokenize_prompt(
self, self,
prompt: TextPrompt | TokensPrompt | EmbedsPrompt, prompt: TextPrompt | TokensPrompt,
params: TokenizeParams, params: TokenizeParams,
) -> TokensPrompt | EmbedsPrompt: ) -> TokensPrompt: ...
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
tokenizer = self.get_tokenizer() @overload
prompt_token_ids = tokenizer.encode( def tokenize_prompt( # type: ignore[misc]
prompt["prompt"], self,
**params.get_encode_kwargs(), prompt: EmbedsPrompt,
) params: TokenizeParams,
) -> EmbedsPrompt: ...
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt) @overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
def tokenize_prompt(
self,
prompt: DictPrompt,
params: TokenizeParams,
) -> TokPrompt:
if "encoder_prompt" in prompt:
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt: if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt: if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings") raise RuntimeError("Cannot run detokenization on embeddings")
tokenizer = self.get_tokenizer() prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
prompt_text = tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type] return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def tokenize_prompts( def tokenize_prompts(
self, self,
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt], prompts: Sequence[DictPrompt],
params: TokenizeParams, params: TokenizeParams,
) -> list[TokensPrompt | EmbedsPrompt]: ) -> list[TokPrompt]:
return [self.tokenize_prompt(prompt, params) for prompt in prompts] return [self.tokenize_prompt(prompt, params) for prompt in prompts]
@overload
async def tokenize_prompt_async( async def tokenize_prompt_async(
self, self,
prompt: TextPrompt | TokensPrompt | EmbedsPrompt, prompt: TextPrompt | TokensPrompt,
params: TokenizeParams, params: TokenizeParams,
) -> TokensPrompt | EmbedsPrompt: ) -> TokensPrompt: ...
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) @overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
tokenizer = self.get_async_tokenizer() async def tokenize_prompt_async(
prompt_token_ids = await tokenizer.encode( self,
prompt["prompt"], prompt: DictPrompt,
**params.get_encode_kwargs(), params: TokenizeParams,
) ) -> TokPrompt:
if "encoder_prompt" in prompt:
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt) if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt: if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt: if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings") raise RuntimeError("Cannot run detokenization on embeddings")
tokenizer = self.get_async_tokenizer() prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
prompt_text = await tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type] return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
async def tokenize_prompts_async( async def tokenize_prompts_async(
self, self,
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt], prompts: Sequence[DictPrompt],
params: TokenizeParams, params: TokenizeParams,
) -> list[TokensPrompt | EmbedsPrompt]: ) -> list[TokPrompt]:
return await asyncio.gather( return await asyncio.gather(
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts) *(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
) )
...@@ -9,10 +9,11 @@ from vllm.entrypoints.chat_utils import ( ...@@ -9,10 +9,11 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages, parse_chat_messages,
parse_chat_messages_async, parse_chat_messages_async,
) )
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams from .params import ChatParams
from .protocol import BaseRenderer from .protocol import BaseRenderer
...@@ -45,7 +46,7 @@ class TerratorchRenderer(BaseRenderer): ...@@ -45,7 +46,7 @@ class TerratorchRenderer(BaseRenderer):
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config model_config = self.config
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -54,7 +55,7 @@ class TerratorchRenderer(BaseRenderer): ...@@ -54,7 +55,7 @@ class TerratorchRenderer(BaseRenderer):
content_format="string", content_format="string",
) )
prompt = self.render_completion([1]) # Dummy token IDs prompt = parse_dec_only_prompt([1]) # Dummy token IDs
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
...@@ -66,7 +67,7 @@ class TerratorchRenderer(BaseRenderer): ...@@ -66,7 +67,7 @@ class TerratorchRenderer(BaseRenderer):
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
params: ChatParams, params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config model_config = self.config
conversation, mm_data, mm_uuids = await parse_chat_messages_async( conversation, mm_data, mm_uuids = await parse_chat_messages_async(
...@@ -75,7 +76,7 @@ class TerratorchRenderer(BaseRenderer): ...@@ -75,7 +76,7 @@ class TerratorchRenderer(BaseRenderer):
content_format="string", content_format="string",
) )
prompt = self.render_completion([1]) # Dummy token IDs prompt = parse_dec_only_prompt([1]) # Dummy token IDs
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
......
...@@ -28,6 +28,8 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput ...@@ -28,6 +28,8 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer, merge_kwargs from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -42,7 +44,6 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError ...@@ -42,7 +44,6 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.utils import get_prompt_text
from vllm.v1.executor import Executor from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import ( from vllm.v1.metrics.loggers import (
StatLoggerFactory, StatLoggerFactory,
...@@ -284,7 +285,11 @@ class AsyncLLM(EngineClient): ...@@ -284,7 +285,11 @@ class AsyncLLM(EngineClient):
async def add_request( async def add_request(
self, self,
request_id: str, request_id: str,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None], prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
arrival_time: float | None = None, arrival_time: float | None = None,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
...@@ -367,7 +372,7 @@ class AsyncLLM(EngineClient): ...@@ -367,7 +372,7 @@ class AsyncLLM(EngineClient):
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
supported_tasks=await self.get_supported_tasks(), supported_tasks=await self.get_supported_tasks(),
) )
prompt_text = get_prompt_text(prompt) prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
self.input_processor.assign_request_id(request) self.input_processor.assign_request_id(request)
...@@ -484,7 +489,9 @@ class AsyncLLM(EngineClient): ...@@ -484,7 +489,9 @@ class AsyncLLM(EngineClient):
raise ValueError( raise ValueError(
"prompt_embeds not supported for streaming inputs" "prompt_embeds not supported for streaming inputs"
) )
prompt_text = get_prompt_text(input_chunk.prompt) prompt_text, _, _ = extract_prompt_components(
self.model_config, input_chunk.prompt
)
await self._add_request(req, prompt_text, None, 0, queue) await self._add_request(req, prompt_text, None, 0, queue)
except (asyncio.CancelledError, GeneratorExit): except (asyncio.CancelledError, GeneratorExit):
cancelled = True cancelled = True
...@@ -528,7 +535,11 @@ class AsyncLLM(EngineClient): ...@@ -528,7 +535,11 @@ class AsyncLLM(EngineClient):
# re-multiplexed in the API server anyhow. # re-multiplexed in the API server anyhow.
async def generate( async def generate(
self, self,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None], prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
*, *,
...@@ -769,7 +780,7 @@ class AsyncLLM(EngineClient): ...@@ -769,7 +780,7 @@ class AsyncLLM(EngineClient):
async def encode( async def encode(
self, self,
prompt: PromptType, prompt: PromptType | DictPrompt | TokPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
......
...@@ -7,14 +7,13 @@ from typing import Any, Literal, cast ...@@ -7,14 +7,13 @@ from typing import Any, Literal, cast
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs import ( from vllm.inputs.data import (
ProcessorInputs, ProcessorInputs,
PromptType, PromptType,
SingletonInputs, SingletonInputs,
SingletonPrompt, SingletonPrompt,
TextPrompt,
) )
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt, split_enc_dec_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -30,6 +29,7 @@ from vllm.multimodal.processing.context import set_request_id ...@@ -30,6 +29,7 @@ from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -243,8 +243,8 @@ class InputProcessor: ...@@ -243,8 +243,8 @@ class InputProcessor:
return mm_processor.info.parse_mm_data(mm_data) return mm_processor.info.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None: def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if isinstance(prompt, str): if not isinstance(prompt, dict):
prompt = TextPrompt(prompt=prompt) return
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {}) mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {}) mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
...@@ -297,7 +297,7 @@ class InputProcessor: ...@@ -297,7 +297,7 @@ class InputProcessor:
f"multi_modal_uuids[{modality!r}] is missing." f"multi_modal_uuids[{modality!r}] is missing."
) )
def _validate_mm_uuids(self, prompt: PromptType) -> None: def _validate_mm_uuids(self, prompt: PromptType | DictPrompt | TokPrompt) -> None:
""" """
Validate that user-provided multi_modal_uuids align with Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s). multi_modal_data in the incoming request prompt(s).
...@@ -305,10 +305,10 @@ class InputProcessor: ...@@ -305,10 +305,10 @@ class InputProcessor:
auto-hashed downstream. auto-hashed downstream.
""" """
if is_explicit_encoder_decoder_prompt(prompt): if isinstance(prompt, dict) and "encoder_prompt" in prompt:
self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
if (dec_prompt := prompt["decoder_prompt"]) is not None: if (dec_prompt := prompt["decoder_prompt"]) is not None: # type: ignore[typeddict-item]
self._validate_singleton_mm_uuids(dec_prompt) self._validate_singleton_mm_uuids(dec_prompt)
else: else:
self._validate_singleton_mm_uuids(prompt) self._validate_singleton_mm_uuids(prompt)
...@@ -449,21 +449,23 @@ class InputProcessor: ...@@ -449,21 +449,23 @@ class InputProcessor:
def _extract_singleton_mm_data( def _extract_singleton_mm_data(
self, prompt: SingletonPrompt self, prompt: SingletonPrompt
) -> MultiModalDataDict | None: ) -> MultiModalDataDict | None:
if isinstance(prompt, str): if not isinstance(prompt, dict):
return None return None
return prompt.get("multi_modal_data") # type: ignore[return-value] return prompt.get("multi_modal_data")
def _extract_mm_data(self, prompt: PromptType) -> MultiModalDataDict | None: def _extract_mm_data(
if is_explicit_encoder_decoder_prompt(prompt): self, prompt: PromptType | DictPrompt | TokPrompt
return self._extract_singleton_mm_data(prompt["encoder_prompt"]) ) -> MultiModalDataDict | None:
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
return self._extract_singleton_mm_data(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
else: else:
return self._extract_singleton_mm_data(prompt) return self._extract_singleton_mm_data(prompt)
def _maybe_build_mm_uuids( def _maybe_build_mm_uuids(
self, self,
request_id: str, request_id: str,
prompt: PromptType, prompt: PromptType | DictPrompt | TokPrompt,
) -> MultiModalUUIDDict | None: ) -> MultiModalUUIDDict | None:
"""Build per-item multimodal hash overrides when enabled. In this case, """Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and multimodal data items are identified by their request id, modality and
...@@ -519,7 +521,7 @@ class InputProcessor: ...@@ -519,7 +521,7 @@ class InputProcessor:
def process_inputs( def process_inputs(
self, self,
request_id: str, request_id: str,
prompt: PromptType, prompt: PromptType | DictPrompt | TokPrompt,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
arrival_time: float | None = None, arrival_time: float | None = None,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
......
...@@ -22,6 +22,8 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput ...@@ -22,6 +22,8 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -32,7 +34,6 @@ from vllm.v1.engine.core_client import EngineCoreClient ...@@ -32,7 +34,6 @@ from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.utils import get_prompt_text
from vllm.v1.executor import Executor from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
...@@ -216,7 +217,7 @@ class LLMEngine: ...@@ -216,7 +217,7 @@ class LLMEngine:
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
prompt: EngineCoreRequest | PromptType, prompt: EngineCoreRequest | PromptType | DictPrompt | TokPrompt,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
arrival_time: float | None = None, arrival_time: float | None = None,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
...@@ -251,7 +252,7 @@ class LLMEngine: ...@@ -251,7 +252,7 @@ class LLMEngine:
priority, priority,
supported_tasks=self.get_supported_tasks(), supported_tasks=self.get_supported_tasks(),
) )
prompt_text = get_prompt_text(prompt) prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
self.input_processor.assign_request_id(request) self.input_processor.assign_request_id(request)
......
...@@ -17,8 +17,6 @@ import zmq ...@@ -17,8 +17,6 @@ import zmq
from vllm import envs from vllm import envs
from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.inputs import PromptType
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy from vllm.ray.ray_env import get_env_vars_to_copy
...@@ -226,10 +224,6 @@ def get_device_indices( ...@@ -226,10 +224,6 @@ def get_device_indices(
return value return value
def get_prompt_text(prompt: PromptType) -> str | None:
return get_prompt_components(prompt)[0]
class CoreEngineActorManager: class CoreEngineActorManager:
""" """
Utility class to handle creation, readiness, and shutdown Utility class to handle creation, readiness, and shutdown
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment