Unverified Commit ba2f0acc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Reorganize inputs (#35182)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 678b3c99
......@@ -23,13 +23,13 @@ from transformers.models.whisper.modeling_whisper import (
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs import MultiModalDataDict
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
NestedTensors,
......
......@@ -19,7 +19,7 @@ from transformers import BatchFeature, WhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs import MultiModalDataDict, PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -32,7 +32,6 @@ from vllm.model_executor.models.whisper import (
from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
NestedTensors,
......
......@@ -20,7 +20,7 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.engine.protocol import StreamingInput
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
from vllm.model_executor.models.voxtral import (
......@@ -31,9 +31,7 @@ from vllm.model_executor.models.voxtral import (
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import _I, BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalKwargsOptionalItems,
)
from vllm.multimodal.inputs import MultiModalKwargsOptionalItems
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
......
......@@ -21,7 +21,12 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType, TextPrompt
from vllm.inputs import (
ExplicitEncoderDecoderPrompt,
MultiModalDataDict,
PromptType,
TextPrompt,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import (
......@@ -44,7 +49,6 @@ from vllm.model_executor.models.whisper_utils import (
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .hasher import MultiModalHasher
from .inputs import (
BatchedTensorInputs,
ModalityData,
MultiModalDataBuiltins,
MultiModalDataDict,
MultiModalKwargsItems,
MultiModalPlaceholderDict,
MultiModalUUIDDict,
NestedTensors,
)
from .inputs import BatchedTensorInputs, MultiModalKwargsItems, NestedTensors
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
......@@ -25,13 +16,8 @@ Info:
__all__ = [
"BatchedTensorInputs",
"ModalityData",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalHasher",
"MultiModalKwargsItems",
"MultiModalPlaceholderDict",
"MultiModalUUIDDict",
"NestedTensors",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
......
......@@ -15,12 +15,11 @@ from typing import (
TypedDict,
Union,
cast,
final,
)
import numpy as np
from PIL.Image import Image
from typing_extensions import NotRequired, TypeVar
from typing_extensions import TypeVar
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader
......@@ -32,14 +31,9 @@ if TYPE_CHECKING:
import torch
import torch.types
from transformers.feature_extraction_utils import BatchFeature
from vllm.inputs.data import _InputOptions
else:
torch = LazyLoader("torch", globals(), "torch")
_InputOptions = dict
_T = TypeVar("_T")
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
"""
......@@ -98,15 +92,6 @@ which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""
ModalityData: TypeAlias = _T | list[_T | None] | None
"""
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
The number of data items allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""
class VisionChunkImage(TypedDict):
"""Represents an image wrapped as a vision chunk."""
......@@ -126,46 +111,10 @@ class VisionChunkVideo(TypedDict):
video_idx: int
VisionChunk = VisionChunkImage | VisionChunkVideo
VisionChunk: TypeAlias = VisionChunkImage | VisionChunkVideo
"""A vision chunk is either an image or a video chunk."""
@final
class MultiModalDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: ModalityData[ImageItem]
"""The input image(s)."""
video: ModalityData[VideoItem]
"""The input video(s)."""
audio: ModalityData[AudioItem]
"""The input audio(s)."""
vision_chunk: ModalityData[VisionChunk]
"""The input visual atom(s) - unified modality for images and video chunks."""
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
"""
A dictionary containing an entry for each modality type to input.
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
"""
MultiModalUUIDDict: TypeAlias = Mapping[str, Sequence[str | None] | str]
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.
The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""
@dataclass(frozen=True)
class PlaceholderRange:
"""
......@@ -1048,112 +997,3 @@ MultiModalKwargsOptionalItems: TypeAlias = (
MultiModalKwargsItems[MultiModalKwargsItem]
| MultiModalKwargsItems[MultiModalKwargsItem | None]
)
MultiModalHashes = dict[str, list[str]]
"""
A dictionary containing per-item hashes for each modality.
"""
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
"""
A dictionary containing per-item placeholder ranges for each modality.
"""
class MultiModalInputs(_InputOptions):
"""
Represents the outputs of
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
ready to be passed to vLLM internals.
"""
type: Literal["multimodal"]
"""The type of inputs."""
prompt_token_ids: list[int]
"""The processed token IDs which includes placeholder tokens."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token IDs, if available."""
mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: MultiModalHashes
"""The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in
`prompt_token_ids`.
"""
def mm_inputs(
prompt_token_ids: list[int],
mm_kwargs: MultiModalKwargsOptionalItems,
mm_hashes: MultiModalHashes,
mm_placeholders: MultiModalPlaceholderDict,
*,
prompt: str | None = None,
cache_salt: str | None = None,
) -> MultiModalInputs:
inputs = MultiModalInputs(
type="multimodal",
prompt_token_ids=prompt_token_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
if prompt is not None:
inputs["prompt"] = prompt
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
class MultiModalEncDecInputs(MultiModalInputs):
"""
Represents the outputs of
[`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
ready to be passed to vLLM internals.
Note: Even text-only encoder-decoder models are currently implemented
as multi-modal models for convenience.
(Example: https://github.com/vllm-project/bart-plugin)
"""
encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt."""
encoder_prompt: NotRequired[str]
"""The prompt text corresponding to the encoder token IDs, if available."""
def mm_enc_dec_inputs(
encoder_inputs: MultiModalInputs,
decoder_prompt_token_ids: list[int],
*,
decoder_prompt: str | None = None,
) -> MultiModalEncDecInputs:
inputs = MultiModalEncDecInputs(
type="multimodal",
prompt_token_ids=decoder_prompt_token_ids,
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
mm_kwargs=encoder_inputs["mm_kwargs"],
mm_hashes=encoder_inputs["mm_hashes"],
mm_placeholders=encoder_inputs["mm_placeholders"],
)
if decoder_prompt is not None:
inputs["prompt"] = decoder_prompt
if "prompt" in encoder_inputs:
inputs["encoder_prompt"] = encoder_inputs["prompt"]
if "cache_salt" in encoder_inputs:
inputs["cache_salt"] = encoder_inputs["cache_salt"]
return inputs
......@@ -19,6 +19,7 @@ import numpy as np
import torch
from typing_extensions import assert_never
from vllm.inputs import ModalityData, MultiModalDataDict, MultiModalUUIDDict
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader
......@@ -29,11 +30,8 @@ from .inputs import (
HfImageItem,
HfVideoItem,
ImageItem,
ModalityData,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
MultiModalUUIDDict,
VideoItem,
)
from .media import MediaWithBytes
......@@ -407,8 +405,8 @@ _D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
"""
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
normalized such that each entry corresponds to a list.
A normalized [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict]
such that each entry corresponds to a list.
"""
def select(self, modalities: Set[str]):
......@@ -477,7 +475,7 @@ ModalityDataParser: TypeAlias = Callable[
class MultiModalDataParser:
"""
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
Parses [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict]
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
Args:
......@@ -695,8 +693,8 @@ class MultiModalDataParser:
MultiModalUUIDItems: TypeAlias = dict[str, Sequence[str | None]]
"""
As [`MultiModalUUIDDict`][vllm.multimodal.inputs.MultiModalUUIDDict], but
normalized such that each entry corresponds to a list.
A normalized [`MultiModalUUIDDict`][vllm.inputs.MultiModalUUIDDict]
such that each entry corresponds to a list.
"""
......
......@@ -11,15 +11,14 @@ from typing import TYPE_CHECKING, Any, overload
import torch
from typing_extensions import TypeVar
from vllm.inputs import MultiModalDataDict
from vllm.logger import init_logger
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.multimodal.parse import (
DictEmbeddingItems,
EmbeddingItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.renderers import TokenizeParams
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
......@@ -32,12 +31,14 @@ if TYPE_CHECKING:
from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig
from vllm.renderers import TokenizeParams
else:
PretrainedConfig = object
BatchFeature = object
ProcessorMixin = object
ModelConfig = object
TokenizeParams = object
logger = init_logger(__name__)
......@@ -339,6 +340,8 @@ class BaseProcessingInfo:
def get_default_tok_params(self) -> TokenizeParams:
"""Construct the default parameters for tokenization."""
from vllm.renderers import TokenizeParams
model_config = self.ctx.model_config
encoder_config = model_config.encoder_config or {}
......@@ -451,8 +454,7 @@ class BaseProcessingInfo:
validate: bool = True,
) -> MultiModalDataItems:
"""
Normalize
[`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
Normalize [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict]
to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
......
......@@ -14,9 +14,9 @@ from vllm.config.multimodal import (
ImageDummyOptions,
VideoDummyOptions,
)
from vllm.inputs import MultiModalDataDict
from vllm.logger import init_logger
from ..inputs import MultiModalDataDict
from .context import BaseProcessingInfo
from .inputs import ProcessorInputs
......
......@@ -3,8 +3,9 @@
from collections.abc import Mapping
from dataclasses import dataclass, field
from vllm.inputs import MultiModalHashes
from ..hasher import MultiModalHasher
from ..inputs import MultiModalHashes
from ..parse import MultiModalDataItems, MultiModalUUIDItems
......@@ -26,7 +27,7 @@ class ProcessorInputs:
mm_uuid_items = self.mm_uuid_items or {}
hf_processor_mm_kwargs = self.hf_processor_mm_kwargs
mm_hashes: MultiModalHashes = {}
mm_hashes = dict[str, list[str]]()
hasher = MultiModalHasher
for modality, data_items in mm_data_items.items():
......
......@@ -6,34 +6,29 @@ from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, S
from dataclasses import dataclass, field, replace
from enum import Enum
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Generic,
NamedTuple,
Protocol,
TypeAlias,
cast,
)
from typing import TYPE_CHECKING, Generic, NamedTuple, Protocol, TypeAlias, cast
import regex as re
import torch
from typing_extensions import TypeVar, assert_never
from vllm.inputs import (
MultiModalEncDecInput,
MultiModalHashes,
MultiModalInput,
mm_enc_dec_input,
mm_input,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from ..inputs import (
MultiModalEncDecInputs,
MultiModalFieldConfig,
MultiModalHashes,
MultiModalInputs,
MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalKwargsOptionalItems,
PlaceholderRange,
mm_enc_dec_inputs,
mm_inputs,
)
from ..parse import (
DictEmbeddingItems,
......@@ -994,7 +989,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs:
) -> MultiModalInput:
processor_inputs = ProcessorInputs(
prompt,
mm_items,
......@@ -1638,7 +1633,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
) -> MultiModalInput:
"""
Process multi-modal inputs to be used in vLLM.
......@@ -1673,7 +1668,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for modality, placeholders in mm_placeholders.items()
}
return mm_inputs(
return mm_input(
prompt_token_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_hashes=mm_info.hashes,
......@@ -1708,7 +1703,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
encoder_inputs: MultiModalInputs,
encoder_inputs: MultiModalInput,
):
tokenizer = self.info.get_tokenizer()
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items)
......@@ -1721,7 +1716,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
decoder_prompt_text = None
decoder_prompt_ids = decoder_prompt_raw
return mm_enc_dec_inputs(
return mm_enc_dec_input(
encoder_inputs,
decoder_prompt_ids,
decoder_prompt=decoder_prompt_text,
......@@ -1731,7 +1726,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
self,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalEncDecInputs:
) -> MultiModalEncDecInput:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
......
......@@ -7,6 +7,7 @@ from dataclasses import dataclass
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
from vllm.inputs import MultiModalInput
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
......@@ -19,7 +20,6 @@ from .cache import (
ShmObjectStoreReceiverCache,
ShmObjectStoreSenderCache,
)
from .inputs import MultiModalInputs
from .processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
......@@ -220,7 +220,7 @@ class MultiModalRegistry:
*,
cache: BaseMultiModalProcessorCache | None = None,
processor: BaseMultiModalProcessor | None = None,
) -> MultiModalInputs:
) -> MultiModalInput:
"""
Create dummy data for profiling the memory usage of a model.
......
......@@ -12,6 +12,7 @@ import numpy.typing as npt
from PIL import Image
from typing_extensions import deprecated
from vllm.inputs import MultiModalPlaceholders
from vllm.utils.import_utils import LazyLoader
from .hasher import MultiModalHasher
......@@ -19,7 +20,6 @@ from .inputs import (
BatchedTensorInputs,
MultiModalFieldElem,
MultiModalKwargsItem,
MultiModalPlaceholderDict,
MultiModalSharedField,
)
from .media import AudioMediaIO, ImageMediaIO, MediaConnector, VideoMediaIO
......@@ -110,10 +110,10 @@ def encode_video_url(
def argsort_mm_positions(
mm_positions: MultiModalPlaceholderDict,
mm_positions: MultiModalPlaceholders,
) -> list[tuple[str, int]]:
"""
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
Given a `MultiModalPlaceholders`, output a sequence of keys to
sort the dictionary by `offset` (starting index in the input sequence)
in ascending order.
......
......@@ -17,7 +17,7 @@ if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs
from vllm.inputs import EngineInput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
......@@ -635,7 +635,7 @@ class Platform:
@classmethod
def validate_request(
cls,
processed_inputs: "ProcessorInputs",
processed_inputs: "EngineInput",
params: "SamplingParams | PoolingParams",
) -> None:
"""Raises if this request is unsupported on this platform"""
......
......@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Sequence
from typing import Generic, TypeVar
from vllm.config import VllmConfig
from vllm.inputs.data import PromptType
from vllm.inputs import PromptType
from vllm.outputs import PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
......
......@@ -11,17 +11,32 @@ from typing import TYPE_CHECKING, Any, Generic, overload
from typing_extensions import TypeVar
from vllm.inputs import (
EmbedsInputs,
EmbedsInput,
EmbedsPrompt,
EncoderDecoderInputs,
ProcessorInputs,
SingletonInputs,
EncoderDecoderInput,
EngineInput,
MultiModalDataDict,
MultiModalInput,
MultiModalUUIDDict,
SingletonInput,
TextPrompt,
TokenInputs,
TokensInput,
TokensPrompt,
build_enc_dec_input,
embeds_input,
tokens_input,
)
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.parse import (
MultiModalDataItems,
MultiModalUUIDItems,
parse_mm_uuids,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
from vllm.multimodal.registry import MultiModalTimingRegistry
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.counter import AtomicCounter
......@@ -46,14 +61,6 @@ if TYPE_CHECKING:
ChatCompletionMessageParam,
ConversationMessage,
)
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalInputs,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalUUIDItems
from vllm.multimodal.processing import BaseMultiModalProcessor
logger = init_logger(__name__)
......@@ -86,9 +93,6 @@ class BaseRenderer(ABC, Generic[_T]):
self.mm_processor: BaseMultiModalProcessor | None = None
self._mm_cache_stats: MultiModalCacheStats | None = None
if config.model_config.is_multimodal_model:
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
from vllm.multimodal.registry import MultiModalTimingRegistry
mm_processor_cache = mm_registry.processor_cache_from_config(config)
# Deep-copy the tokenizer so the multimodal processor gets its
......@@ -524,9 +528,9 @@ class BaseRenderer(ABC, Generic[_T]):
# Step 4: Convert to engine inputs
def _validate_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_data_items: "MultiModalDataItems",
mm_uuid_items: "MultiModalUUIDItems",
mm_data: MultiModalDataDict,
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems,
) -> None:
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in
# `mm_data_items`
......@@ -560,11 +564,11 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_mm_uuids(
self,
mm_data: "MultiModalDataDict",
mm_data_items: "MultiModalDataItems",
mm_uuid_items: "MultiModalUUIDItems",
mm_data: MultiModalDataDict,
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems,
mm_req_id: str,
):
) -> MultiModalUUIDItems:
model_config = self.model_config
# NOTE: When users explicitly turn off BOTH prefix caching and input
......@@ -590,14 +594,11 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_multimodal(
self,
prompt: list[int] | str,
mm_data: "MultiModalDataDict",
mm_uuids: "MultiModalUUIDDict | None",
mm_data: MultiModalDataDict,
mm_uuids: MultiModalUUIDDict | None,
mm_processor_kwargs: Mapping[str, object] | None,
tokenization_kwargs: dict[str, Any] | None,
) -> "MultiModalInputs":
from vllm.multimodal.parse import parse_mm_uuids
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
) -> "MultiModalInput":
mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}"
mm_processor = self.get_mm_processor()
......@@ -628,12 +629,12 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_tokens(
self,
prompt: TokensPrompt,
) -> "TokenInputs | MultiModalInputs":
) -> TokensInput | MultiModalInput:
prompt_token_ids = prompt["prompt_token_ids"]
inputs: TokenInputs | MultiModalInputs
engine_input: TokensInput | MultiModalInput
if multi_modal_data := prompt.get("multi_modal_data"):
inputs = self._process_multimodal(
engine_input = self._process_multimodal(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
......@@ -641,19 +642,16 @@ class BaseRenderer(ABC, Generic[_T]):
mm_uuids=prompt.get("multi_modal_uuids"),
)
else:
inputs = token_inputs(prompt_token_ids)
engine_input = tokens_input(prompt_token_ids)
if prompt_text := prompt.get("prompt"):
inputs["prompt"] = prompt_text
engine_input["prompt"] = prompt_text
if cache_salt := prompt.get("cache_salt"):
inputs["cache_salt"] = cache_salt
engine_input["cache_salt"] = cache_salt
return inputs
return engine_input
def _process_embeds(
self,
prompt: EmbedsPrompt,
) -> EmbedsInputs:
def _process_embeds(self, prompt: EmbedsPrompt) -> EmbedsInput:
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
......@@ -676,15 +674,12 @@ class BaseRenderer(ABC, Generic[_T]):
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(
return embeds_input(
prompt_embeds=prompt_embeds,
cache_salt=prompt.get("cache_salt"),
)
def _process_singleton(
self,
prompt: SingletonTokPrompt,
) -> SingletonInputs:
def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
......@@ -693,7 +688,7 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_enc_dec(
self,
prompt: EncoderDecoderTokPrompt,
) -> EncoderDecoderInputs:
) -> EncoderDecoderInput:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
......@@ -704,27 +699,25 @@ class BaseRenderer(ABC, Generic[_T]):
if isinstance(self.mm_processor, EncDecMultiModalProcessor):
skip_decoder_start_token = self.mm_processor.skip_decoder_start_token
return build_enc_dec_inputs(
encoder_inputs=self._process_singleton(enc_prompt),
decoder_inputs=(
return build_enc_dec_input(
encoder_input=self._process_singleton(enc_prompt),
decoder_input=(
None if dec_prompt is None else self._process_singleton(dec_prompt)
),
decoder_start_token_id=self.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token,
)
def process_for_engine(
self, prompt: TokPrompt, arrival_time: float
) -> ProcessorInputs:
engine_prompt: ProcessorInputs
def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput:
engine_input: EngineInput
if "encoder_prompt" in prompt:
engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type]
engine_input = self._process_enc_dec(prompt) # type: ignore[arg-type]
else:
engine_prompt = self._process_singleton(prompt)
engine_input = self._process_singleton(prompt)
engine_prompt["arrival_time"] = arrival_time
engine_input["arrival_time"] = arrival_time
return engine_prompt
return engine_input
# Top-level methods
def render_cmpl(
......
......@@ -5,7 +5,7 @@ import itertools
from collections import defaultdict, deque
from collections.abc import Set
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast, overload
from typing import Any, Literal, cast, overload
import jinja2
import jinja2.ext
......@@ -25,6 +25,7 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
......@@ -37,13 +38,6 @@ from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
else:
MultiModalDataDict = dict[str, Any]
MultiModalUUIDDict = dict[str, Any]
logger = init_logger(__name__)
......@@ -512,9 +506,9 @@ def safe_apply_chat_template(
def rebuild_mm_uuids_from_mm_data(
mm_uuids: "MultiModalUUIDDict",
mm_data: "MultiModalDataDict",
) -> "MultiModalUUIDDict":
mm_uuids: MultiModalUUIDDict,
mm_data: MultiModalDataDict,
) -> MultiModalUUIDDict:
"""Rebuild mm_uuids after vision_chunk processing.
When videos are split into chunks, the original UUIDs need to be updated
......@@ -547,7 +541,7 @@ def rebuild_mm_uuids_from_mm_data(
def build_video_prompts_from_mm_data(
mm_data: "MultiModalDataDict",
mm_data: MultiModalDataDict,
) -> list[str]:
"""Build video prompts from vision_chunk data.
......@@ -585,7 +579,7 @@ def build_video_prompts_from_mm_data(
def replace_vision_chunk_video_placeholder(
prompt_raw: str | list[int],
mm_data: "MultiModalDataDict",
mm_data: MultiModalDataDict,
video_placeholder: str | None,
) -> str | list[int]:
# get video placeholder, replace it with runtime video-chunk prompts
......
......@@ -9,8 +9,8 @@ from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
from vllm.inputs import (
EmbedsPrompt,
EngineInput,
ExplicitEncoderDecoderPrompt,
ProcessorInputs,
PromptType,
SingletonPrompt,
TextPrompt,
......@@ -70,28 +70,28 @@ def conversation_to_seq(
DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
"""
A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt]
A [`DecoderOnlyPrompt`][vllm.inputs.llm.DecoderOnlyPrompt]
that has been standardized into a dictionary.
"""
EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt]
A [`EncoderPrompt`][vllm.inputs.llm.EncoderPrompt]
that has been standardized into a dictionary.
"""
DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt]
A [`DecoderPrompt`][vllm.inputs.llm.DecoderPrompt]
that has been standardized into a dictionary.
"""
class EncoderDecoderDictPrompt(TypedDict):
"""
A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt]
A [`EncoderDecoderPrompt`][vllm.inputs.llm.EncoderDecoderPrompt]
that has been standardized into a dictionary.
"""
......@@ -104,14 +104,14 @@ SingletonDictPrompt: TypeAlias = (
DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt
)
"""
A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt]
A [`SingletonPrompt`][vllm.inputs.llm.SingletonPrompt]
that has been standardized into a dictionary.
"""
DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt
"""
A [`PromptType`][vllm.inputs.data.PromptType]
A [`PromptType`][vllm.inputs.llm.PromptType]
that has been standardized into a dictionary.
"""
......@@ -236,7 +236,7 @@ def extract_target_prompt(model_config: "ModelConfig", prompt: object):
def extract_prompt_components(
model_config: "ModelConfig",
prompt: PromptType | ProcessorInputs,
prompt: PromptType | EngineInput,
) -> PromptComponents:
target_prompt = extract_target_prompt(model_config, prompt)
......@@ -248,7 +248,8 @@ def extract_prompt_components(
def extract_prompt_len(
model_config: "ModelConfig", prompt: PromptType | ProcessorInputs
model_config: "ModelConfig",
prompt: PromptType | EngineInput,
):
target_prompt = extract_target_prompt(model_config, prompt)
......
......@@ -21,7 +21,7 @@ from vllm.distributed.weight_transfer.base import (
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
from vllm.inputs import ProcessorInputs, PromptType
from vllm.inputs import EngineInput, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
......@@ -139,7 +139,7 @@ class AsyncLLM(EngineClient):
self.model_config.io_processor_plugin,
)
# Convert TokPrompt --> EngineCoreRequest.
# Convert EngineInput --> EngineCoreRequest.
self.input_processor = InputProcessor(self.vllm_config, renderer)
# Converts EngineCoreOutputs --> RequestOutput.
......@@ -290,7 +290,7 @@ class AsyncLLM(EngineClient):
request_id: str,
prompt: EngineCoreRequest
| PromptType
| ProcessorInputs
| EngineInput
| AsyncGenerator[StreamingInput, None],
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
......@@ -530,7 +530,7 @@ class AsyncLLM(EngineClient):
self,
prompt: EngineCoreRequest
| PromptType
| ProcessorInputs
| EngineInput
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
......@@ -776,7 +776,7 @@ class AsyncLLM(EngineClient):
async def encode(
self,
prompt: PromptType | ProcessorInputs,
prompt: PromptType | EngineInput,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,
......
......@@ -7,20 +7,18 @@ from typing import Any, Literal
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.inputs.data import (
ProcessorInputs,
from vllm.inputs import (
EngineInput,
PromptType,
SingletonInputs,
SingletonInput,
split_enc_dec_input,
)
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
)
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.multimodal.utils import argsort_mm_positions
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
......@@ -197,7 +195,7 @@ class InputProcessor:
def process_inputs(
self,
request_id: str,
prompt: PromptType | ProcessorInputs,
prompt: PromptType | EngineInput,
params: SamplingParams | PoolingParams,
supported_tasks: tuple[SupportedTask, ...],
arrival_time: float | None = None,
......@@ -232,7 +230,7 @@ class InputProcessor:
if arrival_time is None:
arrival_time = prompt.get("arrival_time", time.time()) # type: ignore[assignment]
processed_inputs: ProcessorInputs = prompt # type: ignore[assignment]
processed_inputs: EngineInput = prompt # type: ignore[assignment]
else:
logger.warning_once(
"Passing raw prompts to InputProcessor is deprecated "
......@@ -250,7 +248,7 @@ class InputProcessor:
current_platform.validate_request(processed_inputs, params)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
encoder_inputs, decoder_inputs = split_enc_dec_input(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs)
# Mypy can be conservative for TypedDict unions; normalize access.
......@@ -385,7 +383,7 @@ class InputProcessor:
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
prompt_input: SingletonInput,
prompt_type: Literal["encoder", "decoder"],
) -> None:
model_config = self.model_config
......@@ -393,20 +391,18 @@ class InputProcessor:
prompt_ids = (
None
if prompt_inputs["type"] == "embeds"
else prompt_inputs["prompt_token_ids"]
if prompt_input["type"] == "embeds"
else prompt_input["prompt_token_ids"]
)
prompt_embeds = (
prompt_inputs["prompt_embeds"]
if prompt_inputs["type"] == "embeds"
else None
prompt_input["prompt_embeds"] if prompt_input["type"] == "embeds" else None
)
prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
self._validate_prompt_len(prompt_len, prompt_type)
if prompt_inputs["type"] == "multimodal":
decoder_mm_positions = prompt_inputs["mm_placeholders"]
if prompt_input["type"] == "multimodal":
decoder_mm_positions = prompt_input["mm_placeholders"]
for modality, mm_positions in decoder_mm_positions.items():
for mm_position in mm_positions:
embed_length = mm_position.get_num_embeds()
......@@ -439,10 +435,10 @@ class InputProcessor:
def _validate_model_inputs(
self,
encoder_inputs: SingletonInputs | None,
decoder_inputs: SingletonInputs,
encoder_input: SingletonInput | None,
decoder_input: SingletonInput,
):
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs, prompt_type="encoder")
if encoder_input is not None:
self._validate_model_input(encoder_input, prompt_type="encoder")
self._validate_model_input(decoder_inputs, prompt_type="decoder")
self._validate_model_input(decoder_input, prompt_type="decoder")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment