inputs.py 27.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
7
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
                    Union, cast, final)
12
13

import numpy as np
14
from typing_extensions import NotRequired, TypeAlias
15

16
from vllm.jsontree import JSONTree, json_map_leaves
17
from vllm.utils import LazyLoader, full_groupby, is_list_of
18

19
if TYPE_CHECKING:
20
21
22
23
24
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

25
    from .hasher import MultiModalHashDict
26
27
else:
    torch = LazyLoader("torch", globals(), "torch")
28

29
30
_T = TypeVar("_T")

31
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
32
"""
33
A `transformers.image_utils.ImageInput` representing a single image
34
item, which can be passed to a HuggingFace `ImageProcessor`.
35
36
"""

37
38
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
                               list[np.ndarray], list["torch.Tensor"]]
39
"""
40
A `transformers.image_utils.VideoInput` representing a single video
41
item, which can be passed to a HuggingFace `VideoProcessor`.
42
43
"""

44
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
45
"""
46
Represents a single audio
47
item, which can be passed to a HuggingFace `AudioProcessor`.
48
49
"""

50
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
51
"""
52
A `transformers.image_utils.ImageInput` representing a single image
53
item, which can be passed to a HuggingFace `ImageProcessor`.
54
55
56
57
58
59

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

60
61
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor",
                             tuple[HfVideoItem, dict[str, Any]]]
62
"""
63
64
65
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
66
67
68
69
70
71
72

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
73
                             "torch.Tensor"]
74
75
"""
Represents a single audio
76
item, which can be passed to a HuggingFace `AudioProcessor`.
77
78
79
80
81
82
83
84
85
86
87

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

ModalityData: TypeAlias = Union[_T, list[_T]]
88
89
90
91
"""
Either a single data item, or a list of data items.

The number of data items allowed per modality is restricted by
92
`--limit-mm-per-prompt`.
93
94
95
96
97
98
99
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

100
    image: ModalityData[ImageItem]
101
102
    """The input image(s)."""

103
    video: ModalityData[VideoItem]
104
105
    """The input video(s)."""

106
    audio: ModalityData[AudioItem]
107
108
109
    """The input audio(s)."""


110
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
111
112
"""
A dictionary containing an entry for each modality type to input.
113

114
115
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
116
117
118
"""


119
120
@dataclass(frozen=True)
class PlaceholderRange:
121
122
123
    """
    Placeholder location information for multi-modal data.

124
125
    Example:

126
    Prompt: `AAAA BBBB What is in these images?`
127

128
    Images A and B will have:
129

130
131
132
133
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
134
135
136
137
138
139
140
141
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

142
    is_embed: Optional["torch.Tensor"] = None
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

167

168
169
NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
                                 "torch.Tensor", tuple["torch.Tensor", ...]]
170
171
172
173
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

174
175

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
176
177
    """Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
178
    if isinstance(a, torch.Tensor):
179
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
180
    elif isinstance(b, torch.Tensor):
181
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
182
183
184
185
186
187
188
189
190
191
192
193
194

    if isinstance(a, list):
        return (isinstance(b, list)
                and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
    if isinstance(b, list):
        return (isinstance(a, list)
                and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))

    # Both a and b are scalars
    return a == b


BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
195
196
"""
A dictionary containing nested tensors which have been batched via
197
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
198
199
200
"""


201
@dataclass(frozen=True)
202
class MultiModalFieldElem:
203
204
    """
    Represents a keyword argument corresponding to a multi-modal item
205
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
206
207
208
209
210
211
212
213
214
215
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
216
217
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
218
219
220
    i.e. the name of the keyword argument to be passed to the model.
    """

221
    data: NestedTensors
222
    """
223
224
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
225
226
227
228
229
230
231
232
    i.e. the value of the keyword argument to be passed to the model.
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
233
234
235
236
237

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

238
239
240
        return ((self.modality, self.key) == (other.modality, other.key)
                and nested_tensors_equal(self.data, other.data)
                and type(self.field) == type(other.field))  # noqa: E721
241
242
243
244


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
245
246
    """
    Defines how to interpret tensor data belonging to a keyword argument in
247
248
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
264
265

    @abstractmethod
266
267
268
269
270
271
272
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
273
274
275
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
276

277
278
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
279
        """
280
281
        raise NotImplementedError

282
283
284
    @abstractmethod
    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        raise NotImplementedError
285

286
287
    def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
        """
288
289
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
290

291
292
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
293
294
295
296
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
297

298
        return self._reduce_data([item.data for item in elems])
299
300
301
302
303


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
304
    Info:
305
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
306
307
    """

308
309
310
311
312
313
314
315
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
316
317
318

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
319
320
321
322
323
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
324
            first_shape = batch[0].shape
325
            if all(elem.shape == first_shape for elem in batch):
326
327
328
329
330
331
332
333
                return torch.stack(batch)

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
334
    Info:
335
336
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
337
    """
338
339
    slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
    dim: int = 0
340

341
    def build_elems(
342
        self,
343
344
345
346
347
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
348
349
350
351
        if not is_list_of(self.slices, slice, check="all"):
            assert isinstance(data, torch.Tensor), \
                "torch.Tensor is required for multiple slices"
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
352
353
354

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
355
356
357
358
359
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
360

361
362
363
364
365
366
367
368
369
            def _expect_same_shape(tensor: torch.Tensor):
                return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:]

            first_shape = _expect_same_shape(batch[0])

            if all(_expect_same_shape(elem) == first_shape for elem in batch):
                return torch.concat(batch, dim=self.dim)

        assert self.dim == 0, "dim == 0 is required for nested list"
370
        return [e for elem in batch for e in elem]
371
372


373
374
375
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
376
    Info:
377
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    """
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        return batch[0]


394
395
396
397
class MultiModalFieldConfig:

    @staticmethod
    def batched(modality: str):
398
399
400
401
402
403
404
405
406
407
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

408
409
410
411
412
413
414
415
416
417
418
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
419
        """
420
        return MultiModalFieldConfig(
421
            field=MultiModalBatchedField(),
422
423
424
425
            modality=modality,
        )

    @staticmethod
426
427
428
    def flat(modality: str,
             slices: Union[Sequence[slice], Sequence[Sequence[slice]]],
             dim: int = 0):
429
430
431
432
433
434
435
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
436
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
437
                slices (dim>0) that is used to extract the data corresponding
438
439
                to it.
            dim: The dimension to extract data, default to 0.
440
441
442

        Example:

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
472
        """
473
        return MultiModalFieldConfig(
474
            field=MultiModalFlatField(slices=slices, dim=dim),
475
476
477
            modality=modality,
        )

478
    @staticmethod
479
    def flat_from_sizes(modality: str,
480
                        size_per_item: "torch.Tensor",
481
                        dim: int = 0):
482
483
484
485
486
487
488
489
490
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            slices: For each multi-modal item, the size of the slice that
                is used to extract the data corresponding to it.
491
            dim: The dimension to slice, default to 0.
492
493
494

        Example:

495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [3, 4, 2]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

522
        Info:
523
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
524
525
        """

526
527
528
529
        if size_per_item.ndim != 1:
            raise ValueError("size_per_item should be a 1-D tensor, "
                             f"but found shape: {size_per_item.shape}")

530
        slice_idxs = [0, *accumulate(size_per_item)]
531
532
533
        slices = [(slice(None, None, None), ) * dim +
                  (slice(slice_idxs[i], slice_idxs[i + 1]), )
                  for i in range(len(size_per_item))]
534

535
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
536

537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

552
553
554
        ```
        Given:
            batch_size: 4
555

556
557
        Input:
            Data: [XYZ]
558

559
560
561
562
563
564
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
565
566
567
568
569
570
571
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
572
573
        super().__init__()

574
        self.field = field
575
        self.modality = modality
576

577
    def build_elems(
578
579
580
        self,
        key: str,
        batch: NestedTensors,
581
    ) -> Sequence[MultiModalFieldElem]:
582
        return self.field.build_elems(self.modality, key, batch)
583
584


585
586
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
587
588
589
590
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
591
    """
592

593
594
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
595
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
596

597
598
    @property
    def modality(self) -> str:
599
        modalities = {elem.modality for elem in self.data.values()}
600
601
        assert len(modalities) == 1, f"Found different modalities={modalities}"
        return next(iter(modalities))
602
603


604
605
606
607
608
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
609
    [`torch.nn.Module.forward`][].
610

611
    The metadata `items` enables us to obtain the keyword arguments
612
613
614
615
    corresponding to each data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via
    [`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and
    [`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items].
616
617
    """

618
619
    @staticmethod
    def from_hf_inputs(
620
        hf_inputs: "BatchFeature",
621
622
623
624
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
                    f"Found: {batch_sizes=}")

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

        return MultiModalKwargs.from_items(items)
651
652

    @staticmethod
653
    def from_items(items: Sequence[MultiModalKwargsItem]):
654
655
656
        """Construct a new
        [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
        from multiple items."""
657
658
659
660
661
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
        for item in items:
            for key, elem in item.items():
                elems_by_key[key].append(elem)

662
        data = {
663
            key: elems[0].field.reduce_data(elems)
664
            for key, elems in elems_by_key.items() if len(elems) > 0
665
666
        }

667
        return MultiModalKwargs(data, items=items)
668
669
670
671
672

    def __init__(
        self,
        data: Mapping[str, NestedTensors],
        *,
673
        items: Optional[Sequence[MultiModalKwargsItem]] = None,
674
675
676
    ) -> None:
        super().__init__(data)

677
678
        items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
        self._items_by_modality = dict(items_by_modality)
679

680
681
682
    @property
    def modalities(self):
        return self._items_by_modality.keys()
683

684
    @staticmethod
685
686
    def _try_stack(nested_tensors: NestedTensors,
                   pin_memory: bool = False) -> NestedTensors:
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

703
704
705
        stacked = [
            MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors
        ]
706
707
708
709
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

710
        tensors_ = cast(list[torch.Tensor], stacked)
711
712
713
714
715
716
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

717
718
719
720
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

721
722
723
724
725
726
        outputs = torch.empty(len(tensors_),
                              *tensors_[0].shape,
                              dtype=tensors_[0].dtype,
                              device=tensors_[0].device,
                              pin_memory=pin_memory)
        return torch.stack(tensors_, out=outputs)
727
728

    @staticmethod
729
730
    def batch(inputs_list: list["MultiModalKwargs"],
              pin_memory: bool = False) -> BatchedTensorInputs:
731
732
733
734
735
736
737
738
739
740
741
742
743
744
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
745
        item_lists = defaultdict[str, list[NestedTensors]](list)
746
747
748
749
750
751

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
752
            k: MultiModalKwargs._try_stack(item_list, pin_memory)
753
754
755
756
757
758
759
760
761
762
763
764
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
765
            lambda x: x.to(device=device, non_blocking=True),
766
767
768
769
770
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)

771
772
773
774
775
776
777
    def __delitem__(self, key: str) -> None:
        super().__delitem__(key)

        for items in self._items_by_modality.values():
            for item in items:
                item.pop(key, None)

778
779
780
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
781
        if self._items_by_modality != other._items_by_modality:
782
783
784
785
786
787
            return False

        ks = self.keys()
        return (ks == other.keys()
                and all(nested_tensors_equal(self[k], other[k]) for k in ks))

788
789
790
791
792
    def _validate_modality(self, method_name: str, modality: str) -> None:
        if not self._items_by_modality:
            raise RuntimeError(
                f"`{method_name}` is not supported when "
                "MultiModalKwargs is not initialized with `items`")
793

794
795
        if modality not in self._items_by_modality:
            available_modalities = set(self._items_by_modality.keys())
796
797
798
            raise KeyError(f"Modality {modality!r} not found. "
                           f"Available modalities: {available_modalities}")

799
800
801
802
    def get_item_count(self, modality: str) -> int:
        """Get the number of items belonging to a modality."""
        self._validate_modality("get_item_count", modality)
        return len(self._items_by_modality[modality])
803

804
805
806
807
808
809
810
    def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
        """
        Get the keyword arguments corresponding to an item identified by
        its modality and index.
        """
        self._validate_modality("get_item", modality)
        return self._items_by_modality[modality][item_index]
811

812
    def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
813
        """
814
815
        Get the keyword arguments corresponding to each item belonging to
        a modality.
816
        """
817
818
        self._validate_modality("get_items", modality)
        return self._items_by_modality[modality]
819

820

821
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
822
"""
823
A dictionary containing placeholder ranges for each modality.
824
825
826
"""


827
class MultiModalInputs(TypedDict):
828
    """
829
    Represents the outputs of
830
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
831
832
833
834
835
836
837
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

    prompt: str
838
    """The processed prompt text."""
839

840
    prompt_token_ids: list[int]
841
842
    """The processed token IDs which includes placeholder tokens."""

843
    token_type_ids: NotRequired[list[int]]
844
845
    """The token type IDs of the prompt."""

846
847
848
    mm_kwargs: MultiModalKwargs
    """Keyword arguments to be directly passed to the model after batching."""

849
    mm_hashes: Optional["MultiModalHashDict"]
850
851
    """The hashes of the multi-modal data."""

852
    mm_placeholders: "MultiModalPlaceholderDict"
853
854
    """
    For each modality, information about the placeholder tokens in
855
    `prompt_token_ids`.
856
    """
857

858
859
860
861
862
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

863
864
865

class MultiModalEncDecInputs(MultiModalInputs):
    """
866
867
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
868
869
870
871
872
873
874
875
876
877
878
    ready to be passed to vLLM internals.
    """

    encoder_prompt: str
    """The processed encoder prompt text."""

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""

    encoder_token_type_ids: NotRequired[list[int]]
    """The token type IDs of the encoder prompt."""