inputs.py 30.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast, final
11
12

import numpy as np
13
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
14

15
from vllm.utils import LazyLoader, full_groupby, is_list_of
16
from vllm.utils.jsontree import json_map_leaves
17

18
if TYPE_CHECKING:
19
20
21
22
23
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

24
25
    from .processing import MultiModalHashes

26
27
else:
    torch = LazyLoader("torch", globals(), "torch")
28

29
30
_T = TypeVar("_T")

31
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
32
"""
33
A `transformers.image_utils.ImageInput` representing a single image
34
item, which can be passed to a HuggingFace `ImageProcessor`.
35
36
"""

37
38
39
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
40
"""
41
A `transformers.image_utils.VideoInput` representing a single video
42
item, which can be passed to a HuggingFace `VideoProcessor`.
43
44
"""

45
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
46
"""
47
Represents a single audio
48
item, which can be passed to a HuggingFace `AudioProcessor`.
49
50
"""

51
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
52
"""
53
A `transformers.image_utils.ImageInput` representing a single image
54
item, which can be passed to a HuggingFace `ImageProcessor`.
55
56
57
58
59
60

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

61
62
63
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
64
"""
65
66
67
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
68
69
70
71
72
73

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

74
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
75
76
"""
Represents a single audio
77
item, which can be passed to a HuggingFace `AudioProcessor`.
78
79
80
81
82
83
84
85
86
87

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

88
ModalityData: TypeAlias = Union[_T, list[Optional[_T]], None]
89
"""
90
91
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
92
93

The number of data items allowed per modality is restricted by
94
`--limit-mm-per-prompt`.
95
96
97
98
99
100
101
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

102
    image: ModalityData[ImageItem]
103
104
    """The input image(s)."""

105
    video: ModalityData[VideoItem]
106
107
    """The input video(s)."""

108
    audio: ModalityData[AudioItem]
109
110
111
    """The input audio(s)."""


112
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
113
114
"""
A dictionary containing an entry for each modality type to input.
115

116
117
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
118
119
"""

120
121
122
123
124
125
126
127
128
129
MultiModalUUIDDict: TypeAlias = Mapping[str, Union[list[Optional[str]], str]]
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

130

131
132
@dataclass(frozen=True)
class PlaceholderRange:
133
134
135
    """
    Placeholder location information for multi-modal data.

136
137
    Example:

138
    Prompt: `AAAA BBBB What is in these images?`
139

140
    Images A and B will have:
141

142
143
144
145
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
146
147
148
149
150
151
152
153
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

154
    is_embed: Optional["torch.Tensor"] = None
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

179

180
181
182
183
184
185
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
186
187
188
189
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

190
191

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
192
193
    """Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
194
    if isinstance(a, torch.Tensor):
195
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
196
    elif isinstance(b, torch.Tensor):
197
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
198
199

    if isinstance(a, list):
200
201
202
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
203
    if isinstance(b, list):
204
205
206
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
207
208
209
210
211

    # Both a and b are scalars
    return a == b


212
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
213
214
"""
A dictionary containing nested tensors which have been batched via
215
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
216
217
218
"""


219
220
221
222
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
223

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""


242
@dataclass
243
class MultiModalFieldElem:
244
245
    """
    Represents a keyword argument corresponding to a multi-modal item
246
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
247
248
249
250
251
252
253
254
255
256
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
257
258
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
259
260
261
    i.e. the name of the keyword argument to be passed to the model.
    """

262
    data: NestedTensors
263
    """
264
265
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
266
    i.e. the value of the keyword argument to be passed to the model.
267
268
269

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
270
271
272
273
274
275
276
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
277
278
279
280
281

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

282
283
284
285
286
287
288
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

289
290
291
292
293
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
            and type(self.field) == type(other.field)
        )  # noqa: E721
294
295
296
297


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
298
299
    """
    Defines how to interpret tensor data belonging to a keyword argument in
300
301
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
317
318

    @abstractmethod
319
320
321
322
323
324
325
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
326
327
328
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
329

330
331
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
332
        """
333
334
        raise NotImplementedError

335
    @abstractmethod
336
337
338
339
340
341
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
342
        raise NotImplementedError
343

344
345
346
347
348
349
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
        pin_memory: bool = False,
    ) -> NestedTensors:
350
        """
351
352
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
353

354
355
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
356
357
358
359
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
360

361
362
        batch = [elem.data for elem in elems]
        return self._reduce_data(batch, pin_memory=pin_memory)
363
364
365
366
367


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
368
    Info:
369
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
370
371
    """

372
373
374
375
376
377
378
379
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
380

381
382
383
384
385
386
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
387
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
388
            batch = cast(list[torch.Tensor], batch)
389
390
391
392
393
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
394
            first_shape = batch[0].shape
395
            if all(elem.shape == first_shape for elem in batch):
396
397
398
399
400
401
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
402
                return torch.stack(batch, out=out)
403
404
405
406
407
408
409

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
410
    Info:
411
412
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
413
    """
414

415
416
    slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
    dim: int = 0
417

418
    def build_elems(
419
        self,
420
421
422
423
424
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
425
        if not is_list_of(self.slices, slice, check="all"):
426
            assert isinstance(data, torch.Tensor), (
427
                "torch.Tensor is required for multiple slices"
428
            )
429
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
430

431
432
433
434
435
436
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
437
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
438
            batch = cast(list[torch.Tensor], batch)
439
440
441
442
443
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
444

445
446
447
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
448
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
449

450
            first_shape = _shape_before_after(batch[0])
451

452
453
454
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
455
456
457
458
459
460
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
461
                return torch.concat(batch, dim=self.dim, out=out)
462
463

        assert self.dim == 0, "dim == 0 is required for nested list"
464
        return [e for elem in batch for e in elem]
465
466


467
468
469
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
470
    Info:
471
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
472
    """
473

474
475
476
477
478
479
480
481
482
483
484
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

485
486
487
488
489
490
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
491
492
493
        return batch[0]


494
495
496
class MultiModalFieldConfig:
    @staticmethod
    def batched(modality: str):
497
498
499
500
501
502
503
504
505
506
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

507
508
509
510
511
512
513
514
515
516
517
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
518
        """
519
        return MultiModalFieldConfig(
520
            field=MultiModalBatchedField(),
521
522
523
524
            modality=modality,
        )

    @staticmethod
525
526
527
528
529
    def flat(
        modality: str,
        slices: Union[Sequence[slice], Sequence[Sequence[slice]]],
        dim: int = 0,
    ):
530
531
532
533
534
535
536
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
537
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
538
                slices (dim>0) that is used to extract the data corresponding
539
540
                to it.
            dim: The dimension to extract data, default to 0.
541
542
543

        Example:

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
573
        """
574
        return MultiModalFieldConfig(
575
            field=MultiModalFlatField(slices=slices, dim=dim),
576
577
578
            modality=modality,
        )

579
    @staticmethod
580
    def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0):
581
582
583
584
585
586
587
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
588
589
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
590
            dim: The dimension to slice, default to 0.
591
592
593

        Example:

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
609
            size_per_item: [3, 4, 2]
610
611
612
613
614
615
616
617
618
619
620
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

621
        Info:
622
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
623
624
        """

625
        if size_per_item.ndim != 1:
626
627
628
629
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
630

631
        slice_idxs = [0, *accumulate(size_per_item)]
632
633
634
635
636
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
637

638
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
639

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

655
656
657
        ```
        Given:
            batch_size: 4
658

659
660
        Input:
            Data: [XYZ]
661

662
663
664
665
666
667
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
668
669
670
671
672
673
674
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
675
676
        super().__init__()

677
        self.field = field
678
        self.modality = modality
679

680
    def build_elems(
681
682
683
        self,
        key: str,
        batch: NestedTensors,
684
    ) -> Sequence[MultiModalFieldElem]:
685
        return self.field.build_elems(self.modality, key, batch)
686
687


688
689
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
690
691
692
693
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
694
    """
695

696
697
698
699
700
701
702
703
704
705
706
    @staticmethod
    def dummy(modality: str):
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
            data=torch.empty(1),
            field=MultiModalSharedField(1),
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

707
708
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
709
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
710

711
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
712
713
        super().__init__(data)

714
        modalities = {elem.modality for elem in self.values()}
715
        assert len(modalities) == 1, f"Found different modalities={modalities}"
716
717
718
719
720
721
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

722
    def get_data(self) -> dict[str, NestedTensors]:
723
        return {key: elem.data for key, elem in self.items()}
724
725


726
727
728
729
730
731
732
733
734
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
    Optional[MultiModalKwargsItem],
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
735
    """
736
737
738
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
739
740
    """

741
742
    @staticmethod
    def from_hf_inputs(
743
        hf_inputs: "BatchFeature",
744
745
746
747
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
766
767
                    f"Found: {batch_sizes=}"
                )
768
769
770
771
772
773

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

774
        return MultiModalKwargsItems.from_seq(items)
775

776
777
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
778
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
779
        return MultiModalKwargsItems(items_by_modality)
780

781
    def __getitem__(self, modality: str) -> Sequence[_I]:
782
        if modality not in self:
783
784
785
786
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
787

788
        return super().__getitem__(modality)  # type: ignore[return-value]
789

790
791
792
793
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
794
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
795
796
797

        return self  # type: ignore[return-value]

798
799
    def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
800
801
802
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
803
804
805
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
806

807
808
809
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

810
811
812
813
814
815
        return MultiModalKwargs(
            {
                key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
                for key, elems in elems_by_key.items()
            }
        )
816

817

818
819
820
821
822
823
MultiModalKwargsOptionalItems: TypeAlias = Union[
    MultiModalKwargsItems[MultiModalKwargsItem],
    MultiModalKwargsItems[Optional[MultiModalKwargsItem]],
]


824
825
826
827
828
829
830
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
831
832
833
834
835
836
    @deprecated(
        "`MultiModalKwargs.from_hf_inputs` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_hf_inputs` and "
        "access the tensor data using `.get_data()`."
    )
837
838
839
840
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
841
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data()
842
843

    @staticmethod
844
845
846
847
848
849
    @deprecated(
        "`MultiModalKwargs.from_items` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_seq` and "
        "access the tensor data using `.get_data()`."
    )
850
851
852
853
854
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
855
        return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
856

857
    @staticmethod
858
859
860
    def _try_stack(
        nested_tensors: NestedTensors, pin_memory: bool = False
    ) -> NestedTensors:
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

877
        stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors]
878
879
880
881
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

882
        tensors_ = cast(list[torch.Tensor], stacked)
883
884
885
886
887
888
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

889
890
891
892
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

893
894
895
896
897
898
899
        outputs = torch.empty(
            len(tensors_),
            *tensors_[0].shape,
            dtype=tensors_[0].dtype,
            device=tensors_[0].device,
            pin_memory=pin_memory,
        )
900
        return torch.stack(tensors_, out=outputs)
901
902

    @staticmethod
903
904
905
    def batch(
        inputs_list: list["MultiModalKwargs"], pin_memory: bool = False
    ) -> BatchedTensorInputs:
906
907
908
909
910
911
912
913
914
915
916
917
918
919
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
920
        item_lists = defaultdict[str, list[NestedTensors]](list)
921
922
923
924
925
926

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
927
            k: MultiModalKwargs._try_stack(item_list, pin_memory)
928
929
930
931
932
933
934
935
936
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
937
        return json_map_leaves(
938
            lambda x: x.to(device=device, non_blocking=True),
939
            batched_inputs,
940
941
        )

942
    def __getitem__(self, key: str):
943
        if key not in self:
944
945
946
947
            raise KeyError(
                f"Keyword argument {key!r} not found. "
                f"Available keys: {set(self.keys())}"
            )
948
949

        return super().__getitem__(key)
950

951
952
953
954
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

955
956
957
958
959
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
960

961
        return True
962

963

964
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
965
"""
966
A dictionary containing placeholder ranges for each modality.
967
968
969
"""


970
class MultiModalInputs(TypedDict):
971
    """
972
    Represents the outputs of
973
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
974
975
976
977
978
979
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

980
    prompt_token_ids: list[int]
981
982
    """The processed token IDs which includes placeholder tokens."""

983
    mm_kwargs: MultiModalKwargsOptionalItems
984
985
    """Keyword arguments to be directly passed to the model after batching."""

986
    mm_hashes: "MultiModalHashes"
987
988
    """The hashes of the multi-modal data."""

989
    mm_placeholders: "MultiModalPlaceholderDict"
990
991
    """
    For each modality, information about the placeholder tokens in
992
    `prompt_token_ids`.
993
    """
994

995
996
997
998
999
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1000
1001
1002

class MultiModalEncDecInputs(MultiModalInputs):
    """
1003
1004
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1005
1006
1007
1008
1009
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""