Unverified Commit 27e8d1ea authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 5c79b0d6
......@@ -23,7 +23,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
MultiModalKwargsItems, NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
......@@ -194,7 +194,7 @@ class UltravoxMultiModalProcessor(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
......@@ -203,7 +203,8 @@ class UltravoxMultiModalProcessor(
# Each audio can be split into multiple chunks.
# chunks_start_idx[i] indicates the start index of the chunks
# belonging to the i-th audio.
num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0))
out_mm_data = out_mm_kwargs.get_data()
num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks,
dim=0,
dtype=torch.int32)
......@@ -213,7 +214,7 @@ class UltravoxMultiModalProcessor(
def get_replacement_ultravox(item_idx: int):
start = chunks_start_idx[item_idx]
end = chunks_start_idx[item_idx + 1]
audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum()
audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
return [replacement_id] * int(audio_token_len) # type: ignore
return [
......
......@@ -31,7 +31,7 @@ from vllm.model_executor.models.whisper import WhisperEncoder
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
MultiModalKwargsItems, NestedTensors)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
......@@ -259,7 +259,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
......@@ -289,7 +289,8 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor(
prompt=prompt,
......
......@@ -33,7 +33,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
......@@ -728,7 +728,7 @@ class WhisperMultiModalProcessor(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
num_tokens = self.info.get_num_audio_tokens()
return [
......
......@@ -4,7 +4,8 @@ from .base import MultiModalPlaceholderMap
from .hasher import MultiModalHashDict, MultiModalHasher
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict, NestedTensors)
MultiModalKwargsItems, MultiModalPlaceholderDict,
NestedTensors)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
......@@ -25,6 +26,7 @@ __all__ = [
"MultiModalHashDict",
"MultiModalHasher",
"MultiModalKwargs",
"MultiModalKwargsItems",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",
"NestedTensors",
......
......@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata
from .inputs import MultiModalKwargs, NestedTensors, PlaceholderRange
from .inputs import MultiModalKwargs, PlaceholderRange
_T = TypeVar("_T")
......@@ -56,8 +56,7 @@ class MultiModalPlaceholderMap:
@classmethod
def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range
) -> tuple[dict[str, NestedTensors], dict[str,
"MultiModalPlaceholderMap"]]:
) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]:
"""
Returns the multi-modal items that intersect with the portion of a
prompt (``seq_group``) represented by ``positions``, as well as a
......@@ -100,7 +99,7 @@ class MultiModalPlaceholderMap:
seq_mm_placeholders = seq_group.multi_modal_placeholders
if not seq_mm_data or not seq_mm_placeholders:
return MultiModalKwargs().get_data(), {}
return MultiModalKwargs(), {}
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
......@@ -117,8 +116,6 @@ class MultiModalPlaceholderMap:
placeholder_maps[modality] = placeholder_map
seq_mm_data = seq_mm_data if isinstance(
seq_mm_data, dict) else seq_mm_data.get_data()
return seq_mm_data, placeholder_maps
def append_items_from_seq_group(
......
......@@ -11,7 +11,9 @@ from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
from .inputs import MultiModalKwargs, MultiModalKwargsItem, NestedTensors
from .inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItems,
NestedTensors)
logger = init_logger(__name__)
......@@ -26,8 +28,9 @@ class MultiModalCacheItemMetadata:
MultiModalCacheValue = Union[
MultiModalKwargs,
MultiModalKwargsItems,
MultiModalKwargsItem,
MultiModalKwargs,
Mapping[str, NestedTensors],
MultiModalCacheItemMetadata,
]
......@@ -44,14 +47,16 @@ class MultiModalCache:
*,
debug: bool = False,
) -> int:
# MultiModalKwargs is not a subclass of dict
if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.get_data(), debug=debug)
if isinstance(leaf, MultiModalFieldElem):
return cls.get_item_size(leaf.data) # type: ignore
# MultiModalKwargsItem is not a subclass of dict
# These are not subclasses of dict
if isinstance(leaf, MultiModalKwargsItems):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalKwargsItem):
leaf_data = {k: v.data for k, v in leaf.items()}
return cls.get_item_size(leaf_data, debug=debug)
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data) # type: ignore
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
......
......@@ -11,7 +11,7 @@ from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final)
import numpy as np
from typing_extensions import NotRequired, TypeAlias
from typing_extensions import NotRequired, TypeAlias, deprecated
from vllm.utils import LazyLoader, full_groupby, is_list_of
from vllm.utils.jsontree import JSONTree, json_map_leaves
......@@ -656,7 +656,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
super().__init__(data)
modalities = {elem.modality for elem in self.data.values()}
modalities = {elem.modality for elem in self.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
self._modality = next(iter(modalities))
......@@ -668,16 +668,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
return {key: elem.data for key, elem in self.items()}
class MultiModalKwargs:
class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]):
"""
A dictionary that represents the keyword arguments to
[`torch.nn.Module.forward`][].
The metadata `items` enables us to obtain the keyword arguments
corresponding to each data item in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via
[`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and
[`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items].
A dictionary of
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
by modality.
"""
@staticmethod
......@@ -712,19 +707,64 @@ class MultiModalKwargs:
elems = [v[item_idx] for v in elems_in_modality.values()]
items.append(MultiModalKwargsItem.from_elems(elems))
return MultiModalKwargs(items)
def __init__(self, items: Sequence[MultiModalKwargsItem] = ()) -> None:
super().__init__()
return MultiModalKwargsItems.from_seq(items)
@staticmethod
def from_seq(items: Sequence[MultiModalKwargsItem]):
items_by_modality = full_groupby(items, key=lambda x: x.modality)
self._items_by_modality = dict(items_by_modality)
return MultiModalKwargsItems(items_by_modality)
self._data: Optional[dict[str, NestedTensors]] = None
def __getitem__(self, modality: str):
if modality not in self:
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {set(self.keys())}")
return super().__getitem__(modality)
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for items in self.values():
for item in items:
for key, elem in item.items():
elems_by_key[key].append(elem)
return MultiModalKwargs({
key:
elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0
})
@property
def modalities(self):
return self._items_by_modality.keys()
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
[`torch.nn.Module.forward`][].
"""
@staticmethod
@deprecated("`MultiModalKwargs.from_hf_inputs` is deprecated and "
"will be removed in v0.13. "
"Please use `MultiModalKwargsItems.from_hf_inputs` and "
"access the tensor data using `.get_data()`.")
def from_hf_inputs(
hf_inputs: "BatchFeature",
config_by_key: Mapping[str, MultiModalFieldConfig],
):
return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key) \
.get_data()
@staticmethod
@deprecated("`MultiModalKwargs.from_items` is deprecated and "
"will be removed in v0.13. "
"Please use `MultiModalKwargsItems.from_seq` and "
"access the tensor data using `.get_data()`.")
def from_items(
items: Sequence[MultiModalKwargsItem],
*,
pin_memory: bool = False,
):
return MultiModalKwargsItems.from_seq(items) \
.get_data(pin_memory=pin_memory)
@staticmethod
def _try_stack(nested_tensors: NestedTensors,
......@@ -813,92 +853,24 @@ class MultiModalKwargs:
return cast(BatchedTensorInputs, json_mapped)
def keys(self):
return self.get_data().keys()
def values(self):
return self.get_data().values()
def items(self):
return self.get_data().items()
def get(self, key: str, /, default=None):
return self.get_data().get(key, default)
def pop(self, key: str, *args, **kwargs):
data = dict(self.get_data())
res = data.pop(key, *args, **kwargs)
for items in self._items_by_modality.values():
for item in items:
item.pop(key, *args, **kwargs)
self._data = None
return res
def __iter__(self):
return iter(self.get_data())
def __getitem__(self, key: str):
return self.get_data()[key]
if key not in self:
raise KeyError(f"Keyword argument {key!r} not found. "
f"Available keys: {set(self.keys())}")
return super().__getitem__(key)
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
return self._items_by_modality == other._items_by_modality
def _validate_modality(self, method_name: str, modality: str) -> None:
if not self._items_by_modality:
raise RuntimeError(
f"`{method_name}` is not supported when "
"MultiModalKwargs is not initialized with `items`")
if modality not in self._items_by_modality:
available_modalities = set(self._items_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
def get_item_count(self, modality: str) -> int:
"""Get the number of items belonging to a modality."""
self._validate_modality("get_item_count", modality)
return len(self._items_by_modality[modality])
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
"""
Get the keyword arguments corresponding to an item identified by
its modality and index.
"""
self._validate_modality("get_item", modality)
return self._items_by_modality[modality][item_index]
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
"""
Get the keyword arguments corresponding to each item belonging to
a modality.
"""
self._validate_modality("get_items", modality)
return self._items_by_modality[modality]
def get_data(self,
*,
pin_memory: bool = False) -> dict[str, NestedTensors]:
if self._data is not None:
return self._data
for k in self:
if k not in other:
return False
if not nested_tensors_equal(self[k], other[k]):
return False
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for items in self._items_by_modality.values():
for item in items:
for key, elem in item.items():
elems_by_key[key].append(elem)
data = {
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0
}
self._data = data
return data
return True
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
......@@ -926,7 +898,7 @@ class MultiModalInputs(TypedDict):
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargs
mm_kwargs: MultiModalKwargsItems
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: Optional["MultiModalHashDict"]
......
......@@ -16,7 +16,7 @@ from vllm.utils import LazyLoader, is_list_of
from .audio import AudioResampler
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict,
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
MultiModalFieldConfig, MultiModalKwargsItems, VideoItem)
_T = TypeVar("_T")
_I = TypeVar("_I")
......@@ -157,19 +157,16 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
self.fields_config = fields_config
self.required_fields = required_fields
self._kwargs = MultiModalKwargs.from_hf_inputs(
self._kwargs = MultiModalKwargsItems.from_hf_inputs(
BatchFeature(dict(data)),
fields_config,
)
def get_count(self) -> int:
return self._kwargs.get_item_count(self.modality)
return len(self._kwargs[self.modality])
def get(self, index: int) -> Mapping[str, torch.Tensor]:
return {
k: v.data
for k, v in self._kwargs.get_item(self.modality, index).items()
}
return self._kwargs[self.modality][index].get_data()
def get_processor_data(self) -> Mapping[str, object]:
return {}
......
......@@ -23,8 +23,9 @@ from vllm.utils import flatten_2d_lists, full_groupby
from .cache import MultiModalCache
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
MultiModalFieldConfig, MultiModalInputs,
MultiModalKwargsItem, MultiModalKwargsItems,
PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser)
......@@ -985,7 +986,7 @@ _I = TypeVar("_I", bound=BaseProcessingInfo)
MultiModalHashes = dict[str, list[str]]
"""
A collection of hashes with a similar structure as
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
......@@ -1095,7 +1096,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
"""
Given the original multi-modal items for this modality
......@@ -1361,7 +1362,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self,
cache: ProcessingCache,
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
mm_missing_kwargs: MultiModalKwargs,
mm_missing_kwargs: MultiModalKwargsItems,
) -> dict[str, list[MultiModalKwargsItem]]:
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
......@@ -1369,10 +1370,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for modality, items_or_hashes in mm_cache_items_or_hashes.items():
for item_or_hash in items_or_hashes:
if isinstance(item_or_hash, str):
kw_item = mm_missing_kwargs.get_item(
modality,
mm_missing_next_idx[modality],
)
kw_item = mm_missing_kwargs[modality][
mm_missing_next_idx[modality]]
cache.put(item_or_hash, kw_item)
mm_missing_next_idx[modality] += 1
else:
......@@ -1390,7 +1389,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
(
prompt_ids,
mm_processed_data,
......@@ -1403,7 +1403,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
enable_hf_prompt_update=True,
)
mm_kwargs = MultiModalKwargs.from_hf_inputs(
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data,
self._get_mm_fields_config(mm_processed_data,
hf_processor_mm_kwargs),
......@@ -1423,7 +1423,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
......@@ -1468,7 +1469,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
enable_hf_prompt_update=False,
)
mm_missing_kwargs = MultiModalKwargs.from_hf_inputs(
mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_missing_processed_data,
self._get_mm_fields_config(mm_missing_processed_data,
hf_processor_mm_kwargs),
......@@ -1480,7 +1481,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_kwargs=mm_missing_kwargs,
)
mm_kwargs = MultiModalKwargs([
mm_kwargs = MultiModalKwargsItems.from_seq([
item for cache_items in mm_cache_items_merged.values()
for item in cache_items
])
......@@ -1585,14 +1586,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _validate_mm_kwargs(
self,
mm_kwargs: MultiModalKwargs,
mm_kwargs: MultiModalKwargsItems,
mm_item_counts: Mapping[str, int],
) -> None:
for modality, item_count in mm_item_counts.items():
if modality in mm_kwargs.modalities:
items = mm_kwargs.get_items(modality)
else:
items = []
items = mm_kwargs.get(modality, [])
if len(items) != item_count:
raise RuntimeError(
......@@ -1630,7 +1628,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
prompt_ids: list[int],
mm_kwargs: MultiModalKwargs,
mm_kwargs: MultiModalKwargsItems,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
unbound_prompt_updates = self._get_prompt_updates(
......
......@@ -13,7 +13,7 @@ import vllm.envs as envs
from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalKwargs,
MultiModalInputs, MultiModalKwargsItems,
MultiModalPlaceholderDict)
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
EncDecMultiModalProcessor)
......@@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargs
multi_modal_data: MultiModalKwargsItems
multi_modal_placeholders: MultiModalPlaceholderDict
......
......@@ -32,11 +32,13 @@ _M = TypeVar("_M")
if TYPE_CHECKING:
from .inputs import (BatchedTensorInputs, MultiModalKwargs,
MultiModalKwargsItem, MultiModalPlaceholderDict)
MultiModalKwargsItem, MultiModalKwargsItems,
MultiModalPlaceholderDict)
else:
BatchedTensorInputs = Any
MultiModalKwargs = Any
MultiModalKwargsItem = Any
MultiModalKwargsItems = Any
MultiModalPlaceholderDict = Any
global_thread_pool = ThreadPoolExecutor(
......@@ -359,18 +361,20 @@ def argsort_mm_positions(
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
"Please use `group_mm_kwargs_by_modality` instead.")
def group_mm_inputs_by_modality(
mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]:
mm_inputs: list[MultiModalKwargsItems]
) -> list[list[MultiModalKwargsItems]]:
if not mm_inputs:
return []
def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]:
def modality_group_func(
mm_input: MultiModalKwargsItems) -> Union[str, int]:
# If the input has multiple modalities, return a id as the unique key
# for the mm_input input.
if len(mm_input.modalities) > 1:
if len(mm_input) > 1:
return id(mm_input)
elif len(mm_input.modalities) == 1:
return list(mm_input.modalities)[0]
elif len(mm_input) == 1:
return next(iter(mm_input.keys()))
# FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty,
# this is used to make InternVL with legacy pipeline still work with v1.
......@@ -397,12 +401,12 @@ def group_mm_kwargs_by_modality(
Yields:
A tuple `(modality, num_items, grouped_kwargs)`.
"""
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items)
# mm_kwargs_group = MultiModalKwargs(items_lst) \
# mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \
# .get_data(pin_memory=pin_memory)
# if device is not None:
......@@ -417,7 +421,10 @@ def group_mm_kwargs_by_modality(
# We will also need to update each model to remove `flatten_bn`.
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[MultiModalKwargs([item]) for item in items_lst],
[
MultiModalKwargsItems.from_seq([item]).get_data()
for item in items_lst
],
pin_memory=pin_memory,
),
device=device,
......
......@@ -22,7 +22,6 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
if TYPE_CHECKING:
from vllm.multimodal.inputs import NestedTensors
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorOutput)
......@@ -523,7 +522,7 @@ class Sequence:
@property
def multi_modal_data(self) -> MultiModalKwargs:
if self.inputs["type"] == "multimodal":
return self.inputs["mm_kwargs"]
return self.inputs["mm_kwargs"].get_data()
return MultiModalKwargs()
......@@ -979,8 +978,7 @@ class SequenceGroupMetadata(
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
token_type_ids: Optional[list[int]] = None
multi_modal_data: Optional[Union[MultiModalKwargs,
dict[str, "NestedTensors"]]] = None
multi_modal_data: Optional[MultiModalKwargs] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[list[int]] = None
......
......@@ -310,7 +310,7 @@ class Processor:
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
orig_sorted_mm_inputs = [
decoder_mm_inputs.get_item(modality, idx)
decoder_mm_inputs[modality][idx]
for modality, idx in sorted_mm_idxs
]
sorted_mm_positions = [
......
......@@ -18,12 +18,15 @@ from msgspec import msgpack
from vllm import envs
from vllm.logger import init_logger
# yapf: disable
from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalBatchedField,
MultiModalFieldConfig, MultiModalFieldElem,
MultiModalFlatField, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField, NestedTensors)
# yapf: enable
from vllm.v1.engine import UtilityResult
logger = init_logger(__name__)
......@@ -116,12 +119,11 @@ class MsgpackEncoder:
if isinstance(obj, MultiModalKwargsItem):
return self._encode_mm_item(obj)
if isinstance(obj, MultiModalKwargsItems):
return self._encode_mm_items(obj)
if isinstance(obj, MultiModalKwargs):
return [
self._encode_mm_item(item)
for itemlist in obj._items_by_modality.values()
for item in itemlist
]
return self._encode_mm_kwargs(obj)
if isinstance(obj, UtilityResult):
result = obj.result
......@@ -183,6 +185,12 @@ class MsgpackEncoder:
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
return {
modality: [self._encode_mm_item(item) for item in itemlist]
for modality, itemlist in items.items()
}
def _encode_mm_item(self,
item: MultiModalKwargsItem) -> list[dict[str, Any]]:
return [self._encode_mm_field_elem(elem) for elem in item.values()]
......@@ -200,6 +208,12 @@ class MsgpackEncoder:
self._encode_mm_field(elem.field),
}
def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
return {
modality: self._encode_nested_tensors(data)
for modality, data in kw.items()
}
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor):
return self._encode_tensor(nt)
......@@ -260,8 +274,10 @@ class MsgpackDecoder:
return slice(*obj)
if issubclass(t, MultiModalKwargsItem):
return self._decode_mm_item(obj)
if issubclass(t, MultiModalKwargsItems):
return self._decode_mm_items(obj)
if issubclass(t, MultiModalKwargs):
return MultiModalKwargs(self._decode_mm_items(obj))
return self._decode_mm_kwargs(obj)
if t is UtilityResult:
return self._decode_utility_result(obj)
return obj
......@@ -315,8 +331,11 @@ class MsgpackDecoder:
# Convert back to proper shape & type
return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]:
return [self._decode_mm_item(v) for v in obj]
def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
return MultiModalKwargsItems({
modality: [self._decode_mm_item(item) for item in itemlist]
for modality, itemlist in obj.items()
})
def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
return MultiModalKwargsItem.from_elems(
......@@ -339,6 +358,12 @@ class MsgpackDecoder:
obj["field"] = factory_meth(None, *field_args).field
return MultiModalFieldElem(**obj)
def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
return MultiModalKwargs({
modality: self._decode_nested_tensors(data)
for modality, data in obj.items()
})
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if isinstance(obj, (int, float)):
# Although it violates NestedTensors type, MultiModalKwargs
......
......@@ -10,8 +10,8 @@ import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
......@@ -57,8 +57,10 @@ class CachedRequestState:
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargs]:
return [MultiModalKwargs([item]) for item in self.mm_kwargs]
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
......
......@@ -2218,11 +2218,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_data = dummy_decoder_data.multi_modal_data
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next(mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
[dummy_mm_item] * max_items_per_batch,
dummy_mm_items,
device=self.device,
pin_memory=self.pin_memory,
))
......
......@@ -1824,11 +1824,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_data = dummy_decoder_data.multi_modal_data
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next(grouped_mm_kwargs
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
[dummy_mm_item] * max_items_per_batch,
dummy_mm_items,
device=self.device,
pin_memory=self.pin_memory,
))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment