inputs.py 30.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union,
                    cast, final)
12
13

import numpy as np
14
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
15

16
from vllm.utils import LazyLoader, full_groupby, is_list_of
17
from vllm.utils.jsontree import JSONTree, json_map_leaves
18

19
if TYPE_CHECKING:
20
21
22
23
24
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

25
26
    from .processing import MultiModalHashes

27
28
else:
    torch = LazyLoader("torch", globals(), "torch")
29

30
31
_T = TypeVar("_T")

32
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
33
"""
34
A `transformers.image_utils.ImageInput` representing a single image
35
item, which can be passed to a HuggingFace `ImageProcessor`.
36
37
"""

38
39
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
                               list[np.ndarray], list["torch.Tensor"]]
40
"""
41
A `transformers.image_utils.VideoInput` representing a single video
42
item, which can be passed to a HuggingFace `VideoProcessor`.
43
44
"""

45
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
46
"""
47
Represents a single audio
48
item, which can be passed to a HuggingFace `AudioProcessor`.
49
50
"""

51
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
52
"""
53
A `transformers.image_utils.ImageInput` representing a single image
54
item, which can be passed to a HuggingFace `ImageProcessor`.
55
56
57
58
59
60

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

61
62
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor",
                             tuple[HfVideoItem, dict[str, Any]]]
63
"""
64
65
66
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
67
68
69
70
71
72
73

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
74
                             "torch.Tensor"]
75
76
"""
Represents a single audio
77
item, which can be passed to a HuggingFace `AudioProcessor`.
78
79
80
81
82
83
84
85
86
87
88

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

ModalityData: TypeAlias = Union[_T, list[_T]]
89
90
91
92
"""
Either a single data item, or a list of data items.

The number of data items allowed per modality is restricted by
93
`--limit-mm-per-prompt`.
94
95
96
97
98
99
100
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

101
    image: ModalityData[ImageItem]
102
103
    """The input image(s)."""

104
    video: ModalityData[VideoItem]
105
106
    """The input video(s)."""

107
    audio: ModalityData[AudioItem]
108
109
110
    """The input audio(s)."""


111
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
112
113
"""
A dictionary containing an entry for each modality type to input.
114

115
116
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
117
118
"""

119
120
121
122
123
124
125
126
127
128
MultiModalUUIDDict: TypeAlias = Mapping[str, Union[list[Optional[str]], str]]
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

129

130
131
@dataclass(frozen=True)
class PlaceholderRange:
132
133
134
    """
    Placeholder location information for multi-modal data.

135
136
    Example:

137
    Prompt: `AAAA BBBB What is in these images?`
138

139
    Images A and B will have:
140

141
142
143
144
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
145
146
147
148
149
150
151
152
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

153
    is_embed: Optional["torch.Tensor"] = None
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

178

179
180
NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
                                 "torch.Tensor", tuple["torch.Tensor", ...]]
181
182
183
184
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

185
186

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
187
188
    """Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
189
    if isinstance(a, torch.Tensor):
190
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
191
    elif isinstance(b, torch.Tensor):
192
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
193
194
195
196
197
198
199
200
201
202
203
204
205

    if isinstance(a, list):
        return (isinstance(b, list)
                and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
    if isinstance(b, list):
        return (isinstance(a, list)
                and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))

    # Both a and b are scalars
    return a == b


BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
206
207
"""
A dictionary containing nested tensors which have been batched via
208
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
209
210
211
"""


212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
    
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""


235
@dataclass
236
class MultiModalFieldElem:
237
238
    """
    Represents a keyword argument corresponding to a multi-modal item
239
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
240
241
242
243
244
245
246
247
248
249
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
250
251
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
252
253
254
    i.e. the name of the keyword argument to be passed to the model.
    """

255
    data: NestedTensors
256
    """
257
258
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
259
    i.e. the value of the keyword argument to be passed to the model.
260
261
262

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
263
264
265
266
267
268
269
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
270
271
272
273
274

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

275
276
277
278
279
280
281
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

282
        return ((self.modality, self.key) == (other.modality, other.key)
283
                and data_equal
284
                and type(self.field) == type(other.field))  # noqa: E721
285
286
287
288


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
289
290
    """
    Defines how to interpret tensor data belonging to a keyword argument in
291
292
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
308
309

    @abstractmethod
310
311
312
313
314
315
316
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
317
318
319
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
320

321
322
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
323
        """
324
325
        raise NotImplementedError

326
    @abstractmethod
327
328
329
330
331
332
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
333
        raise NotImplementedError
334

335
336
337
338
339
340
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
        pin_memory: bool = False,
    ) -> NestedTensors:
341
        """
342
343
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
344

345
346
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
347
348
349
350
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
351

352
353
        batch = [elem.data for elem in elems]
        return self._reduce_data(batch, pin_memory=pin_memory)
354
355
356
357
358


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
359
    Info:
360
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
361
362
    """

363
364
365
366
367
368
369
370
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
371

372
373
374
375
376
377
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
378
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
379
380
381
382
383
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
384
            first_shape = batch[0].shape
385
            if all(elem.shape == first_shape for elem in batch):
386
387
388
389
390
                out = torch.empty((len(batch), *batch[0].shape),
                                  dtype=batch[0].dtype,
                                  device=batch[0].device,
                                  pin_memory=pin_memory)
                return torch.stack(batch, out=out)
391
392
393
394
395
396
397

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
398
    Info:
399
400
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
401
    """
402
403
    slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
    dim: int = 0
404

405
    def build_elems(
406
        self,
407
408
409
410
411
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
412
413
414
415
        if not is_list_of(self.slices, slice, check="all"):
            assert isinstance(data, torch.Tensor), \
                "torch.Tensor is required for multiple slices"
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
416

417
418
419
420
421
422
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
423
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
424
425
426
427
428
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
429

430
431
432
433
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
                return tensor.shape[:dim], tensor.shape[dim + 1:]
434

435
            first_shape = _shape_before_after(batch[0])
436

437
438
439
440
441
442
443
444
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
                out = torch.empty((*shape_before, shape_concat, *shape_after),
                                  dtype=batch[0].dtype,
                                  device=batch[0].device,
                                  pin_memory=pin_memory)
                return torch.concat(batch, dim=self.dim, out=out)
445
446

        assert self.dim == 0, "dim == 0 is required for nested list"
447
        return [e for elem in batch for e in elem]
448
449


450
451
452
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
453
    Info:
454
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
455
456
457
458
459
460
461
462
463
464
465
466
    """
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

467
468
469
470
471
472
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
473
474
475
        return batch[0]


476
477
478
479
class MultiModalFieldConfig:

    @staticmethod
    def batched(modality: str):
480
481
482
483
484
485
486
487
488
489
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

490
491
492
493
494
495
496
497
498
499
500
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
501
        """
502
        return MultiModalFieldConfig(
503
            field=MultiModalBatchedField(),
504
505
506
507
            modality=modality,
        )

    @staticmethod
508
509
510
    def flat(modality: str,
             slices: Union[Sequence[slice], Sequence[Sequence[slice]]],
             dim: int = 0):
511
512
513
514
515
516
517
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
518
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
519
                slices (dim>0) that is used to extract the data corresponding
520
521
                to it.
            dim: The dimension to extract data, default to 0.
522
523
524

        Example:

525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
554
        """
555
        return MultiModalFieldConfig(
556
            field=MultiModalFlatField(slices=slices, dim=dim),
557
558
559
            modality=modality,
        )

560
    @staticmethod
561
    def flat_from_sizes(modality: str,
562
                        size_per_item: "torch.Tensor",
563
                        dim: int = 0):
564
565
566
567
568
569
570
571
572
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            slices: For each multi-modal item, the size of the slice that
                is used to extract the data corresponding to it.
573
            dim: The dimension to slice, default to 0.
574
575
576

        Example:

577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [3, 4, 2]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

604
        Info:
605
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
606
607
        """

608
609
610
611
        if size_per_item.ndim != 1:
            raise ValueError("size_per_item should be a 1-D tensor, "
                             f"but found shape: {size_per_item.shape}")

612
        slice_idxs = [0, *accumulate(size_per_item)]
613
614
615
        slices = [(slice(None, None, None), ) * dim +
                  (slice(slice_idxs[i], slice_idxs[i + 1]), )
                  for i in range(len(size_per_item))]
616

617
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
618

619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

634
635
636
        ```
        Given:
            batch_size: 4
637

638
639
        Input:
            Data: [XYZ]
640

641
642
643
644
645
646
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
647
648
649
650
651
652
653
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
654
655
        super().__init__()

656
        self.field = field
657
        self.modality = modality
658

659
    def build_elems(
660
661
662
        self,
        key: str,
        batch: NestedTensors,
663
    ) -> Sequence[MultiModalFieldElem]:
664
        return self.field.build_elems(self.modality, key, batch)
665
666


667
668
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
669
670
671
672
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
673
    """
674

675
676
677
678
679
680
681
682
683
684
685
    @staticmethod
    def dummy(modality: str):
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
            data=torch.empty(1),
            field=MultiModalSharedField(1),
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

686
687
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
688
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
689

690
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
691
692
        super().__init__(data)

693
        modalities = {elem.modality for elem in self.values()}
694
        assert len(modalities) == 1, f"Found different modalities={modalities}"
695
696
697
698
699
700
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

701
    def get_data(self) -> dict[str, NestedTensors]:
702
        return {key: elem.data for key, elem in self.items()}
703
704


705
706
707
708
709
710
711
712
713
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
    Optional[MultiModalKwargsItem],
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
714
    """
715
716
717
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
718
719
    """

720
721
    @staticmethod
    def from_hf_inputs(
722
        hf_inputs: "BatchFeature",
723
724
725
726
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
                    f"Found: {batch_sizes=}")

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

752
        return MultiModalKwargsItems.from_seq(items)
753

754
755
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
756
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
757
        return MultiModalKwargsItems(items_by_modality)
758

759
    def __getitem__(self, modality: str) -> Sequence[_I]:
760
761
762
763
        if modality not in self:
            raise KeyError(f"Modality {modality!r} not found. "
                           f"Available modalities: {set(self.keys())}")

764
        return super().__getitem__(modality)  # type: ignore[return-value]
765
766
767

    def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
768
769
770
771
772
773
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
                    raise RuntimeError("Cannot build data from empty "
                                       f"mm_items[{modality}][{i}]")

774
775
776
777
778
779
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

        return MultiModalKwargs({
            key:
            elems[0].field.reduce_data(elems, pin_memory=pin_memory)
780
            for key, elems in elems_by_key.items()
781
        })
782

783

784
785
786
787
788
789
MultiModalKwargsOptionalItems: TypeAlias = Union[
    MultiModalKwargsItems[MultiModalKwargsItem],
    MultiModalKwargsItems[Optional[MultiModalKwargsItem]],
]


790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
    @deprecated("`MultiModalKwargs.from_hf_inputs` is deprecated and "
                "will be removed in v0.13. "
                "Please use `MultiModalKwargsItems.from_hf_inputs` and "
                "access the tensor data using `.get_data()`.")
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key) \
            .get_data()

    @staticmethod
    @deprecated("`MultiModalKwargs.from_items` is deprecated and "
                "will be removed in v0.13. "
                "Please use `MultiModalKwargsItems.from_seq` and "
                "access the tensor data using `.get_data()`.")
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
        return MultiModalKwargsItems.from_seq(items) \
            .get_data(pin_memory=pin_memory)
820

821
    @staticmethod
822
823
    def _try_stack(nested_tensors: NestedTensors,
                   pin_memory: bool = False) -> NestedTensors:
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

840
841
842
        stacked = [
            MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors
        ]
843
844
845
846
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

847
        tensors_ = cast(list[torch.Tensor], stacked)
848
849
850
851
852
853
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

854
855
856
857
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

858
859
860
861
862
863
        outputs = torch.empty(len(tensors_),
                              *tensors_[0].shape,
                              dtype=tensors_[0].dtype,
                              device=tensors_[0].device,
                              pin_memory=pin_memory)
        return torch.stack(tensors_, out=outputs)
864
865

    @staticmethod
866
867
    def batch(inputs_list: list["MultiModalKwargs"],
              pin_memory: bool = False) -> BatchedTensorInputs:
868
869
870
871
872
873
874
875
876
877
878
879
880
881
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
882
        item_lists = defaultdict[str, list[NestedTensors]](list)
883
884
885
886
887
888

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
889
            k: MultiModalKwargs._try_stack(item_list, pin_memory)
890
891
892
893
894
895
896
897
898
899
900
901
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
902
            lambda x: x.to(device=device, non_blocking=True),
903
904
905
906
907
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)

908
    def __getitem__(self, key: str):
909
910
911
912
913
        if key not in self:
            raise KeyError(f"Keyword argument {key!r} not found. "
                           f"Available keys: {set(self.keys())}")

        return super().__getitem__(key)
914

915
916
917
918
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

919
920
921
922
923
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
924

925
        return True
926

927

928
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
929
"""
930
A dictionary containing placeholder ranges for each modality.
931
932
933
"""


934
class MultiModalInputs(TypedDict):
935
    """
936
    Represents the outputs of
937
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
938
939
940
941
942
943
944
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

    prompt: str
945
    """The processed prompt text."""
946

947
    prompt_token_ids: list[int]
948
949
    """The processed token IDs which includes placeholder tokens."""

950
    mm_kwargs: MultiModalKwargsOptionalItems
951
952
    """Keyword arguments to be directly passed to the model after batching."""

953
    mm_hashes: "MultiModalHashes"
954
955
    """The hashes of the multi-modal data."""

956
    mm_placeholders: "MultiModalPlaceholderDict"
957
958
    """
    For each modality, information about the placeholder tokens in
959
    `prompt_token_ids`.
960
    """
961

962
963
964
965
966
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

967
968
969

class MultiModalEncDecInputs(MultiModalInputs):
    """
970
971
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
972
973
974
975
976
977
978
979
    ready to be passed to vLLM internals.
    """

    encoder_prompt: str
    """The processed encoder prompt text."""

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""