Unverified Commit 27e8d1ea authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 5c79b0d6
...@@ -23,7 +23,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys ...@@ -23,7 +23,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors) MultiModalKwargsItems, NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
...@@ -194,7 +194,7 @@ class UltravoxMultiModalProcessor( ...@@ -194,7 +194,7 @@ class UltravoxMultiModalProcessor(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
...@@ -203,7 +203,8 @@ class UltravoxMultiModalProcessor( ...@@ -203,7 +203,8 @@ class UltravoxMultiModalProcessor(
# Each audio can be split into multiple chunks. # Each audio can be split into multiple chunks.
# chunks_start_idx[i] indicates the start index of the chunks # chunks_start_idx[i] indicates the start index of the chunks
# belonging to the i-th audio. # belonging to the i-th audio.
num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0)) out_mm_data = out_mm_kwargs.get_data()
num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks, chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks,
dim=0, dim=0,
dtype=torch.int32) dtype=torch.int32)
...@@ -213,7 +214,7 @@ class UltravoxMultiModalProcessor( ...@@ -213,7 +214,7 @@ class UltravoxMultiModalProcessor(
def get_replacement_ultravox(item_idx: int): def get_replacement_ultravox(item_idx: int):
start = chunks_start_idx[item_idx] start = chunks_start_idx[item_idx]
end = chunks_start_idx[item_idx + 1] end = chunks_start_idx[item_idx + 1]
audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum() audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
return [replacement_id] * int(audio_token_len) # type: ignore return [replacement_id] * int(audio_token_len) # type: ignore
return [ return [
......
...@@ -31,7 +31,7 @@ from vllm.model_executor.models.whisper import WhisperEncoder ...@@ -31,7 +31,7 @@ from vllm.model_executor.models.whisper import WhisperEncoder
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors) MultiModalKwargsItems, NestedTensors)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -259,7 +259,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] ...@@ -259,7 +259,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
...@@ -289,7 +289,8 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] ...@@ -289,7 +289,8 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: ) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super( prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor( )._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
......
...@@ -33,7 +33,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -33,7 +33,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs) MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
...@@ -728,7 +728,7 @@ class WhisperMultiModalProcessor( ...@@ -728,7 +728,7 @@ class WhisperMultiModalProcessor(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
num_tokens = self.info.get_num_audio_tokens() num_tokens = self.info.get_num_audio_tokens()
return [ return [
......
...@@ -4,7 +4,8 @@ from .base import MultiModalPlaceholderMap ...@@ -4,7 +4,8 @@ from .base import MultiModalPlaceholderMap
from .hasher import MultiModalHashDict, MultiModalHasher from .hasher import MultiModalHashDict, MultiModalHasher
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalKwargs, MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict, NestedTensors) MultiModalKwargsItems, MultiModalPlaceholderDict,
NestedTensors)
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry() MULTIMODAL_REGISTRY = MultiModalRegistry()
...@@ -25,6 +26,7 @@ __all__ = [ ...@@ -25,6 +26,7 @@ __all__ = [
"MultiModalHashDict", "MultiModalHashDict",
"MultiModalHasher", "MultiModalHasher",
"MultiModalKwargs", "MultiModalKwargs",
"MultiModalKwargsItems",
"MultiModalPlaceholderDict", "MultiModalPlaceholderDict",
"MultiModalPlaceholderMap", "MultiModalPlaceholderMap",
"NestedTensors", "NestedTensors",
......
...@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar ...@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
from .inputs import MultiModalKwargs, NestedTensors, PlaceholderRange from .inputs import MultiModalKwargs, PlaceholderRange
_T = TypeVar("_T") _T = TypeVar("_T")
...@@ -56,8 +56,7 @@ class MultiModalPlaceholderMap: ...@@ -56,8 +56,7 @@ class MultiModalPlaceholderMap:
@classmethod @classmethod
def from_seq_group( def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range cls, seq_group: "SequenceGroupMetadata", positions: range
) -> tuple[dict[str, NestedTensors], dict[str, ) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]:
"MultiModalPlaceholderMap"]]:
""" """
Returns the multi-modal items that intersect with the portion of a Returns the multi-modal items that intersect with the portion of a
prompt (``seq_group``) represented by ``positions``, as well as a prompt (``seq_group``) represented by ``positions``, as well as a
...@@ -100,7 +99,7 @@ class MultiModalPlaceholderMap: ...@@ -100,7 +99,7 @@ class MultiModalPlaceholderMap:
seq_mm_placeholders = seq_group.multi_modal_placeholders seq_mm_placeholders = seq_group.multi_modal_placeholders
if not seq_mm_data or not seq_mm_placeholders: if not seq_mm_data or not seq_mm_placeholders:
return MultiModalKwargs().get_data(), {} return MultiModalKwargs(), {}
placeholder_maps = dict[str, MultiModalPlaceholderMap]() placeholder_maps = dict[str, MultiModalPlaceholderMap]()
...@@ -117,8 +116,6 @@ class MultiModalPlaceholderMap: ...@@ -117,8 +116,6 @@ class MultiModalPlaceholderMap:
placeholder_maps[modality] = placeholder_map placeholder_maps[modality] = placeholder_map
seq_mm_data = seq_mm_data if isinstance(
seq_mm_data, dict) else seq_mm_data.get_data()
return seq_mm_data, placeholder_maps return seq_mm_data, placeholder_maps
def append_items_from_seq_group( def append_items_from_seq_group(
......
...@@ -11,7 +11,9 @@ from vllm.logger import init_logger ...@@ -11,7 +11,9 @@ from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache from vllm.utils import GiB_bytes, LRUCache
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
from .inputs import MultiModalKwargs, MultiModalKwargsItem, NestedTensors from .inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItems,
NestedTensors)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -26,8 +28,9 @@ class MultiModalCacheItemMetadata: ...@@ -26,8 +28,9 @@ class MultiModalCacheItemMetadata:
MultiModalCacheValue = Union[ MultiModalCacheValue = Union[
MultiModalKwargs, MultiModalKwargsItems,
MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargs,
Mapping[str, NestedTensors], Mapping[str, NestedTensors],
MultiModalCacheItemMetadata, MultiModalCacheItemMetadata,
] ]
...@@ -44,14 +47,16 @@ class MultiModalCache: ...@@ -44,14 +47,16 @@ class MultiModalCache:
*, *,
debug: bool = False, debug: bool = False,
) -> int: ) -> int:
# MultiModalKwargs is not a subclass of dict if isinstance(leaf, MultiModalFieldElem):
if isinstance(leaf, MultiModalKwargs): return cls.get_item_size(leaf.data) # type: ignore
return cls.get_item_size(leaf.get_data(), debug=debug)
# MultiModalKwargsItem is not a subclass of dict # These are not subclasses of dict
if isinstance(leaf, MultiModalKwargsItems):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalKwargsItem): if isinstance(leaf, MultiModalKwargsItem):
leaf_data = {k: v.data for k, v in leaf.items()} return cls.get_item_size(leaf.data) # type: ignore
return cls.get_item_size(leaf_data, debug=debug) if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data) # type: ignore
# sys.getsizeof doesn't work for tensors # sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor): if isinstance(leaf, torch.Tensor):
......
...@@ -11,7 +11,7 @@ from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, ...@@ -11,7 +11,7 @@ from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final) Union, cast, final)
import numpy as np import numpy as np
from typing_extensions import NotRequired, TypeAlias from typing_extensions import NotRequired, TypeAlias, deprecated
from vllm.utils import LazyLoader, full_groupby, is_list_of from vllm.utils import LazyLoader, full_groupby, is_list_of
from vllm.utils.jsontree import JSONTree, json_map_leaves from vllm.utils.jsontree import JSONTree, json_map_leaves
...@@ -656,7 +656,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): ...@@ -656,7 +656,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None: def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
super().__init__(data) super().__init__(data)
modalities = {elem.modality for elem in self.data.values()} modalities = {elem.modality for elem in self.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}" assert len(modalities) == 1, f"Found different modalities={modalities}"
self._modality = next(iter(modalities)) self._modality = next(iter(modalities))
...@@ -668,16 +668,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): ...@@ -668,16 +668,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
return {key: elem.data for key, elem in self.items()} return {key: elem.data for key, elem in self.items()}
class MultiModalKwargs: class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]):
""" """
A dictionary that represents the keyword arguments to A dictionary of
[`torch.nn.Module.forward`][]. [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
by modality.
The metadata `items` enables us to obtain the keyword arguments
corresponding to each data item in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via
[`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and
[`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items].
""" """
@staticmethod @staticmethod
...@@ -712,19 +707,64 @@ class MultiModalKwargs: ...@@ -712,19 +707,64 @@ class MultiModalKwargs:
elems = [v[item_idx] for v in elems_in_modality.values()] elems = [v[item_idx] for v in elems_in_modality.values()]
items.append(MultiModalKwargsItem.from_elems(elems)) items.append(MultiModalKwargsItem.from_elems(elems))
return MultiModalKwargs(items) return MultiModalKwargsItems.from_seq(items)
def __init__(self, items: Sequence[MultiModalKwargsItem] = ()) -> None:
super().__init__()
@staticmethod
def from_seq(items: Sequence[MultiModalKwargsItem]):
items_by_modality = full_groupby(items, key=lambda x: x.modality) items_by_modality = full_groupby(items, key=lambda x: x.modality)
self._items_by_modality = dict(items_by_modality) return MultiModalKwargsItems(items_by_modality)
self._data: Optional[dict[str, NestedTensors]] = None def __getitem__(self, modality: str):
if modality not in self:
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {set(self.keys())}")
return super().__getitem__(modality)
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for items in self.values():
for item in items:
for key, elem in item.items():
elems_by_key[key].append(elem)
return MultiModalKwargs({
key:
elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0
})
@property
def modalities(self): class MultiModalKwargs(UserDict[str, NestedTensors]):
return self._items_by_modality.keys() """
A dictionary that represents the keyword arguments to
[`torch.nn.Module.forward`][].
"""
@staticmethod
@deprecated("`MultiModalKwargs.from_hf_inputs` is deprecated and "
"will be removed in v0.13. "
"Please use `MultiModalKwargsItems.from_hf_inputs` and "
"access the tensor data using `.get_data()`.")
def from_hf_inputs(
hf_inputs: "BatchFeature",
config_by_key: Mapping[str, MultiModalFieldConfig],
):
return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key) \
.get_data()
@staticmethod
@deprecated("`MultiModalKwargs.from_items` is deprecated and "
"will be removed in v0.13. "
"Please use `MultiModalKwargsItems.from_seq` and "
"access the tensor data using `.get_data()`.")
def from_items(
items: Sequence[MultiModalKwargsItem],
*,
pin_memory: bool = False,
):
return MultiModalKwargsItems.from_seq(items) \
.get_data(pin_memory=pin_memory)
@staticmethod @staticmethod
def _try_stack(nested_tensors: NestedTensors, def _try_stack(nested_tensors: NestedTensors,
...@@ -813,92 +853,24 @@ class MultiModalKwargs: ...@@ -813,92 +853,24 @@ class MultiModalKwargs:
return cast(BatchedTensorInputs, json_mapped) return cast(BatchedTensorInputs, json_mapped)
def keys(self):
return self.get_data().keys()
def values(self):
return self.get_data().values()
def items(self):
return self.get_data().items()
def get(self, key: str, /, default=None):
return self.get_data().get(key, default)
def pop(self, key: str, *args, **kwargs):
data = dict(self.get_data())
res = data.pop(key, *args, **kwargs)
for items in self._items_by_modality.values():
for item in items:
item.pop(key, *args, **kwargs)
self._data = None
return res
def __iter__(self):
return iter(self.get_data())
def __getitem__(self, key: str): def __getitem__(self, key: str):
return self.get_data()[key] if key not in self:
raise KeyError(f"Keyword argument {key!r} not found. "
f"Available keys: {set(self.keys())}")
return super().__getitem__(key)
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False
return self._items_by_modality == other._items_by_modality for k in self:
if k not in other:
def _validate_modality(self, method_name: str, modality: str) -> None: return False
if not self._items_by_modality: if not nested_tensors_equal(self[k], other[k]):
raise RuntimeError( return False
f"`{method_name}` is not supported when "
"MultiModalKwargs is not initialized with `items`")
if modality not in self._items_by_modality:
available_modalities = set(self._items_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
def get_item_count(self, modality: str) -> int:
"""Get the number of items belonging to a modality."""
self._validate_modality("get_item_count", modality)
return len(self._items_by_modality[modality])
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
"""
Get the keyword arguments corresponding to an item identified by
its modality and index.
"""
self._validate_modality("get_item", modality)
return self._items_by_modality[modality][item_index]
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
"""
Get the keyword arguments corresponding to each item belonging to
a modality.
"""
self._validate_modality("get_items", modality)
return self._items_by_modality[modality]
def get_data(self,
*,
pin_memory: bool = False) -> dict[str, NestedTensors]:
if self._data is not None:
return self._data
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) return True
for items in self._items_by_modality.values():
for item in items:
for key, elem in item.items():
elems_by_key[key].append(elem)
data = {
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0
}
self._data = data
return data
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]] MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
...@@ -926,7 +898,7 @@ class MultiModalInputs(TypedDict): ...@@ -926,7 +898,7 @@ class MultiModalInputs(TypedDict):
token_type_ids: NotRequired[list[int]] token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt.""" """The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargs mm_kwargs: MultiModalKwargsItems
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
mm_hashes: Optional["MultiModalHashDict"] mm_hashes: Optional["MultiModalHashDict"]
......
...@@ -16,7 +16,7 @@ from vllm.utils import LazyLoader, is_list_of ...@@ -16,7 +16,7 @@ from vllm.utils import LazyLoader, is_list_of
from .audio import AudioResampler from .audio import AudioResampler
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict, ImageItem, ModalityData, MultiModalDataDict,
MultiModalFieldConfig, MultiModalKwargs, VideoItem) MultiModalFieldConfig, MultiModalKwargsItems, VideoItem)
_T = TypeVar("_T") _T = TypeVar("_T")
_I = TypeVar("_I") _I = TypeVar("_I")
...@@ -157,19 +157,16 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], ...@@ -157,19 +157,16 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
self.fields_config = fields_config self.fields_config = fields_config
self.required_fields = required_fields self.required_fields = required_fields
self._kwargs = MultiModalKwargs.from_hf_inputs( self._kwargs = MultiModalKwargsItems.from_hf_inputs(
BatchFeature(dict(data)), BatchFeature(dict(data)),
fields_config, fields_config,
) )
def get_count(self) -> int: def get_count(self) -> int:
return self._kwargs.get_item_count(self.modality) return len(self._kwargs[self.modality])
def get(self, index: int) -> Mapping[str, torch.Tensor]: def get(self, index: int) -> Mapping[str, torch.Tensor]:
return { return self._kwargs[self.modality][index].get_data()
k: v.data
for k, v in self._kwargs.get_item(self.modality, index).items()
}
def get_processor_data(self) -> Mapping[str, object]: def get_processor_data(self) -> Mapping[str, object]:
return {} return {}
......
...@@ -23,8 +23,9 @@ from vllm.utils import flatten_2d_lists, full_groupby ...@@ -23,8 +23,9 @@ from vllm.utils import flatten_2d_lists, full_groupby
from .cache import MultiModalCache from .cache import MultiModalCache
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, MultiModalFieldConfig, MultiModalInputs,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, MultiModalKwargsItems,
PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
...@@ -985,7 +986,7 @@ _I = TypeVar("_I", bound=BaseProcessingInfo) ...@@ -985,7 +986,7 @@ _I = TypeVar("_I", bound=BaseProcessingInfo)
MultiModalHashes = dict[str, list[str]] MultiModalHashes = dict[str, list[str]]
""" """
A collection of hashes with a similar structure as A collection of hashes with a similar structure as
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]. [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
""" """
...@@ -1095,7 +1096,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1095,7 +1096,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
""" """
Given the original multi-modal items for this modality Given the original multi-modal items for this modality
...@@ -1361,7 +1362,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1361,7 +1362,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self, self,
cache: ProcessingCache, cache: ProcessingCache,
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]], mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
mm_missing_kwargs: MultiModalKwargs, mm_missing_kwargs: MultiModalKwargsItems,
) -> dict[str, list[MultiModalKwargsItem]]: ) -> dict[str, list[MultiModalKwargsItem]]:
mm_missing_next_idx = defaultdict[str, int](lambda: 0) mm_missing_next_idx = defaultdict[str, int](lambda: 0)
...@@ -1369,10 +1370,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1369,10 +1370,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for modality, items_or_hashes in mm_cache_items_or_hashes.items(): for modality, items_or_hashes in mm_cache_items_or_hashes.items():
for item_or_hash in items_or_hashes: for item_or_hash in items_or_hashes:
if isinstance(item_or_hash, str): if isinstance(item_or_hash, str):
kw_item = mm_missing_kwargs.get_item( kw_item = mm_missing_kwargs[modality][
modality, mm_missing_next_idx[modality]]
mm_missing_next_idx[modality],
)
cache.put(item_or_hash, kw_item) cache.put(item_or_hash, kw_item)
mm_missing_next_idx[modality] += 1 mm_missing_next_idx[modality] += 1
else: else:
...@@ -1390,7 +1389,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1390,7 +1389,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: ) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
( (
prompt_ids, prompt_ids,
mm_processed_data, mm_processed_data,
...@@ -1403,7 +1403,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1403,7 +1403,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
enable_hf_prompt_update=True, enable_hf_prompt_update=True,
) )
mm_kwargs = MultiModalKwargs.from_hf_inputs( mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data, mm_processed_data,
self._get_mm_fields_config(mm_processed_data, self._get_mm_fields_config(mm_processed_data,
hf_processor_mm_kwargs), hf_processor_mm_kwargs),
...@@ -1423,7 +1423,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1423,7 +1423,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: ) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
""" """
Apply the HF processor on the full prompt text, Apply the HF processor on the full prompt text,
caching the results and reusing cached results. caching the results and reusing cached results.
...@@ -1468,7 +1469,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1468,7 +1469,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
enable_hf_prompt_update=False, enable_hf_prompt_update=False,
) )
mm_missing_kwargs = MultiModalKwargs.from_hf_inputs( mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_missing_processed_data, mm_missing_processed_data,
self._get_mm_fields_config(mm_missing_processed_data, self._get_mm_fields_config(mm_missing_processed_data,
hf_processor_mm_kwargs), hf_processor_mm_kwargs),
...@@ -1480,7 +1481,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1480,7 +1481,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_kwargs=mm_missing_kwargs, mm_missing_kwargs=mm_missing_kwargs,
) )
mm_kwargs = MultiModalKwargs([ mm_kwargs = MultiModalKwargsItems.from_seq([
item for cache_items in mm_cache_items_merged.values() item for cache_items in mm_cache_items_merged.values()
for item in cache_items for item in cache_items
]) ])
...@@ -1585,14 +1586,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1585,14 +1586,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _validate_mm_kwargs( def _validate_mm_kwargs(
self, self,
mm_kwargs: MultiModalKwargs, mm_kwargs: MultiModalKwargsItems,
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> None: ) -> None:
for modality, item_count in mm_item_counts.items(): for modality, item_count in mm_item_counts.items():
if modality in mm_kwargs.modalities: items = mm_kwargs.get(modality, [])
items = mm_kwargs.get_items(modality)
else:
items = []
if len(items) != item_count: if len(items) != item_count:
raise RuntimeError( raise RuntimeError(
...@@ -1630,7 +1628,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1630,7 +1628,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
prompt_ids: list[int], prompt_ids: list[int],
mm_kwargs: MultiModalKwargs, mm_kwargs: MultiModalKwargsItems,
is_update_applied: bool, is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
unbound_prompt_updates = self._get_prompt_updates( unbound_prompt_updates = self._get_prompt_updates(
......
...@@ -13,7 +13,7 @@ import vllm.envs as envs ...@@ -13,7 +13,7 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalKwargs, MultiModalInputs, MultiModalKwargsItems,
MultiModalPlaceholderDict) MultiModalPlaceholderDict)
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
EncDecMultiModalProcessor) EncDecMultiModalProcessor)
...@@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple): ...@@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling.""" """Dummy data used for profiling."""
prompt_token_ids: list[int] prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargs multi_modal_data: MultiModalKwargsItems
multi_modal_placeholders: MultiModalPlaceholderDict multi_modal_placeholders: MultiModalPlaceholderDict
......
...@@ -32,11 +32,13 @@ _M = TypeVar("_M") ...@@ -32,11 +32,13 @@ _M = TypeVar("_M")
if TYPE_CHECKING: if TYPE_CHECKING:
from .inputs import (BatchedTensorInputs, MultiModalKwargs, from .inputs import (BatchedTensorInputs, MultiModalKwargs,
MultiModalKwargsItem, MultiModalPlaceholderDict) MultiModalKwargsItem, MultiModalKwargsItems,
MultiModalPlaceholderDict)
else: else:
BatchedTensorInputs = Any BatchedTensorInputs = Any
MultiModalKwargs = Any MultiModalKwargs = Any
MultiModalKwargsItem = Any MultiModalKwargsItem = Any
MultiModalKwargsItems = Any
MultiModalPlaceholderDict = Any MultiModalPlaceholderDict = Any
global_thread_pool = ThreadPoolExecutor( global_thread_pool = ThreadPoolExecutor(
...@@ -359,18 +361,20 @@ def argsort_mm_positions( ...@@ -359,18 +361,20 @@ def argsort_mm_positions(
"`group_mm_kwargs_by_modality` and will be removed in v0.13. " "`group_mm_kwargs_by_modality` and will be removed in v0.13. "
"Please use `group_mm_kwargs_by_modality` instead.") "Please use `group_mm_kwargs_by_modality` instead.")
def group_mm_inputs_by_modality( def group_mm_inputs_by_modality(
mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]: mm_inputs: list[MultiModalKwargsItems]
) -> list[list[MultiModalKwargsItems]]:
if not mm_inputs: if not mm_inputs:
return [] return []
def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]: def modality_group_func(
mm_input: MultiModalKwargsItems) -> Union[str, int]:
# If the input has multiple modalities, return a id as the unique key # If the input has multiple modalities, return a id as the unique key
# for the mm_input input. # for the mm_input input.
if len(mm_input.modalities) > 1: if len(mm_input) > 1:
return id(mm_input) return id(mm_input)
elif len(mm_input.modalities) == 1: elif len(mm_input) == 1:
return list(mm_input.modalities)[0] return next(iter(mm_input.keys()))
# FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty, # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty,
# this is used to make InternVL with legacy pipeline still work with v1. # this is used to make InternVL with legacy pipeline still work with v1.
...@@ -397,12 +401,12 @@ def group_mm_kwargs_by_modality( ...@@ -397,12 +401,12 @@ def group_mm_kwargs_by_modality(
Yields: Yields:
A tuple `(modality, num_items, grouped_kwargs)`. A tuple `(modality, num_items, grouped_kwargs)`.
""" """
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items) items_lst = list(items)
# mm_kwargs_group = MultiModalKwargs(items_lst) \ # mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \
# .get_data(pin_memory=pin_memory) # .get_data(pin_memory=pin_memory)
# if device is not None: # if device is not None:
...@@ -417,7 +421,10 @@ def group_mm_kwargs_by_modality( ...@@ -417,7 +421,10 @@ def group_mm_kwargs_by_modality(
# We will also need to update each model to remove `flatten_bn`. # We will also need to update each model to remove `flatten_bn`.
mm_kwargs_group = MultiModalKwargs.as_kwargs( mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch( MultiModalKwargs.batch(
[MultiModalKwargs([item]) for item in items_lst], [
MultiModalKwargsItems.from_seq([item]).get_data()
for item in items_lst
],
pin_memory=pin_memory, pin_memory=pin_memory,
), ),
device=device, device=device,
......
...@@ -22,7 +22,6 @@ from vllm.pooling_params import PoolingParams ...@@ -22,7 +22,6 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.inputs import NestedTensors
from vllm.v1.worker.kv_connector_model_runner_mixin import ( from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorOutput) KVConnectorOutput)
...@@ -523,7 +522,7 @@ class Sequence: ...@@ -523,7 +522,7 @@ class Sequence:
@property @property
def multi_modal_data(self) -> MultiModalKwargs: def multi_modal_data(self) -> MultiModalKwargs:
if self.inputs["type"] == "multimodal": if self.inputs["type"] == "multimodal":
return self.inputs["mm_kwargs"] return self.inputs["mm_kwargs"].get_data()
return MultiModalKwargs() return MultiModalKwargs()
...@@ -979,8 +978,7 @@ class SequenceGroupMetadata( ...@@ -979,8 +978,7 @@ class SequenceGroupMetadata(
state: Optional[SequenceGroupState] = msgspec.field( state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState()) default_factory=lambda: SequenceGroupState())
token_type_ids: Optional[list[int]] = None token_type_ids: Optional[list[int]] = None
multi_modal_data: Optional[Union[MultiModalKwargs, multi_modal_data: Optional[MultiModalKwargs] = None
dict[str, "NestedTensors"]]] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
encoder_seq_data: Optional[SequenceData] = None encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[list[int]] = None cross_block_table: Optional[list[int]] = None
......
...@@ -310,7 +310,7 @@ class Processor: ...@@ -310,7 +310,7 @@ class Processor:
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
orig_sorted_mm_inputs = [ orig_sorted_mm_inputs = [
decoder_mm_inputs.get_item(modality, idx) decoder_mm_inputs[modality][idx]
for modality, idx in sorted_mm_idxs for modality, idx in sorted_mm_idxs
] ]
sorted_mm_positions = [ sorted_mm_positions = [
......
...@@ -18,12 +18,15 @@ from msgspec import msgpack ...@@ -18,12 +18,15 @@ from msgspec import msgpack
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
# yapf: disable
from vllm.multimodal.inputs import (BaseMultiModalField, from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalBatchedField, MultiModalBatchedField,
MultiModalFieldConfig, MultiModalFieldElem, MultiModalFieldConfig, MultiModalFieldElem,
MultiModalFlatField, MultiModalKwargs, MultiModalFlatField, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField, NestedTensors) MultiModalSharedField, NestedTensors)
# yapf: enable
from vllm.v1.engine import UtilityResult from vllm.v1.engine import UtilityResult
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -116,12 +119,11 @@ class MsgpackEncoder: ...@@ -116,12 +119,11 @@ class MsgpackEncoder:
if isinstance(obj, MultiModalKwargsItem): if isinstance(obj, MultiModalKwargsItem):
return self._encode_mm_item(obj) return self._encode_mm_item(obj)
if isinstance(obj, MultiModalKwargsItems):
return self._encode_mm_items(obj)
if isinstance(obj, MultiModalKwargs): if isinstance(obj, MultiModalKwargs):
return [ return self._encode_mm_kwargs(obj)
self._encode_mm_item(item)
for itemlist in obj._items_by_modality.values()
for item in itemlist
]
if isinstance(obj, UtilityResult): if isinstance(obj, UtilityResult):
result = obj.result result = obj.result
...@@ -183,6 +185,12 @@ class MsgpackEncoder: ...@@ -183,6 +185,12 @@ class MsgpackEncoder:
dtype = str(obj.dtype).removeprefix("torch.") dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data return dtype, obj.shape, data
def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
return {
modality: [self._encode_mm_item(item) for item in itemlist]
for modality, itemlist in items.items()
}
def _encode_mm_item(self, def _encode_mm_item(self,
item: MultiModalKwargsItem) -> list[dict[str, Any]]: item: MultiModalKwargsItem) -> list[dict[str, Any]]:
return [self._encode_mm_field_elem(elem) for elem in item.values()] return [self._encode_mm_field_elem(elem) for elem in item.values()]
...@@ -200,6 +208,12 @@ class MsgpackEncoder: ...@@ -200,6 +208,12 @@ class MsgpackEncoder:
self._encode_mm_field(elem.field), self._encode_mm_field(elem.field),
} }
def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
return {
modality: self._encode_nested_tensors(data)
for modality, data in kw.items()
}
def _encode_nested_tensors(self, nt: NestedTensors) -> Any: def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor): if isinstance(nt, torch.Tensor):
return self._encode_tensor(nt) return self._encode_tensor(nt)
...@@ -260,8 +274,10 @@ class MsgpackDecoder: ...@@ -260,8 +274,10 @@ class MsgpackDecoder:
return slice(*obj) return slice(*obj)
if issubclass(t, MultiModalKwargsItem): if issubclass(t, MultiModalKwargsItem):
return self._decode_mm_item(obj) return self._decode_mm_item(obj)
if issubclass(t, MultiModalKwargsItems):
return self._decode_mm_items(obj)
if issubclass(t, MultiModalKwargs): if issubclass(t, MultiModalKwargs):
return MultiModalKwargs(self._decode_mm_items(obj)) return self._decode_mm_kwargs(obj)
if t is UtilityResult: if t is UtilityResult:
return self._decode_utility_result(obj) return self._decode_utility_result(obj)
return obj return obj
...@@ -315,8 +331,11 @@ class MsgpackDecoder: ...@@ -315,8 +331,11 @@ class MsgpackDecoder:
# Convert back to proper shape & type # Convert back to proper shape & type
return arr.view(torch_dtype).view(shape) return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]: def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
return [self._decode_mm_item(v) for v in obj] return MultiModalKwargsItems({
modality: [self._decode_mm_item(item) for item in itemlist]
for modality, itemlist in obj.items()
})
def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem: def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
return MultiModalKwargsItem.from_elems( return MultiModalKwargsItem.from_elems(
...@@ -339,6 +358,12 @@ class MsgpackDecoder: ...@@ -339,6 +358,12 @@ class MsgpackDecoder:
obj["field"] = factory_meth(None, *field_args).field obj["field"] = factory_meth(None, *field_args).field
return MultiModalFieldElem(**obj) return MultiModalFieldElem(**obj)
def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
return MultiModalKwargs({
modality: self._decode_nested_tensors(data)
for modality, data in obj.items()
})
def _decode_nested_tensors(self, obj: Any) -> NestedTensors: def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if isinstance(obj, (int, float)): if isinstance(obj, (int, float)):
# Although it violates NestedTensors type, MultiModalKwargs # Although it violates NestedTensors type, MultiModalKwargs
......
...@@ -10,8 +10,8 @@ import torch ...@@ -10,8 +10,8 @@ import torch
from typing_extensions import deprecated from typing_extensions import deprecated
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem, from vllm.multimodal.inputs import (MultiModalKwargsItem,
PlaceholderRange) MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values from vllm.utils import swap_dict_values
...@@ -57,8 +57,10 @@ class CachedRequestState: ...@@ -57,8 +57,10 @@ class CachedRequestState:
@property @property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.") "removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargs]: def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [MultiModalKwargs([item]) for item in self.mm_kwargs] return [
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
]
def get_token_id(self, idx: int) -> int: def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens: if idx < self.num_prompt_tokens:
......
...@@ -2218,11 +2218,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2218,11 +2218,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_data = dummy_decoder_data.multi_modal_data dummy_mm_data = dummy_decoder_data.multi_modal_data
# Result in the maximum GPU consumption of the model # Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next(mm_kwargs_group return next(mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
[dummy_mm_item] * max_items_per_batch, dummy_mm_items,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
)) ))
......
...@@ -1824,11 +1824,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1824,11 +1824,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_data = dummy_decoder_data.multi_modal_data dummy_mm_data = dummy_decoder_data.multi_modal_data
# Result in the maximum GPU consumption of the model # Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next(grouped_mm_kwargs return next(grouped_mm_kwargs
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
[dummy_mm_item] * max_items_per_batch, dummy_mm_items,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
)) ))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment