inputs.py 31.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar, deprecated
24

25
from vllm.utils.collection_utils import full_groupby, is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
if TYPE_CHECKING:
30
31
32
33
34
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

35
    from .base import MediaWithBytes
36
37
    from .processing import MultiModalHashes

38
39
else:
    torch = LazyLoader("torch", globals(), "torch")
40

41
42
_T = TypeVar("_T")

43
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
44
"""
45
A `transformers.image_utils.ImageInput` representing a single image
46
item, which can be passed to a HuggingFace `ImageProcessor`.
47
48
"""

49
50
51
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
52
"""
53
A `transformers.image_utils.VideoInput` representing a single video
54
item, which can be passed to a HuggingFace `VideoProcessor`.
55
56
"""

57
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
58
"""
59
Represents a single audio
60
item, which can be passed to a HuggingFace `AudioProcessor`.
61
62
"""

63
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
64
"""
65
A `transformers.image_utils.ImageInput` representing a single image
66
item, which can be passed to a HuggingFace `ImageProcessor`.
67
68
69
70
71
72

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

73
74
75
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
76
"""
77
78
79
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
80
81
82
83
84
85

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

86
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
87
88
"""
Represents a single audio
89
item, which can be passed to a HuggingFace `AudioProcessor`.
90
91
92
93
94
95
96
97
98
99

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

100
ModalityData: TypeAlias = _T | list[_T | None] | None
101
"""
102
103
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
104
105

The number of data items allowed per modality is restricted by
106
`--limit-mm-per-prompt`.
107
108
109
110
111
112
113
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

114
    image: ModalityData[ImageItem]
115
116
    """The input image(s)."""

117
    video: ModalityData[VideoItem]
118
119
    """The input video(s)."""

120
    audio: ModalityData[AudioItem]
121
122
123
    """The input audio(s)."""


124
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
125
126
"""
A dictionary containing an entry for each modality type to input.
127

128
129
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
130
131
"""

132
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
133
134
135
136
137
138
139
140
141
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

142

143
144
@dataclass(frozen=True)
class PlaceholderRange:
145
146
147
    """
    Placeholder location information for multi-modal data.

148
149
    Example:

150
    Prompt: `AAAA BBBB What is in these images?`
151

152
    Images A and B will have:
153

154
155
156
157
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
158
159
160
161
162
163
164
165
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

166
    is_embed: Optional["torch.Tensor"] = None
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

191

192
193
194
195
196
197
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
198
199
200
201
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

202
203

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
204
205
    """Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
206
    if isinstance(a, torch.Tensor):
207
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
208
    elif isinstance(b, torch.Tensor):
209
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
210
211

    if isinstance(a, list):
212
213
214
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
215
    if isinstance(b, list):
216
217
218
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
219
220
221
222
223

    # Both a and b are scalars
    return a == b


224
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
225
226
"""
A dictionary containing nested tensors which have been batched via
227
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
228
229
230
"""


231
232
233
234
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
235

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""

253
254
255
256
257
258
259
260
261
262
263
264
265
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

266

267
@dataclass
268
class MultiModalFieldElem:
269
270
    """
    Represents a keyword argument corresponding to a multi-modal item
271
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
272
273
274
275
276
277
278
279
280
281
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
282
283
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
284
285
286
    i.e. the name of the keyword argument to be passed to the model.
    """

287
    data: NestedTensors
288
    """
289
290
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
291
    i.e. the value of the keyword argument to be passed to the model.
292
293
294

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
295
296
297
298
299
300
301
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
302
303
304
305
306

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

307
308
309
310
311
312
313
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

314
315
316
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
317
            and type(self.field) is type(other.field)
318
        )  # noqa: E721
319
320
321
322


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
323
324
    """
    Defines how to interpret tensor data belonging to a keyword argument in
325
326
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
342
343

    @abstractmethod
344
345
346
347
348
349
350
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
351
352
353
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
354

355
356
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
357
        """
358
359
        raise NotImplementedError

360
    @abstractmethod
361
362
363
364
365
366
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
367
        raise NotImplementedError
368

369
370
371
372
373
374
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
        pin_memory: bool = False,
    ) -> NestedTensors:
375
        """
376
377
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
378

379
380
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
381
382
383
384
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
385

386
387
        batch = [elem.data for elem in elems]
        return self._reduce_data(batch, pin_memory=pin_memory)
388
389
390
391
392


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
393
    Info:
394
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
395
396
    """

397
398
399
400
401
402
403
404
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
405

406
407
408
409
410
411
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
412
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
413
            batch = cast(list[torch.Tensor], batch)
414
415
416
417
418
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
419
            first_shape = batch[0].shape
420
            if all(elem.shape == first_shape for elem in batch):
421
422
423
424
425
426
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
427
                return torch.stack(batch, out=out)
428
429
430
431
432
433
434

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
435
    Info:
436
437
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
438
    """
439

440
    slices: Sequence[slice] | Sequence[Sequence[slice]]
441
    dim: int = 0
442

443
    def build_elems(
444
        self,
445
446
447
448
449
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
450
        if not is_list_of(self.slices, slice, check="all"):
451
            assert isinstance(data, torch.Tensor), (
452
                "torch.Tensor is required for multiple slices"
453
            )
454
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
455

456
457
458
459
460
461
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
462
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
463
            batch = cast(list[torch.Tensor], batch)
464
465
466
467
468
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
469

470
471
472
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
473
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
474

475
            first_shape = _shape_before_after(batch[0])
476

477
478
479
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
480
481
482
483
484
485
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
486
                return torch.concat(batch, dim=self.dim, out=out)
487
488

        assert self.dim == 0, "dim == 0 is required for nested list"
489
        return [e for elem in batch for e in elem]
490
491


492
493
494
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
495
    Info:
496
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
497
    """
498

499
500
501
502
503
504
505
506
507
508
509
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

510
511
512
513
514
515
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
516
517
518
        return batch[0]


519
520
521
class MultiModalFieldConfig:
    @staticmethod
    def batched(modality: str):
522
523
524
525
526
527
528
529
530
531
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

532
533
534
535
536
537
538
539
540
541
542
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
543
        """
544
        return MultiModalFieldConfig(
545
            field=MultiModalBatchedField(),
546
547
548
549
            modality=modality,
        )

    @staticmethod
550
551
    def flat(
        modality: str,
552
        slices: Sequence[slice] | Sequence[Sequence[slice]],
553
554
        dim: int = 0,
    ):
555
556
557
558
559
560
561
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
562
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
563
                slices (dim>0) that is used to extract the data corresponding
564
565
                to it.
            dim: The dimension to extract data, default to 0.
566
567
568

        Example:

569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
598
        """
599
        return MultiModalFieldConfig(
600
            field=MultiModalFlatField(slices=slices, dim=dim),
601
602
603
            modality=modality,
        )

604
    @staticmethod
605
    def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0):
606
607
608
609
610
611
612
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
613
614
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
615
            dim: The dimension to slice, default to 0.
616
617
618

        Example:

619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
634
            size_per_item: [3, 4, 2]
635
636
637
638
639
640
641
642
643
644
645
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

646
        Info:
647
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
648
649
        """

650
        if size_per_item.ndim != 1:
651
652
653
654
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
655

656
        slice_idxs = [0, *accumulate(size_per_item)]
657
658
659
660
661
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
662

663
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
664

665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

680
681
682
        ```
        Given:
            batch_size: 4
683

684
685
        Input:
            Data: [XYZ]
686

687
688
689
690
691
692
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
693
694
695
696
697
698
699
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
700
701
        super().__init__()

702
        self.field = field
703
        self.modality = modality
704

705
706
707
    def __repr__(self) -> str:
        return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"

708
    def build_elems(
709
710
711
        self,
        key: str,
        batch: NestedTensors,
712
    ) -> Sequence[MultiModalFieldElem]:
713
        return self.field.build_elems(self.modality, key, batch)
714
715


716
717
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
718
719
720
721
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
722
    """
723

724
    @staticmethod
725
    def dummy(modality: str, nbytes: int = 1):
726
727
728
729
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
730
            data=torch.empty(nbytes, dtype=torch.uint8),
731
732
733
734
            field=MultiModalSharedField(1),
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

735
736
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
737
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
738

739
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
740
741
        super().__init__(data)

742
        modalities = {elem.modality for elem in self.values()}
743
        assert len(modalities) == 1, f"Found different modalities={modalities}"
744
745
746
747
748
749
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

750
    def get_data(self) -> dict[str, NestedTensors]:
751
        return {key: elem.data for key, elem in self.items()}
752
753


754
755
756
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
757
    MultiModalKwargsItem | None,
758
759
760
761
762
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
763
    """
764
765
766
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
767
768
    """

769
770
    @staticmethod
    def from_hf_inputs(
771
        hf_inputs: "BatchFeature",
772
773
774
775
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
794
795
                    f"Found: {batch_sizes=}"
                )
796
797
798
799
800
801

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

802
        return MultiModalKwargsItems.from_seq(items)
803

804
805
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
806
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
807
        return MultiModalKwargsItems(items_by_modality)
808

809
    def __getitem__(self, modality: str) -> Sequence[_I]:
810
        if modality not in self:
811
812
813
814
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
815

816
        return super().__getitem__(modality)  # type: ignore[return-value]
817

818
819
820
821
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
822
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
823
824
825

        return self  # type: ignore[return-value]

826
827
    def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
828
829
830
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
831
832
833
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
834

835
836
837
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

838
839
840
841
842
843
        return MultiModalKwargs(
            {
                key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
                for key, elems in elems_by_key.items()
            }
        )
844

845

846
847
848
849
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
850
851


852
853
854
855
856
857
858
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
859
860
861
862
863
864
    @deprecated(
        "`MultiModalKwargs.from_hf_inputs` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_hf_inputs` and "
        "access the tensor data using `.get_data()`."
    )
865
866
867
868
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
869
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data()
870
871

    @staticmethod
872
873
874
875
876
877
    @deprecated(
        "`MultiModalKwargs.from_items` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_seq` and "
        "access the tensor data using `.get_data()`."
    )
878
879
880
881
882
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
883
        return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
884

885
    @staticmethod
886
887
888
    def _try_stack(
        nested_tensors: NestedTensors, pin_memory: bool = False
    ) -> NestedTensors:
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

905
        stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors]
906
907
908
909
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

910
        tensors_ = cast(list[torch.Tensor], stacked)
911
912
913
914
915
916
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

917
918
919
920
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

921
922
923
924
925
926
927
        outputs = torch.empty(
            len(tensors_),
            *tensors_[0].shape,
            dtype=tensors_[0].dtype,
            device=tensors_[0].device,
            pin_memory=pin_memory,
        )
928
        return torch.stack(tensors_, out=outputs)
929
930

    @staticmethod
931
932
933
    def batch(
        inputs_list: list["MultiModalKwargs"], pin_memory: bool = False
    ) -> BatchedTensorInputs:
934
935
936
937
938
939
940
941
942
943
944
945
946
947
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
948
        item_lists = defaultdict[str, list[NestedTensors]](list)
949
950
951
952
953
954

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
955
            k: MultiModalKwargs._try_stack(item_list, pin_memory)
956
957
958
959
960
961
962
963
964
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
965
        return json_map_leaves(
966
            lambda x: x.to(device=device, non_blocking=True),
967
            batched_inputs,
968
969
        )

970
    def __getitem__(self, key: str):
971
        if key not in self:
972
973
974
975
            raise KeyError(
                f"Keyword argument {key!r} not found. "
                f"Available keys: {set(self.keys())}"
            )
976
977

        return super().__getitem__(key)
978

979
980
981
982
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

983
984
985
986
987
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
988

989
        return True
990

991

992
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
993
"""
994
A dictionary containing placeholder ranges for each modality.
995
996
997
"""


998
class MultiModalInputs(TypedDict):
999
    """
1000
    Represents the outputs of
1001
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
1002
1003
1004
1005
1006
1007
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

1008
    prompt_token_ids: list[int]
1009
1010
    """The processed token IDs which includes placeholder tokens."""

1011
    mm_kwargs: MultiModalKwargsOptionalItems
1012
1013
    """Keyword arguments to be directly passed to the model after batching."""

1014
    mm_hashes: "MultiModalHashes"
1015
1016
    """The hashes of the multi-modal data."""

1017
    mm_placeholders: "MultiModalPlaceholderDict"
1018
1019
    """
    For each modality, information about the placeholder tokens in
1020
    `prompt_token_ids`.
1021
    """
1022

1023
1024
1025
1026
1027
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1028
1029
1030

class MultiModalEncDecInputs(MultiModalInputs):
    """
1031
1032
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1033
1034
1035
1036
1037
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""