inputs.py 27.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import ABC, abstractmethod
4
from collections import UserDict, defaultdict
5
6
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
7
from functools import partial
8
from itertools import accumulate
9
10
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
                    Union, cast, final)
11
12

import numpy as np
13
from typing_extensions import NotRequired, TypeAlias
14

15
from vllm.jsontree import JSONTree, json_map_leaves
16
from vllm.utils import LazyLoader, full_groupby, is_list_of
17

18
if TYPE_CHECKING:
19
20
21
22
23
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

24
    from .hasher import MultiModalHashDict
25
26
else:
    torch = LazyLoader("torch", globals(), "torch")
27

28
29
_T = TypeVar("_T")

30
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
31
"""
32
A `transformers.image_utils.ImageInput` representing a single image
33
item, which can be passed to a HuggingFace `ImageProcessor`.
34
35
"""

36
37
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
                               list[np.ndarray], list["torch.Tensor"]]
38
"""
39
A `transformers.image_utils.VideoInput` representing a single video
40
item, which can be passed to a HuggingFace `VideoProcessor`.
41
42
"""

43
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
44
"""
45
Represents a single audio
46
item, which can be passed to a HuggingFace `AudioProcessor`.
47
48
"""

49
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
50
"""
51
A `transformers.image_utils.ImageInput` representing a single image
52
item, which can be passed to a HuggingFace `ImageProcessor`.
53
54
55
56
57
58

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

59
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"]
60
"""
61
A `transformers.image_utils.VideoInput` representing a single video
62
item, which can be passed to a HuggingFace `VideoProcessor`.
63
64
65
66
67
68
69

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
70
                             "torch.Tensor"]
71
72
"""
Represents a single audio
73
item, which can be passed to a HuggingFace `AudioProcessor`.
74
75
76
77
78
79
80
81
82
83
84

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

ModalityData: TypeAlias = Union[_T, list[_T]]
85
86
87
88
"""
Either a single data item, or a list of data items.

The number of data items allowed per modality is restricted by
89
`--limit-mm-per-prompt`.
90
91
92
93
94
95
96
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

97
    image: ModalityData[ImageItem]
98
99
    """The input image(s)."""

100
    video: ModalityData[VideoItem]
101
102
    """The input video(s)."""

103
    audio: ModalityData[AudioItem]
104
105
106
    """The input audio(s)."""


107
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
108
109
"""
A dictionary containing an entry for each modality type to input.
110

111
112
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
113
114
115
"""


116
117
@dataclass(frozen=True)
class PlaceholderRange:
118
119
120
    """
    Placeholder location information for multi-modal data.

121
122
    Example:

123
    Prompt: `AAAA BBBB What is in these images?`
124

125
    Images A and B will have:
126

127
128
129
130
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
131
132
133
134
135
136
137
138
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

139
    is_embed: Optional["torch.Tensor"] = None
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

164

165
166
NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
                                 "torch.Tensor", tuple["torch.Tensor", ...]]
167
168
169
170
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

171
172

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
173
174
    """Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
175
    if isinstance(a, torch.Tensor):
176
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
177
    elif isinstance(b, torch.Tensor):
178
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
179
180
181
182
183
184
185
186
187
188
189
190
191

    if isinstance(a, list):
        return (isinstance(b, list)
                and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
    if isinstance(b, list):
        return (isinstance(a, list)
                and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))

    # Both a and b are scalars
    return a == b


BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
192
193
"""
A dictionary containing nested tensors which have been batched via
194
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
195
196
197
"""


198
@dataclass(frozen=True)
199
class MultiModalFieldElem:
200
201
    """
    Represents a keyword argument corresponding to a multi-modal item
202
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
203
204
205
206
207
208
209
210
211
212
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
213
214
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
215
216
217
    i.e. the name of the keyword argument to be passed to the model.
    """

218
    data: NestedTensors
219
    """
220
221
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
222
223
224
225
226
227
228
229
    i.e. the value of the keyword argument to be passed to the model.
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
230
231
232
233
234

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

235
236
237
        return ((self.modality, self.key) == (other.modality, other.key)
                and nested_tensors_equal(self.data, other.data)
                and type(self.field) == type(other.field))  # noqa: E721
238
239
240
241


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
242
243
    """
    Defines how to interpret tensor data belonging to a keyword argument in
244
245
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
261
262

    @abstractmethod
263
264
265
266
267
268
269
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
270
271
272
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
273

274
275
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
276
        """
277
278
        raise NotImplementedError

279
280
281
    @abstractmethod
    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        raise NotImplementedError
282

283
284
    def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
        """
285
286
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
287

288
289
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
290
291
292
293
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
294

295
        return self._reduce_data([item.data for item in elems])
296
297
298
299
300


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
301
    Info:
302
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
303
304
    """

305
306
307
308
309
310
311
312
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
313
314
315

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
316
317
318
319
320
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
321
            first_shape = batch[0].shape
322
            if all(elem.shape == first_shape for elem in batch):
323
324
325
326
327
328
329
330
                return torch.stack(batch)

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
331
    Info:
332
333
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
334
    """
335
336
    slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
    dim: int = 0
337

338
    def build_elems(
339
        self,
340
341
342
343
344
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
345
346
347
348
        if not is_list_of(self.slices, slice, check="all"):
            assert isinstance(data, torch.Tensor), \
                "torch.Tensor is required for multiple slices"
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
349
350
351

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
352
353
354
355
356
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
357

358
359
360
361
362
363
364
365
366
            def _expect_same_shape(tensor: torch.Tensor):
                return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:]

            first_shape = _expect_same_shape(batch[0])

            if all(_expect_same_shape(elem) == first_shape for elem in batch):
                return torch.concat(batch, dim=self.dim)

        assert self.dim == 0, "dim == 0 is required for nested list"
367
        return [e for elem in batch for e in elem]
368
369


370
371
372
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
373
    Info:
374
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    """
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        return batch[0]


391
392
393
394
class MultiModalFieldConfig:

    @staticmethod
    def batched(modality: str):
395
396
397
398
399
400
401
402
403
404
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

405
406
407
408
409
410
411
412
413
414
415
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
416
        """
417
        return MultiModalFieldConfig(
418
            field=MultiModalBatchedField(),
419
420
421
422
            modality=modality,
        )

    @staticmethod
423
424
425
    def flat(modality: str,
             slices: Union[Sequence[slice], Sequence[Sequence[slice]]],
             dim: int = 0):
426
427
428
429
430
431
432
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
433
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
434
                slices (dim>0) that is used to extract the data corresponding
435
436
                to it.
            dim: The dimension to extract data, default to 0.
437
438
439

        Example:

440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
469
        """
470
        return MultiModalFieldConfig(
471
            field=MultiModalFlatField(slices=slices, dim=dim),
472
473
474
            modality=modality,
        )

475
    @staticmethod
476
    def flat_from_sizes(modality: str,
477
                        size_per_item: "torch.Tensor",
478
                        dim: int = 0):
479
480
481
482
483
484
485
486
487
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            slices: For each multi-modal item, the size of the slice that
                is used to extract the data corresponding to it.
488
            dim: The dimension to slice, default to 0.
489
490
491

        Example:

492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [3, 4, 2]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

519
        Info:
520
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
521
522
        """

523
524
525
526
        if size_per_item.ndim != 1:
            raise ValueError("size_per_item should be a 1-D tensor, "
                             f"but found shape: {size_per_item.shape}")

527
        slice_idxs = [0, *accumulate(size_per_item)]
528
529
530
        slices = [(slice(None, None, None), ) * dim +
                  (slice(slice_idxs[i], slice_idxs[i + 1]), )
                  for i in range(len(size_per_item))]
531

532
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
533

534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

549
550
551
        ```
        Given:
            batch_size: 4
552

553
554
        Input:
            Data: [XYZ]
555

556
557
558
559
560
561
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
562
563
564
565
566
567
568
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
569
570
        super().__init__()

571
        self.field = field
572
        self.modality = modality
573

574
    def build_elems(
575
576
577
        self,
        key: str,
        batch: NestedTensors,
578
    ) -> Sequence[MultiModalFieldElem]:
579
        return self.field.build_elems(self.modality, key, batch)
580
581


582
583
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
584
585
586
587
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
588
    """
589

590
591
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
592
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
593

594
595
    @property
    def modality(self) -> str:
596
        modalities = {elem.modality for elem in self.data.values()}
597
598
        assert len(modalities) == 1, f"Found different modalities={modalities}"
        return next(iter(modalities))
599
600


601
602
603
604
605
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
606
    [`torch.nn.Module.forward`][].
607

608
    The metadata `items` enables us to obtain the keyword arguments
609
610
611
612
    corresponding to each data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via
    [`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and
    [`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items].
613
614
    """

615
616
    @staticmethod
    def from_hf_inputs(
617
        hf_inputs: "BatchFeature",
618
619
620
621
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
                    f"Found: {batch_sizes=}")

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

        return MultiModalKwargs.from_items(items)
648
649

    @staticmethod
650
    def from_items(items: Sequence[MultiModalKwargsItem]):
651
652
653
        """Construct a new
        [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
        from multiple items."""
654
655
656
657
658
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
        for item in items:
            for key, elem in item.items():
                elems_by_key[key].append(elem)

659
        data = {
660
            key: elems[0].field.reduce_data(elems)
661
            for key, elems in elems_by_key.items() if len(elems) > 0
662
663
        }

664
        return MultiModalKwargs(data, items=items)
665
666
667
668
669

    def __init__(
        self,
        data: Mapping[str, NestedTensors],
        *,
670
        items: Optional[Sequence[MultiModalKwargsItem]] = None,
671
672
673
    ) -> None:
        super().__init__(data)

674
675
        items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
        self._items_by_modality = dict(items_by_modality)
676

677
678
679
    @property
    def modalities(self):
        return self._items_by_modality.keys()
680

681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
    @staticmethod
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

        stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

704
        tensors_ = cast(list[torch.Tensor], stacked)
705
706
707
708
709
710
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

711
712
713
714
715
716
717
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

        return torch.stack(tensors_)

    @staticmethod
718
    def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
733
        item_lists = defaultdict[str, list[NestedTensors]](list)
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
            k: MultiModalKwargs._try_stack(item_list)
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
749
        dtype: Optional[torch.dtype] = None,
750
751
752
    ) -> BatchedTensorInputs:
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

753
754
755
756
        def maybe_cast_dtype(x: torch.Tensor):
            # This mimics the behavior of transformers.BatchFeature
            return x.to(dtype=dtype) if x.is_floating_point() else x

757
        json_mapped = json_map_leaves(
758
759
            # NOTE: Cast the dtype before sending it to device
            lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
760
761
762
763
764
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)

765
766
767
768
769
770
771
    def __delitem__(self, key: str) -> None:
        super().__delitem__(key)

        for items in self._items_by_modality.values():
            for item in items:
                item.pop(key, None)

772
773
774
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
775
        if self._items_by_modality != other._items_by_modality:
776
777
778
779
780
781
            return False

        ks = self.keys()
        return (ks == other.keys()
                and all(nested_tensors_equal(self[k], other[k]) for k in ks))

782
783
784
785
786
    def _validate_modality(self, method_name: str, modality: str) -> None:
        if not self._items_by_modality:
            raise RuntimeError(
                f"`{method_name}` is not supported when "
                "MultiModalKwargs is not initialized with `items`")
787

788
789
        if modality not in self._items_by_modality:
            available_modalities = set(self._items_by_modality.keys())
790
791
792
            raise KeyError(f"Modality {modality!r} not found. "
                           f"Available modalities: {available_modalities}")

793
794
795
796
    def get_item_count(self, modality: str) -> int:
        """Get the number of items belonging to a modality."""
        self._validate_modality("get_item_count", modality)
        return len(self._items_by_modality[modality])
797

798
799
800
801
802
803
804
    def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
        """
        Get the keyword arguments corresponding to an item identified by
        its modality and index.
        """
        self._validate_modality("get_item", modality)
        return self._items_by_modality[modality][item_index]
805

806
    def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
807
        """
808
809
        Get the keyword arguments corresponding to each item belonging to
        a modality.
810
        """
811
812
        self._validate_modality("get_items", modality)
        return self._items_by_modality[modality]
813

814

815
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
816
"""
817
A dictionary containing placeholder ranges for each modality.
818
819
820
"""


821
class MultiModalInputs(TypedDict):
822
    """
823
    Represents the outputs of
824
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
825
826
827
828
829
830
831
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

    prompt: str
832
    """The processed prompt text."""
833

834
    prompt_token_ids: list[int]
835
836
    """The processed token IDs which includes placeholder tokens."""

837
    token_type_ids: NotRequired[list[int]]
838
839
    """The token type IDs of the prompt."""

840
841
842
    mm_kwargs: MultiModalKwargs
    """Keyword arguments to be directly passed to the model after batching."""

843
    mm_hashes: Optional["MultiModalHashDict"]
844
845
    """The hashes of the multi-modal data."""

846
    mm_placeholders: "MultiModalPlaceholderDict"
847
848
    """
    For each modality, information about the placeholder tokens in
849
    `prompt_token_ids`.
850
    """
851

852
853
854
855
856
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

857
858
859

class MultiModalEncDecInputs(MultiModalInputs):
    """
860
861
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
862
863
864
865
866
867
868
869
870
871
872
    ready to be passed to vLLM internals.
    """

    encoder_prompt: str
    """The processed encoder prompt text."""

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""

    encoder_token_type_ids: NotRequired[list[int]]
    """The token type IDs of the encoder prompt."""