inputs.py 29.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence, Set
7
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar, deprecated
24

25
from vllm.utils.collection_utils import full_groupby, is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
if TYPE_CHECKING:
30
31
32
33
34
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

35
    from .base import MediaWithBytes
36
37
    from .processing import MultiModalHashes

38
39
else:
    torch = LazyLoader("torch", globals(), "torch")
40

41
42
_T = TypeVar("_T")

43
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
44
"""
45
A `transformers.image_utils.ImageInput` representing a single image
46
item, which can be passed to a HuggingFace `ImageProcessor`.
47
48
"""

49
50
51
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
52
"""
53
A `transformers.image_utils.VideoInput` representing a single video
54
item, which can be passed to a HuggingFace `VideoProcessor`.
55
56
"""

57
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
58
"""
59
Represents a single audio
60
item, which can be passed to a HuggingFace `AudioProcessor`.
61
62
"""

63
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
64
"""
65
A `transformers.image_utils.ImageInput` representing a single image
66
item, which can be passed to a HuggingFace `ImageProcessor`.
67
68
69
70
71
72

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

73
74
75
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
76
"""
77
78
79
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
80
81
82
83
84
85

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

86
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
87
88
"""
Represents a single audio
89
item, which can be passed to a HuggingFace `AudioProcessor`.
90
91
92
93
94
95
96
97
98
99

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

100
ModalityData: TypeAlias = _T | list[_T | None] | None
101
"""
102
103
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
104
105

The number of data items allowed per modality is restricted by
106
`--limit-mm-per-prompt`.
107
108
109
110
111
112
113
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

114
    image: ModalityData[ImageItem]
115
116
    """The input image(s)."""

117
    video: ModalityData[VideoItem]
118
119
    """The input video(s)."""

120
    audio: ModalityData[AudioItem]
121
122
123
    """The input audio(s)."""


124
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
125
126
"""
A dictionary containing an entry for each modality type to input.
127

128
129
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
130
131
"""

132
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
133
134
135
136
137
138
139
140
141
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

142

143
144
@dataclass(frozen=True)
class PlaceholderRange:
145
146
147
    """
    Placeholder location information for multi-modal data.

148
149
    Example:

150
    Prompt: `AAAA BBBB What is in these images?`
151

152
    Images A and B will have:
153

154
155
156
157
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
158
159
160
161
162
163
164
165
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

166
    is_embed: Optional["torch.Tensor"] = None
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

191

192
193
194
195
196
197
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
198
199
200
201
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

202
203

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
204
205
206
207
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
208
    if isinstance(a, torch.Tensor):
209
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
210
    elif isinstance(b, torch.Tensor):
211
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
212
213

    if isinstance(a, list):
214
215
216
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
217
    if isinstance(b, list):
218
219
220
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
221
222
223
224
225

    # Both a and b are scalars
    return a == b


226
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
227
228
"""
A dictionary containing nested tensors which have been batched via
229
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
230
231
232
"""


233
234
235
236
237
238
239
240
241
242
243
244
245
246
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
    for k in a:
        if k not in b:
            return False
        if not nested_tensors_equal(a[k], b[k]):
            return False

    return True


247
248
249
250
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
251

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""

269
270
271
272
273
274
275
276
277
278
279
280
281
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

282

283
@dataclass
284
class MultiModalFieldElem:
285
286
    """
    Represents a keyword argument corresponding to a multi-modal item
287
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
288
289
290
291
292
293
294
295
296
297
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
298
299
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
300
301
302
    i.e. the name of the keyword argument to be passed to the model.
    """

303
    data: NestedTensors
304
    """
305
306
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
307
    i.e. the value of the keyword argument to be passed to the model.
308
309
310

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
311
312
313
314
315
316
317
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
318
319
320
321
322

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

323
324
325
326
327
328
329
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

330
331
332
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
333
            and type(self.field) is type(other.field)
334
        )  # noqa: E721
335
336
337
338


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
339
340
    """
    Defines how to interpret tensor data belonging to a keyword argument in
341
342
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
358
359

    @abstractmethod
360
361
362
363
364
365
366
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
367
368
369
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
370

371
372
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
373
        """
374
375
        raise NotImplementedError

376
    @abstractmethod
377
378
379
380
381
382
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
383
        raise NotImplementedError
384

385
386
387
388
389
390
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
        pin_memory: bool = False,
    ) -> NestedTensors:
391
        """
392
393
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
394

395
396
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
397
398
399
400
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
401

402
403
        batch = [elem.data for elem in elems]
        return self._reduce_data(batch, pin_memory=pin_memory)
404
405
406
407
408


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
409
    Info:
410
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
411
412
    """

413
414
415
416
417
418
419
420
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
421

422
423
424
425
426
427
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
428
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
429
            batch = cast(list[torch.Tensor], batch)
430
431
432
433
434
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
435
            first_shape = batch[0].shape
436
            if all(elem.shape == first_shape for elem in batch):
437
438
439
440
441
442
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
443
                return torch.stack(batch, out=out)
444
445
446
447
448
449
450

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
451
    Info:
452
453
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
454
    """
455

456
    slices: Sequence[slice] | Sequence[Sequence[slice]]
457
    dim: int = 0
458

459
    def build_elems(
460
        self,
461
462
463
464
465
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
466
        if not is_list_of(self.slices, slice, check="all"):
467
            assert isinstance(data, torch.Tensor), (
468
                "torch.Tensor is required for multiple slices"
469
            )
470
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
471

472
473
474
475
476
477
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
478
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
479
            batch = cast(list[torch.Tensor], batch)
480
481
482
483
484
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
485

486
487
488
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
489
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
490

491
            first_shape = _shape_before_after(batch[0])
492

493
494
495
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
496
497
498
499
500
501
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
502
                return torch.concat(batch, dim=self.dim, out=out)
503
504

        assert self.dim == 0, "dim == 0 is required for nested list"
505
        return [e for elem in batch for e in elem]
506
507


508
509
510
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
511
    Info:
512
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
513
    """
514

515
516
517
518
519
520
521
522
523
524
525
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

526
527
528
529
530
531
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
532
533
534
        return batch[0]


535
536
537
class MultiModalFieldConfig:
    @staticmethod
    def batched(modality: str):
538
539
540
541
542
543
544
545
546
547
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

548
549
550
551
552
553
554
555
556
557
558
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
559
        """
560
        return MultiModalFieldConfig(
561
            field=MultiModalBatchedField(),
562
563
564
565
            modality=modality,
        )

    @staticmethod
566
567
    def flat(
        modality: str,
568
        slices: Sequence[slice] | Sequence[Sequence[slice]],
569
570
        dim: int = 0,
    ):
571
572
573
574
575
576
577
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
578
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
579
                slices (dim>0) that is used to extract the data corresponding
580
581
                to it.
            dim: The dimension to extract data, default to 0.
582
583
584

        Example:

585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
614
        """
615
        return MultiModalFieldConfig(
616
            field=MultiModalFlatField(slices=slices, dim=dim),
617
618
619
            modality=modality,
        )

620
    @staticmethod
621
    def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0):
622
623
624
625
626
627
628
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
629
630
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
631
            dim: The dimension to slice, default to 0.
632
633
634

        Example:

635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
650
            size_per_item: [3, 4, 2]
651
652
653
654
655
656
657
658
659
660
661
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

662
        Info:
663
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
664
665
        """

666
        if size_per_item.ndim != 1:
667
668
669
670
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
671

672
        slice_idxs = [0, *accumulate(size_per_item)]
673
674
675
676
677
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
678

679
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
680

681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

696
697
698
        ```
        Given:
            batch_size: 4
699

700
701
        Input:
            Data: [XYZ]
702

703
704
705
706
707
708
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
709
710
711
712
713
714
715
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
716
717
        super().__init__()

718
        self.field = field
719
        self.modality = modality
720

721
722
723
    def __repr__(self) -> str:
        return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"

724
    def build_elems(
725
726
727
        self,
        key: str,
        batch: NestedTensors,
728
    ) -> Sequence[MultiModalFieldElem]:
729
        return self.field.build_elems(self.modality, key, batch)
730
731


732
733
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
734
735
736
737
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
738
    """
739

740
    @staticmethod
741
    def dummy(modality: str, nbytes: int = 1):
742
743
744
745
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
746
            data=torch.empty(nbytes, dtype=torch.uint8),
747
748
749
750
            field=MultiModalSharedField(1),
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

751
752
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
753
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
754

755
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
756
757
        super().__init__(data)

758
        modalities = {elem.modality for elem in self.values()}
759
        assert len(modalities) == 1, f"Found different modalities={modalities}"
760
761
762
763
764
765
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

766
    def get_data(self) -> dict[str, NestedTensors]:
767
        return {key: elem.data for key, elem in self.items()}
768
769


770
771
772
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
773
    MultiModalKwargsItem | None,
774
775
776
777
778
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
779
    """
780
781
782
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
783
784
    """

785
786
    @staticmethod
    def from_hf_inputs(
787
        hf_inputs: "BatchFeature",
788
789
790
791
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
810
811
                    f"Found: {batch_sizes=}"
                )
812
813
814
815
816
817

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

818
        return MultiModalKwargsItems.from_seq(items)
819

820
821
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
822
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
823
        return MultiModalKwargsItems(items_by_modality)
824

825
    def __getitem__(self, modality: str) -> Sequence[_I]:
826
        if modality not in self:
827
828
829
830
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
831

832
        return super().__getitem__(modality)  # type: ignore[return-value]
833

834
835
836
837
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
838
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
839
840
841

        return self  # type: ignore[return-value]

842
843
844
845
846
847
848
849
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
        cpu_fields: Set[str] = frozenset(),
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
850
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
851
852
853
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
854
855
856
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
857

858
859
860
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
        data = {
            key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
            for key, elems in elems_by_key.items()
        }

        if device is not None:
            for k in data.keys() - cpu_fields:
                data[k] = json_map_leaves(
                    (
                        lambda x: x.to(device=device, non_blocking=True)
                        if isinstance(x, torch.Tensor)
                        else x
                    ),
                    data[k],
                )

        return data
878

879

880
881
882
883
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
884
885


886
@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.13.")
887
888
889
890
891
892
893
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
894
895
896
897
898
899
    @deprecated(
        "`MultiModalKwargs.from_hf_inputs` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_hf_inputs` and "
        "access the tensor data using `.get_data()`."
    )
900
901
902
903
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
904
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data()
905
906

    @staticmethod
907
908
909
910
911
912
    @deprecated(
        "`MultiModalKwargs.from_items` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_seq` and "
        "access the tensor data using `.get_data()`."
    )
913
914
915
916
917
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
918
        return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
919

920
    def __getitem__(self, key: str):
921
        if key not in self:
922
923
924
925
            raise KeyError(
                f"Keyword argument {key!r} not found. "
                f"Available keys: {set(self.keys())}"
            )
926
927

        return super().__getitem__(key)
928

929
930
931
932
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

933
934
935
936
937
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
938

939
        return True
940

941

942
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
943
"""
944
A dictionary containing placeholder ranges for each modality.
945
946
947
"""


948
class MultiModalInputs(TypedDict):
949
    """
950
    Represents the outputs of
951
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
952
953
954
955
956
957
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

958
    prompt_token_ids: list[int]
959
960
    """The processed token IDs which includes placeholder tokens."""

961
    mm_kwargs: MultiModalKwargsOptionalItems
962
963
    """Keyword arguments to be directly passed to the model after batching."""

964
    mm_hashes: "MultiModalHashes"
965
966
    """The hashes of the multi-modal data."""

967
    mm_placeholders: "MultiModalPlaceholderDict"
968
969
    """
    For each modality, information about the placeholder tokens in
970
    `prompt_token_ids`.
971
    """
972

973
974
975
976
977
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

978
979
980

class MultiModalEncDecInputs(MultiModalInputs):
    """
981
982
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
983
984
985
986
987
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""