Unverified Commit ba2f0acc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Reorganize inputs (#35182)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 678b3c99
...@@ -23,13 +23,13 @@ from transformers.models.whisper.modeling_whisper import ( ...@@ -23,13 +23,13 @@ from transformers.models.whisper.modeling_whisper import (
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs import MultiModalDataDict
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.model_loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
NestedTensors, NestedTensors,
......
...@@ -19,7 +19,7 @@ from transformers import BatchFeature, WhisperConfig ...@@ -19,7 +19,7 @@ from transformers import BatchFeature, WhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs import MultiModalDataDict, PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -32,7 +32,6 @@ from vllm.model_executor.models.whisper import ( ...@@ -32,7 +32,6 @@ from vllm.model_executor.models.whisper import (
from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
NestedTensors, NestedTensors,
......
...@@ -20,7 +20,7 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -20,7 +20,7 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.engine.protocol import StreamingInput from vllm.engine.protocol import StreamingInput
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs import PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
from vllm.model_executor.models.voxtral import ( from vllm.model_executor.models.voxtral import (
...@@ -31,9 +31,7 @@ from vllm.model_executor.models.voxtral import ( ...@@ -31,9 +31,7 @@ from vllm.model_executor.models.voxtral import (
) )
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import _I, BaseMultiModalProcessorCache from vllm.multimodal.cache import _I, BaseMultiModalProcessorCache
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import MultiModalKwargsOptionalItems
MultiModalKwargsOptionalItems,
)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import BaseDummyInputsBuilder from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import ( from vllm.multimodal.processing.processor import (
......
...@@ -21,7 +21,12 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -21,7 +21,12 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType, TextPrompt from vllm.inputs import (
ExplicitEncoderDecoderPrompt,
MultiModalDataDict,
PromptType,
TextPrompt,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import ( from vllm.model_executor.layers.attention import (
...@@ -44,7 +49,6 @@ from vllm.model_executor.models.whisper_utils import ( ...@@ -44,7 +49,6 @@ from vllm.model_executor.models.whisper_utils import (
) )
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
) )
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
from .inputs import ( from .inputs import BatchedTensorInputs, MultiModalKwargsItems, NestedTensors
BatchedTensorInputs,
ModalityData,
MultiModalDataBuiltins,
MultiModalDataDict,
MultiModalKwargsItems,
MultiModalPlaceholderDict,
MultiModalUUIDDict,
NestedTensors,
)
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry() MULTIMODAL_REGISTRY = MultiModalRegistry()
...@@ -25,13 +16,8 @@ Info: ...@@ -25,13 +16,8 @@ Info:
__all__ = [ __all__ = [
"BatchedTensorInputs", "BatchedTensorInputs",
"ModalityData",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalHasher", "MultiModalHasher",
"MultiModalKwargsItems", "MultiModalKwargsItems",
"MultiModalPlaceholderDict",
"MultiModalUUIDDict",
"NestedTensors", "NestedTensors",
"MULTIMODAL_REGISTRY", "MULTIMODAL_REGISTRY",
"MultiModalRegistry", "MultiModalRegistry",
......
...@@ -15,12 +15,11 @@ from typing import ( ...@@ -15,12 +15,11 @@ from typing import (
TypedDict, TypedDict,
Union, Union,
cast, cast,
final,
) )
import numpy as np import numpy as np
from PIL.Image import Image from PIL.Image import Image
from typing_extensions import NotRequired, TypeVar from typing_extensions import TypeVar
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
...@@ -32,14 +31,9 @@ if TYPE_CHECKING: ...@@ -32,14 +31,9 @@ if TYPE_CHECKING:
import torch import torch
import torch.types import torch.types
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
from vllm.inputs.data import _InputOptions
else: else:
torch = LazyLoader("torch", globals(), "torch") torch = LazyLoader("torch", globals(), "torch")
_InputOptions = dict
_T = TypeVar("_T")
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"] HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
""" """
...@@ -98,15 +92,6 @@ which are treated as audio embeddings; ...@@ -98,15 +92,6 @@ which are treated as audio embeddings;
these are directly passed to the model without HF processing. these are directly passed to the model without HF processing.
""" """
ModalityData: TypeAlias = _T | list[_T | None] | None
"""
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
The number of data items allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""
class VisionChunkImage(TypedDict): class VisionChunkImage(TypedDict):
"""Represents an image wrapped as a vision chunk.""" """Represents an image wrapped as a vision chunk."""
...@@ -126,46 +111,10 @@ class VisionChunkVideo(TypedDict): ...@@ -126,46 +111,10 @@ class VisionChunkVideo(TypedDict):
video_idx: int video_idx: int
VisionChunk = VisionChunkImage | VisionChunkVideo VisionChunk: TypeAlias = VisionChunkImage | VisionChunkVideo
"""A vision chunk is either an image or a video chunk.""" """A vision chunk is either an image or a video chunk."""
@final
class MultiModalDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: ModalityData[ImageItem]
"""The input image(s)."""
video: ModalityData[VideoItem]
"""The input video(s)."""
audio: ModalityData[AudioItem]
"""The input audio(s)."""
vision_chunk: ModalityData[VisionChunk]
"""The input visual atom(s) - unified modality for images and video chunks."""
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
"""
A dictionary containing an entry for each modality type to input.
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
"""
MultiModalUUIDDict: TypeAlias = Mapping[str, Sequence[str | None] | str]
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.
The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""
@dataclass(frozen=True) @dataclass(frozen=True)
class PlaceholderRange: class PlaceholderRange:
""" """
...@@ -1048,112 +997,3 @@ MultiModalKwargsOptionalItems: TypeAlias = ( ...@@ -1048,112 +997,3 @@ MultiModalKwargsOptionalItems: TypeAlias = (
MultiModalKwargsItems[MultiModalKwargsItem] MultiModalKwargsItems[MultiModalKwargsItem]
| MultiModalKwargsItems[MultiModalKwargsItem | None] | MultiModalKwargsItems[MultiModalKwargsItem | None]
) )
MultiModalHashes = dict[str, list[str]]
"""
A dictionary containing per-item hashes for each modality.
"""
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
"""
A dictionary containing per-item placeholder ranges for each modality.
"""
class MultiModalInputs(_InputOptions):
"""
Represents the outputs of
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
ready to be passed to vLLM internals.
"""
type: Literal["multimodal"]
"""The type of inputs."""
prompt_token_ids: list[int]
"""The processed token IDs which includes placeholder tokens."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token IDs, if available."""
mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: MultiModalHashes
"""The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in
`prompt_token_ids`.
"""
def mm_inputs(
prompt_token_ids: list[int],
mm_kwargs: MultiModalKwargsOptionalItems,
mm_hashes: MultiModalHashes,
mm_placeholders: MultiModalPlaceholderDict,
*,
prompt: str | None = None,
cache_salt: str | None = None,
) -> MultiModalInputs:
inputs = MultiModalInputs(
type="multimodal",
prompt_token_ids=prompt_token_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
if prompt is not None:
inputs["prompt"] = prompt
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
class MultiModalEncDecInputs(MultiModalInputs):
"""
Represents the outputs of
[`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
ready to be passed to vLLM internals.
Note: Even text-only encoder-decoder models are currently implemented
as multi-modal models for convenience.
(Example: https://github.com/vllm-project/bart-plugin)
"""
encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt."""
encoder_prompt: NotRequired[str]
"""The prompt text corresponding to the encoder token IDs, if available."""
def mm_enc_dec_inputs(
encoder_inputs: MultiModalInputs,
decoder_prompt_token_ids: list[int],
*,
decoder_prompt: str | None = None,
) -> MultiModalEncDecInputs:
inputs = MultiModalEncDecInputs(
type="multimodal",
prompt_token_ids=decoder_prompt_token_ids,
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
mm_kwargs=encoder_inputs["mm_kwargs"],
mm_hashes=encoder_inputs["mm_hashes"],
mm_placeholders=encoder_inputs["mm_placeholders"],
)
if decoder_prompt is not None:
inputs["prompt"] = decoder_prompt
if "prompt" in encoder_inputs:
inputs["encoder_prompt"] = encoder_inputs["prompt"]
if "cache_salt" in encoder_inputs:
inputs["cache_salt"] = encoder_inputs["cache_salt"]
return inputs
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import torch import torch
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.inputs import ModalityData, MultiModalDataDict, MultiModalUUIDDict
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
...@@ -29,11 +30,8 @@ from .inputs import ( ...@@ -29,11 +30,8 @@ from .inputs import (
HfImageItem, HfImageItem,
HfVideoItem, HfVideoItem,
ImageItem, ImageItem,
ModalityData,
MultiModalDataDict,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict,
VideoItem, VideoItem,
) )
from .media import MediaWithBytes from .media import MediaWithBytes
...@@ -407,8 +405,8 @@ _D = TypeVar("_D", bound=ModalityDataItems[Any, Any]) ...@@ -407,8 +405,8 @@ _D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
""" """
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but A normalized [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict]
normalized such that each entry corresponds to a list. such that each entry corresponds to a list.
""" """
def select(self, modalities: Set[str]): def select(self, modalities: Set[str]):
...@@ -477,7 +475,7 @@ ModalityDataParser: TypeAlias = Callable[ ...@@ -477,7 +475,7 @@ ModalityDataParser: TypeAlias = Callable[
class MultiModalDataParser: class MultiModalDataParser:
""" """
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict] Parses [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict]
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]. into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
Args: Args:
...@@ -695,8 +693,8 @@ class MultiModalDataParser: ...@@ -695,8 +693,8 @@ class MultiModalDataParser:
MultiModalUUIDItems: TypeAlias = dict[str, Sequence[str | None]] MultiModalUUIDItems: TypeAlias = dict[str, Sequence[str | None]]
""" """
As [`MultiModalUUIDDict`][vllm.multimodal.inputs.MultiModalUUIDDict], but A normalized [`MultiModalUUIDDict`][vllm.inputs.MultiModalUUIDDict]
normalized such that each entry corresponds to a list. such that each entry corresponds to a list.
""" """
......
...@@ -11,15 +11,14 @@ from typing import TYPE_CHECKING, Any, overload ...@@ -11,15 +11,14 @@ from typing import TYPE_CHECKING, Any, overload
import torch import torch
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.inputs import MultiModalDataDict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
DictEmbeddingItems, DictEmbeddingItems,
EmbeddingItems, EmbeddingItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalDataParser, MultiModalDataParser,
) )
from vllm.renderers import TokenizeParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
...@@ -32,12 +31,14 @@ if TYPE_CHECKING: ...@@ -32,12 +31,14 @@ if TYPE_CHECKING:
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.renderers import TokenizeParams
else: else:
PretrainedConfig = object PretrainedConfig = object
BatchFeature = object BatchFeature = object
ProcessorMixin = object ProcessorMixin = object
ModelConfig = object ModelConfig = object
TokenizeParams = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -339,6 +340,8 @@ class BaseProcessingInfo: ...@@ -339,6 +340,8 @@ class BaseProcessingInfo:
def get_default_tok_params(self) -> TokenizeParams: def get_default_tok_params(self) -> TokenizeParams:
"""Construct the default parameters for tokenization.""" """Construct the default parameters for tokenization."""
from vllm.renderers import TokenizeParams
model_config = self.ctx.model_config model_config = self.ctx.model_config
encoder_config = model_config.encoder_config or {} encoder_config = model_config.encoder_config or {}
...@@ -451,8 +454,7 @@ class BaseProcessingInfo: ...@@ -451,8 +454,7 @@ class BaseProcessingInfo:
validate: bool = True, validate: bool = True,
) -> MultiModalDataItems: ) -> MultiModalDataItems:
""" """
Normalize Normalize [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict]
[`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems] to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
before passing them to before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
......
...@@ -14,9 +14,9 @@ from vllm.config.multimodal import ( ...@@ -14,9 +14,9 @@ from vllm.config.multimodal import (
ImageDummyOptions, ImageDummyOptions,
VideoDummyOptions, VideoDummyOptions,
) )
from vllm.inputs import MultiModalDataDict
from vllm.logger import init_logger from vllm.logger import init_logger
from ..inputs import MultiModalDataDict
from .context import BaseProcessingInfo from .context import BaseProcessingInfo
from .inputs import ProcessorInputs from .inputs import ProcessorInputs
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from vllm.inputs import MultiModalHashes
from ..hasher import MultiModalHasher from ..hasher import MultiModalHasher
from ..inputs import MultiModalHashes
from ..parse import MultiModalDataItems, MultiModalUUIDItems from ..parse import MultiModalDataItems, MultiModalUUIDItems
...@@ -26,7 +27,7 @@ class ProcessorInputs: ...@@ -26,7 +27,7 @@ class ProcessorInputs:
mm_uuid_items = self.mm_uuid_items or {} mm_uuid_items = self.mm_uuid_items or {}
hf_processor_mm_kwargs = self.hf_processor_mm_kwargs hf_processor_mm_kwargs = self.hf_processor_mm_kwargs
mm_hashes: MultiModalHashes = {} mm_hashes = dict[str, list[str]]()
hasher = MultiModalHasher hasher = MultiModalHasher
for modality, data_items in mm_data_items.items(): for modality, data_items in mm_data_items.items():
......
...@@ -6,34 +6,29 @@ from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, S ...@@ -6,34 +6,29 @@ from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, S
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from typing import ( from typing import TYPE_CHECKING, Generic, NamedTuple, Protocol, TypeAlias, cast
TYPE_CHECKING,
Generic,
NamedTuple,
Protocol,
TypeAlias,
cast,
)
import regex as re import regex as re
import torch import torch
from typing_extensions import TypeVar, assert_never from typing_extensions import TypeVar, assert_never
from vllm.inputs import (
MultiModalEncDecInput,
MultiModalHashes,
MultiModalInput,
mm_enc_dec_input,
mm_input,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from ..inputs import ( from ..inputs import (
MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalHashes,
MultiModalInputs,
MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalKwargsOptionalItems, MultiModalKwargsOptionalItems,
PlaceholderRange, PlaceholderRange,
mm_enc_dec_inputs,
mm_inputs,
) )
from ..parse import ( from ..parse import (
DictEmbeddingItems, DictEmbeddingItems,
...@@ -994,7 +989,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -994,7 +989,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None, mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None, hf_processor_mm_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs: ) -> MultiModalInput:
processor_inputs = ProcessorInputs( processor_inputs = ProcessorInputs(
prompt, prompt,
mm_items, mm_items,
...@@ -1638,7 +1633,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1638,7 +1633,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self, self,
inputs: ProcessorInputs, inputs: ProcessorInputs,
timing_ctx: TimingContext, timing_ctx: TimingContext,
) -> MultiModalInputs: ) -> MultiModalInput:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -1673,7 +1668,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1673,7 +1668,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for modality, placeholders in mm_placeholders.items() for modality, placeholders in mm_placeholders.items()
} }
return mm_inputs( return mm_input(
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_info.kwargs, mm_kwargs=mm_info.kwargs,
mm_hashes=mm_info.hashes, mm_hashes=mm_info.hashes,
...@@ -1708,7 +1703,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1708,7 +1703,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
encoder_inputs: MultiModalInputs, encoder_inputs: MultiModalInput,
): ):
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items) decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items)
...@@ -1721,7 +1716,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1721,7 +1716,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
decoder_prompt_text = None decoder_prompt_text = None
decoder_prompt_ids = decoder_prompt_raw decoder_prompt_ids = decoder_prompt_raw
return mm_enc_dec_inputs( return mm_enc_dec_input(
encoder_inputs, encoder_inputs,
decoder_prompt_ids, decoder_prompt_ids,
decoder_prompt=decoder_prompt_text, decoder_prompt=decoder_prompt_text,
...@@ -1731,7 +1726,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1731,7 +1726,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
self, self,
inputs: ProcessorInputs, inputs: ProcessorInputs,
timing_ctx: TimingContext, timing_ctx: TimingContext,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInput:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model: The main processing steps are modified to fit encoder-decoder model:
......
...@@ -7,6 +7,7 @@ from dataclasses import dataclass ...@@ -7,6 +7,7 @@ from dataclasses import dataclass
from multiprocessing.synchronize import Lock as LockType from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
from vllm.inputs import MultiModalInput
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
...@@ -19,7 +20,6 @@ from .cache import ( ...@@ -19,7 +20,6 @@ from .cache import (
ShmObjectStoreReceiverCache, ShmObjectStoreReceiverCache,
ShmObjectStoreSenderCache, ShmObjectStoreSenderCache,
) )
from .inputs import MultiModalInputs
from .processing import ( from .processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
...@@ -220,7 +220,7 @@ class MultiModalRegistry: ...@@ -220,7 +220,7 @@ class MultiModalRegistry:
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
processor: BaseMultiModalProcessor | None = None, processor: BaseMultiModalProcessor | None = None,
) -> MultiModalInputs: ) -> MultiModalInput:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
......
...@@ -12,6 +12,7 @@ import numpy.typing as npt ...@@ -12,6 +12,7 @@ import numpy.typing as npt
from PIL import Image from PIL import Image
from typing_extensions import deprecated from typing_extensions import deprecated
from vllm.inputs import MultiModalPlaceholders
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
...@@ -19,7 +20,6 @@ from .inputs import ( ...@@ -19,7 +20,6 @@ from .inputs import (
BatchedTensorInputs, BatchedTensorInputs,
MultiModalFieldElem, MultiModalFieldElem,
MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalPlaceholderDict,
MultiModalSharedField, MultiModalSharedField,
) )
from .media import AudioMediaIO, ImageMediaIO, MediaConnector, VideoMediaIO from .media import AudioMediaIO, ImageMediaIO, MediaConnector, VideoMediaIO
...@@ -110,10 +110,10 @@ def encode_video_url( ...@@ -110,10 +110,10 @@ def encode_video_url(
def argsort_mm_positions( def argsort_mm_positions(
mm_positions: MultiModalPlaceholderDict, mm_positions: MultiModalPlaceholders,
) -> list[tuple[str, int]]: ) -> list[tuple[str, int]]:
""" """
Given a `MultiModalPlaceholderDict`, output a sequence of keys to Given a `MultiModalPlaceholders`, output a sequence of keys to
sort the dictionary by `offset` (starting index in the input sequence) sort the dictionary by `offset` (starting index in the input sequence)
in ascending order. in ascending order.
......
...@@ -17,7 +17,7 @@ if TYPE_CHECKING: ...@@ -17,7 +17,7 @@ if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup from torch.distributed import PrefixStore, ProcessGroup
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs from vllm.inputs import EngineInput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -635,7 +635,7 @@ class Platform: ...@@ -635,7 +635,7 @@ class Platform:
@classmethod @classmethod
def validate_request( def validate_request(
cls, cls,
processed_inputs: "ProcessorInputs", processed_inputs: "EngineInput",
params: "SamplingParams | PoolingParams", params: "SamplingParams | PoolingParams",
) -> None: ) -> None:
"""Raises if this request is unsupported on this platform""" """Raises if this request is unsupported on this platform"""
......
...@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Sequence ...@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Sequence
from typing import Generic, TypeVar from typing import Generic, TypeVar
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs.data import PromptType from vllm.inputs import PromptType
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
......
...@@ -11,17 +11,32 @@ from typing import TYPE_CHECKING, Any, Generic, overload ...@@ -11,17 +11,32 @@ from typing import TYPE_CHECKING, Any, Generic, overload
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.inputs import ( from vllm.inputs import (
EmbedsInputs, EmbedsInput,
EmbedsPrompt, EmbedsPrompt,
EncoderDecoderInputs, EncoderDecoderInput,
ProcessorInputs, EngineInput,
SingletonInputs, MultiModalDataDict,
MultiModalInput,
MultiModalUUIDDict,
SingletonInput,
TextPrompt, TextPrompt,
TokenInputs, TokensInput,
TokensPrompt, TokensPrompt,
build_enc_dec_input,
embeds_input,
tokens_input,
) )
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.parse import (
MultiModalDataItems,
MultiModalUUIDItems,
parse_mm_uuids,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
from vllm.multimodal.registry import MultiModalTimingRegistry
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.counter import AtomicCounter from vllm.utils.counter import AtomicCounter
...@@ -46,14 +61,6 @@ if TYPE_CHECKING: ...@@ -46,14 +61,6 @@ if TYPE_CHECKING:
ChatCompletionMessageParam, ChatCompletionMessageParam,
ConversationMessage, ConversationMessage,
) )
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalInputs,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalUUIDItems
from vllm.multimodal.processing import BaseMultiModalProcessor
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -86,9 +93,6 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -86,9 +93,6 @@ class BaseRenderer(ABC, Generic[_T]):
self.mm_processor: BaseMultiModalProcessor | None = None self.mm_processor: BaseMultiModalProcessor | None = None
self._mm_cache_stats: MultiModalCacheStats | None = None self._mm_cache_stats: MultiModalCacheStats | None = None
if config.model_config.is_multimodal_model: if config.model_config.is_multimodal_model:
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
from vllm.multimodal.registry import MultiModalTimingRegistry
mm_processor_cache = mm_registry.processor_cache_from_config(config) mm_processor_cache = mm_registry.processor_cache_from_config(config)
# Deep-copy the tokenizer so the multimodal processor gets its # Deep-copy the tokenizer so the multimodal processor gets its
...@@ -524,9 +528,9 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -524,9 +528,9 @@ class BaseRenderer(ABC, Generic[_T]):
# Step 4: Convert to engine inputs # Step 4: Convert to engine inputs
def _validate_mm_uuids( def _validate_mm_uuids(
self, self,
mm_data: "MultiModalDataDict", mm_data: MultiModalDataDict,
mm_data_items: "MultiModalDataItems", mm_data_items: MultiModalDataItems,
mm_uuid_items: "MultiModalUUIDItems", mm_uuid_items: MultiModalUUIDItems,
) -> None: ) -> None:
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in # NOTE: Keys corresponding to `None` in `mm_data` don't appear in
# `mm_data_items` # `mm_data_items`
...@@ -560,11 +564,11 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -560,11 +564,11 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_mm_uuids( def _process_mm_uuids(
self, self,
mm_data: "MultiModalDataDict", mm_data: MultiModalDataDict,
mm_data_items: "MultiModalDataItems", mm_data_items: MultiModalDataItems,
mm_uuid_items: "MultiModalUUIDItems", mm_uuid_items: MultiModalUUIDItems,
mm_req_id: str, mm_req_id: str,
): ) -> MultiModalUUIDItems:
model_config = self.model_config model_config = self.model_config
# NOTE: When users explicitly turn off BOTH prefix caching and input # NOTE: When users explicitly turn off BOTH prefix caching and input
...@@ -590,14 +594,11 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -590,14 +594,11 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_multimodal( def _process_multimodal(
self, self,
prompt: list[int] | str, prompt: list[int] | str,
mm_data: "MultiModalDataDict", mm_data: MultiModalDataDict,
mm_uuids: "MultiModalUUIDDict | None", mm_uuids: MultiModalUUIDDict | None,
mm_processor_kwargs: Mapping[str, object] | None, mm_processor_kwargs: Mapping[str, object] | None,
tokenization_kwargs: dict[str, Any] | None, tokenization_kwargs: dict[str, Any] | None,
) -> "MultiModalInputs": ) -> "MultiModalInput":
from vllm.multimodal.parse import parse_mm_uuids
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}" mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}"
mm_processor = self.get_mm_processor() mm_processor = self.get_mm_processor()
...@@ -628,12 +629,12 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -628,12 +629,12 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_tokens( def _process_tokens(
self, self,
prompt: TokensPrompt, prompt: TokensPrompt,
) -> "TokenInputs | MultiModalInputs": ) -> TokensInput | MultiModalInput:
prompt_token_ids = prompt["prompt_token_ids"] prompt_token_ids = prompt["prompt_token_ids"]
inputs: TokenInputs | MultiModalInputs engine_input: TokensInput | MultiModalInput
if multi_modal_data := prompt.get("multi_modal_data"): if multi_modal_data := prompt.get("multi_modal_data"):
inputs = self._process_multimodal( engine_input = self._process_multimodal(
prompt_token_ids, prompt_token_ids,
multi_modal_data, multi_modal_data,
mm_processor_kwargs=prompt.get("mm_processor_kwargs"), mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
...@@ -641,19 +642,16 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -641,19 +642,16 @@ class BaseRenderer(ABC, Generic[_T]):
mm_uuids=prompt.get("multi_modal_uuids"), mm_uuids=prompt.get("multi_modal_uuids"),
) )
else: else:
inputs = token_inputs(prompt_token_ids) engine_input = tokens_input(prompt_token_ids)
if prompt_text := prompt.get("prompt"): if prompt_text := prompt.get("prompt"):
inputs["prompt"] = prompt_text engine_input["prompt"] = prompt_text
if cache_salt := prompt.get("cache_salt"): if cache_salt := prompt.get("cache_salt"):
inputs["cache_salt"] = cache_salt engine_input["cache_salt"] = cache_salt
return inputs return engine_input
def _process_embeds( def _process_embeds(self, prompt: EmbedsPrompt) -> EmbedsInput:
self,
prompt: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds: if not self.model_config.enable_prompt_embeds:
raise ValueError( raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`." "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
...@@ -676,15 +674,12 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -676,15 +674,12 @@ class BaseRenderer(ABC, Generic[_T]):
# hidden device transfer in the critical path of generation. # hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu() prompt_embeds = prompt_embeds.cpu()
return embeds_inputs( return embeds_input(
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
cache_salt=prompt.get("cache_salt"), cache_salt=prompt.get("cache_salt"),
) )
def _process_singleton( def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput:
self,
prompt: SingletonTokPrompt,
) -> SingletonInputs:
if "prompt_embeds" in prompt: if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type] return self._process_embeds(prompt) # type: ignore[arg-type]
...@@ -693,7 +688,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -693,7 +688,7 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_enc_dec( def _process_enc_dec(
self, self,
prompt: EncoderDecoderTokPrompt, prompt: EncoderDecoderTokPrompt,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInput:
enc_prompt = prompt["encoder_prompt"] enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"] dec_prompt = prompt["decoder_prompt"]
...@@ -704,27 +699,25 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -704,27 +699,25 @@ class BaseRenderer(ABC, Generic[_T]):
if isinstance(self.mm_processor, EncDecMultiModalProcessor): if isinstance(self.mm_processor, EncDecMultiModalProcessor):
skip_decoder_start_token = self.mm_processor.skip_decoder_start_token skip_decoder_start_token = self.mm_processor.skip_decoder_start_token
return build_enc_dec_inputs( return build_enc_dec_input(
encoder_inputs=self._process_singleton(enc_prompt), encoder_input=self._process_singleton(enc_prompt),
decoder_inputs=( decoder_input=(
None if dec_prompt is None else self._process_singleton(dec_prompt) None if dec_prompt is None else self._process_singleton(dec_prompt)
), ),
decoder_start_token_id=self.get_dec_start_token_id(), decoder_start_token_id=self.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token, skip_decoder_start_token=skip_decoder_start_token,
) )
def process_for_engine( def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput:
self, prompt: TokPrompt, arrival_time: float engine_input: EngineInput
) -> ProcessorInputs:
engine_prompt: ProcessorInputs
if "encoder_prompt" in prompt: if "encoder_prompt" in prompt:
engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type] engine_input = self._process_enc_dec(prompt) # type: ignore[arg-type]
else: else:
engine_prompt = self._process_singleton(prompt) engine_input = self._process_singleton(prompt)
engine_prompt["arrival_time"] = arrival_time engine_input["arrival_time"] = arrival_time
return engine_prompt return engine_input
# Top-level methods # Top-level methods
def render_cmpl( def render_cmpl(
......
...@@ -5,7 +5,7 @@ import itertools ...@@ -5,7 +5,7 @@ import itertools
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Set from collections.abc import Set
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast, overload from typing import Any, Literal, cast, overload
import jinja2 import jinja2
import jinja2.ext import jinja2.ext
...@@ -25,6 +25,7 @@ from vllm.entrypoints.chat_utils import ( ...@@ -25,6 +25,7 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages, parse_chat_messages,
parse_chat_messages_async, parse_chat_messages_async,
) )
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
...@@ -37,13 +38,6 @@ from .inputs import DictPrompt ...@@ -37,13 +38,6 @@ from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams from .params import ChatParams
if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
else:
MultiModalDataDict = dict[str, Any]
MultiModalUUIDDict = dict[str, Any]
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -512,9 +506,9 @@ def safe_apply_chat_template( ...@@ -512,9 +506,9 @@ def safe_apply_chat_template(
def rebuild_mm_uuids_from_mm_data( def rebuild_mm_uuids_from_mm_data(
mm_uuids: "MultiModalUUIDDict", mm_uuids: MultiModalUUIDDict,
mm_data: "MultiModalDataDict", mm_data: MultiModalDataDict,
) -> "MultiModalUUIDDict": ) -> MultiModalUUIDDict:
"""Rebuild mm_uuids after vision_chunk processing. """Rebuild mm_uuids after vision_chunk processing.
When videos are split into chunks, the original UUIDs need to be updated When videos are split into chunks, the original UUIDs need to be updated
...@@ -547,7 +541,7 @@ def rebuild_mm_uuids_from_mm_data( ...@@ -547,7 +541,7 @@ def rebuild_mm_uuids_from_mm_data(
def build_video_prompts_from_mm_data( def build_video_prompts_from_mm_data(
mm_data: "MultiModalDataDict", mm_data: MultiModalDataDict,
) -> list[str]: ) -> list[str]:
"""Build video prompts from vision_chunk data. """Build video prompts from vision_chunk data.
...@@ -585,7 +579,7 @@ def build_video_prompts_from_mm_data( ...@@ -585,7 +579,7 @@ def build_video_prompts_from_mm_data(
def replace_vision_chunk_video_placeholder( def replace_vision_chunk_video_placeholder(
prompt_raw: str | list[int], prompt_raw: str | list[int],
mm_data: "MultiModalDataDict", mm_data: MultiModalDataDict,
video_placeholder: str | None, video_placeholder: str | None,
) -> str | list[int]: ) -> str | list[int]:
# get video placeholder, replace it with runtime video-chunk prompts # get video placeholder, replace it with runtime video-chunk prompts
......
...@@ -9,8 +9,8 @@ from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload ...@@ -9,8 +9,8 @@ from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
from vllm.inputs import ( from vllm.inputs import (
EmbedsPrompt, EmbedsPrompt,
EngineInput,
ExplicitEncoderDecoderPrompt, ExplicitEncoderDecoderPrompt,
ProcessorInputs,
PromptType, PromptType,
SingletonPrompt, SingletonPrompt,
TextPrompt, TextPrompt,
...@@ -70,28 +70,28 @@ def conversation_to_seq( ...@@ -70,28 +70,28 @@ def conversation_to_seq(
DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
""" """
A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt] A [`DecoderOnlyPrompt`][vllm.inputs.llm.DecoderOnlyPrompt]
that has been standardized into a dictionary. that has been standardized into a dictionary.
""" """
EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
""" """
A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt] A [`EncoderPrompt`][vllm.inputs.llm.EncoderPrompt]
that has been standardized into a dictionary. that has been standardized into a dictionary.
""" """
DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
""" """
A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt] A [`DecoderPrompt`][vllm.inputs.llm.DecoderPrompt]
that has been standardized into a dictionary. that has been standardized into a dictionary.
""" """
class EncoderDecoderDictPrompt(TypedDict): class EncoderDecoderDictPrompt(TypedDict):
""" """
A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt] A [`EncoderDecoderPrompt`][vllm.inputs.llm.EncoderDecoderPrompt]
that has been standardized into a dictionary. that has been standardized into a dictionary.
""" """
...@@ -104,14 +104,14 @@ SingletonDictPrompt: TypeAlias = ( ...@@ -104,14 +104,14 @@ SingletonDictPrompt: TypeAlias = (
DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt
) )
""" """
A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] A [`SingletonPrompt`][vllm.inputs.llm.SingletonPrompt]
that has been standardized into a dictionary. that has been standardized into a dictionary.
""" """
DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt
""" """
A [`PromptType`][vllm.inputs.data.PromptType] A [`PromptType`][vllm.inputs.llm.PromptType]
that has been standardized into a dictionary. that has been standardized into a dictionary.
""" """
...@@ -236,7 +236,7 @@ def extract_target_prompt(model_config: "ModelConfig", prompt: object): ...@@ -236,7 +236,7 @@ def extract_target_prompt(model_config: "ModelConfig", prompt: object):
def extract_prompt_components( def extract_prompt_components(
model_config: "ModelConfig", model_config: "ModelConfig",
prompt: PromptType | ProcessorInputs, prompt: PromptType | EngineInput,
) -> PromptComponents: ) -> PromptComponents:
target_prompt = extract_target_prompt(model_config, prompt) target_prompt = extract_target_prompt(model_config, prompt)
...@@ -248,7 +248,8 @@ def extract_prompt_components( ...@@ -248,7 +248,8 @@ def extract_prompt_components(
def extract_prompt_len( def extract_prompt_len(
model_config: "ModelConfig", prompt: PromptType | ProcessorInputs model_config: "ModelConfig",
prompt: PromptType | EngineInput,
): ):
target_prompt = extract_target_prompt(model_config, prompt) target_prompt = extract_target_prompt(model_config, prompt)
......
...@@ -21,7 +21,7 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -21,7 +21,7 @@ from vllm.distributed.weight_transfer.base import (
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import EngineInput, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...@@ -139,7 +139,7 @@ class AsyncLLM(EngineClient): ...@@ -139,7 +139,7 @@ class AsyncLLM(EngineClient):
self.model_config.io_processor_plugin, self.model_config.io_processor_plugin,
) )
# Convert TokPrompt --> EngineCoreRequest. # Convert EngineInput --> EngineCoreRequest.
self.input_processor = InputProcessor(self.vllm_config, renderer) self.input_processor = InputProcessor(self.vllm_config, renderer)
# Converts EngineCoreOutputs --> RequestOutput. # Converts EngineCoreOutputs --> RequestOutput.
...@@ -290,7 +290,7 @@ class AsyncLLM(EngineClient): ...@@ -290,7 +290,7 @@ class AsyncLLM(EngineClient):
request_id: str, request_id: str,
prompt: EngineCoreRequest prompt: EngineCoreRequest
| PromptType | PromptType
| ProcessorInputs | EngineInput
| AsyncGenerator[StreamingInput, None], | AsyncGenerator[StreamingInput, None],
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
arrival_time: float | None = None, arrival_time: float | None = None,
...@@ -530,7 +530,7 @@ class AsyncLLM(EngineClient): ...@@ -530,7 +530,7 @@ class AsyncLLM(EngineClient):
self, self,
prompt: EngineCoreRequest prompt: EngineCoreRequest
| PromptType | PromptType
| ProcessorInputs | EngineInput
| AsyncGenerator[StreamingInput, None], | AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
...@@ -776,7 +776,7 @@ class AsyncLLM(EngineClient): ...@@ -776,7 +776,7 @@ class AsyncLLM(EngineClient):
async def encode( async def encode(
self, self,
prompt: PromptType | ProcessorInputs, prompt: PromptType | EngineInput,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
......
...@@ -7,20 +7,18 @@ from typing import Any, Literal ...@@ -7,20 +7,18 @@ from typing import Any, Literal
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs.data import ( from vllm.inputs import (
ProcessorInputs, EngineInput,
PromptType, PromptType,
SingletonInputs, SingletonInput,
split_enc_dec_input,
) )
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.encoder_budget import MultiModalBudget from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import MultiModalFeatureSpec
MultiModalFeatureSpec,
)
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -197,7 +195,7 @@ class InputProcessor: ...@@ -197,7 +195,7 @@ class InputProcessor:
def process_inputs( def process_inputs(
self, self,
request_id: str, request_id: str,
prompt: PromptType | ProcessorInputs, prompt: PromptType | EngineInput,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
supported_tasks: tuple[SupportedTask, ...], supported_tasks: tuple[SupportedTask, ...],
arrival_time: float | None = None, arrival_time: float | None = None,
...@@ -232,7 +230,7 @@ class InputProcessor: ...@@ -232,7 +230,7 @@ class InputProcessor:
if arrival_time is None: if arrival_time is None:
arrival_time = prompt.get("arrival_time", time.time()) # type: ignore[assignment] arrival_time = prompt.get("arrival_time", time.time()) # type: ignore[assignment]
processed_inputs: ProcessorInputs = prompt # type: ignore[assignment] processed_inputs: EngineInput = prompt # type: ignore[assignment]
else: else:
logger.warning_once( logger.warning_once(
"Passing raw prompts to InputProcessor is deprecated " "Passing raw prompts to InputProcessor is deprecated "
...@@ -250,7 +248,7 @@ class InputProcessor: ...@@ -250,7 +248,7 @@ class InputProcessor:
current_platform.validate_request(processed_inputs, params) current_platform.validate_request(processed_inputs, params)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_input(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs) self._validate_model_inputs(encoder_inputs, decoder_inputs)
# Mypy can be conservative for TypedDict unions; normalize access. # Mypy can be conservative for TypedDict unions; normalize access.
...@@ -385,7 +383,7 @@ class InputProcessor: ...@@ -385,7 +383,7 @@ class InputProcessor:
def _validate_model_input( def _validate_model_input(
self, self,
prompt_inputs: SingletonInputs, prompt_input: SingletonInput,
prompt_type: Literal["encoder", "decoder"], prompt_type: Literal["encoder", "decoder"],
) -> None: ) -> None:
model_config = self.model_config model_config = self.model_config
...@@ -393,20 +391,18 @@ class InputProcessor: ...@@ -393,20 +391,18 @@ class InputProcessor:
prompt_ids = ( prompt_ids = (
None None
if prompt_inputs["type"] == "embeds" if prompt_input["type"] == "embeds"
else prompt_inputs["prompt_token_ids"] else prompt_input["prompt_token_ids"]
) )
prompt_embeds = ( prompt_embeds = (
prompt_inputs["prompt_embeds"] prompt_input["prompt_embeds"] if prompt_input["type"] == "embeds" else None
if prompt_inputs["type"] == "embeds"
else None
) )
prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds) prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds)
self._validate_prompt_len(prompt_len, prompt_type) self._validate_prompt_len(prompt_len, prompt_type)
if prompt_inputs["type"] == "multimodal": if prompt_input["type"] == "multimodal":
decoder_mm_positions = prompt_inputs["mm_placeholders"] decoder_mm_positions = prompt_input["mm_placeholders"]
for modality, mm_positions in decoder_mm_positions.items(): for modality, mm_positions in decoder_mm_positions.items():
for mm_position in mm_positions: for mm_position in mm_positions:
embed_length = mm_position.get_num_embeds() embed_length = mm_position.get_num_embeds()
...@@ -439,10 +435,10 @@ class InputProcessor: ...@@ -439,10 +435,10 @@ class InputProcessor:
def _validate_model_inputs( def _validate_model_inputs(
self, self,
encoder_inputs: SingletonInputs | None, encoder_input: SingletonInput | None,
decoder_inputs: SingletonInputs, decoder_input: SingletonInput,
): ):
if encoder_inputs is not None: if encoder_input is not None:
self._validate_model_input(encoder_inputs, prompt_type="encoder") self._validate_model_input(encoder_input, prompt_type="encoder")
self._validate_model_input(decoder_inputs, prompt_type="decoder") self._validate_model_input(decoder_input, prompt_type="decoder")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment