inputs.py 30.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar, deprecated
24

25
from vllm.utils.collection_utils import full_groupby, is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
if TYPE_CHECKING:
30
31
32
33
34
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

35
    from .base import MediaWithBytes
36
37
    from .processing import MultiModalHashes

38
39
else:
    torch = LazyLoader("torch", globals(), "torch")
40

41
42
_T = TypeVar("_T")

43
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
44
"""
45
A `transformers.image_utils.ImageInput` representing a single image
46
item, which can be passed to a HuggingFace `ImageProcessor`.
47
48
"""

49
50
51
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
52
"""
53
A `transformers.image_utils.VideoInput` representing a single video
54
item, which can be passed to a HuggingFace `VideoProcessor`.
55
56
"""

57
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
58
"""
59
Represents a single audio
60
item, which can be passed to a HuggingFace `AudioProcessor`.
61
62
"""

63
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
64
"""
65
A `transformers.image_utils.ImageInput` representing a single image
66
item, which can be passed to a HuggingFace `ImageProcessor`.
67
68
69
70
71
72

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

73
74
75
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
76
"""
77
78
79
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
80
81
82
83
84
85

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

86
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
87
88
"""
Represents a single audio
89
item, which can be passed to a HuggingFace `AudioProcessor`.
90
91
92
93
94
95
96
97
98
99

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

100
ModalityData: TypeAlias = _T | list[_T | None] | None
101
"""
102
103
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
104
105

The number of data items allowed per modality is restricted by
106
`--limit-mm-per-prompt`.
107
108
109
110
111
112
113
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

114
    image: ModalityData[ImageItem]
115
116
    """The input image(s)."""

117
    video: ModalityData[VideoItem]
118
119
    """The input video(s)."""

120
    audio: ModalityData[AudioItem]
121
122
123
    """The input audio(s)."""


124
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
125
126
"""
A dictionary containing an entry for each modality type to input.
127

128
129
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
130
131
"""

132
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
133
134
135
136
137
138
139
140
141
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

142

143
144
@dataclass(frozen=True)
class PlaceholderRange:
145
146
147
    """
    Placeholder location information for multi-modal data.

148
149
    Example:

150
    Prompt: `AAAA BBBB What is in these images?`
151

152
    Images A and B will have:
153

154
155
156
157
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
158
159
160
161
162
163
164
165
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

166
    is_embed: Optional["torch.Tensor"] = None
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

191

192
193
194
195
196
197
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
198
199
200
201
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

202
203

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
204
205
206
207
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
208
    if isinstance(a, torch.Tensor):
209
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
210
    elif isinstance(b, torch.Tensor):
211
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
212
213

    if isinstance(a, list):
214
215
216
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
217
    if isinstance(b, list):
218
219
220
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
221
222
223
224
225

    # Both a and b are scalars
    return a == b


226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


243
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
244
245
"""
A dictionary containing nested tensors which have been batched via
246
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
247
248
249
"""


250
251
252
253
254
255
256
257
258
259
260
261
262
263
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
    for k in a:
        if k not in b:
            return False
        if not nested_tensors_equal(a[k], b[k]):
            return False

    return True


264
265
266
267
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
268

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""

286
287
288
289
290
291
292
293
294
295
296
297
298
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

299

300
@dataclass
301
class MultiModalFieldElem:
302
303
    """
    Represents a keyword argument corresponding to a multi-modal item
304
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
305
306
307
308
309
310
311
312
313
314
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
315
316
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
317
318
319
    i.e. the name of the keyword argument to be passed to the model.
    """

320
    data: NestedTensors
321
    """
322
323
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
324
    i.e. the value of the keyword argument to be passed to the model.
325
326
327

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
328
329
330
331
332
333
334
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
335
336
337
338
339

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

340
341
342
343
344
345
346
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

347
348
349
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
350
            and type(self.field) is type(other.field)
351
        )  # noqa: E721
352
353


354
@dataclass(frozen=True, kw_only=True)
355
class BaseMultiModalField(ABC):
356
357
    """
    Defines how to interpret tensor data belonging to a keyword argument in
358
359
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
360
361
    """

362
363
364
365
366
367
    keep_on_cpu: bool = False
    """
    If `True`, then this field is excluded from being moved to the accelerator
    when `MultiModalKwargsItems.get_data()` is called to batch the data.
    """

368
369
370
371
372
373
374
375
376
377
378
379
380
    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
381
382

    @abstractmethod
383
384
385
386
387
388
389
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
390
391
392
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
393

394
395
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
396
        """
397
398
        raise NotImplementedError

399
    @abstractmethod
400
401
402
403
404
405
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
406
        raise NotImplementedError
407

408
409
410
411
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
412
        device: torch.types.Device = None,
413
414
        pin_memory: bool = False,
    ) -> NestedTensors:
415
        """
416
417
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
418

419
420
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
421
422
423
424
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
425

426
427
428
429
430
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

431
        batch = [elem.data for elem in elems]
432
433
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
434
435


436
@dataclass(frozen=True, kw_only=True)
437
438
class MultiModalBatchedField(BaseMultiModalField):
    """
439
    Info:
440
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
441
442
    """

443
444
445
446
447
448
449
450
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
451

452
453
454
455
456
457
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
458
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
459
            batch = cast(list[torch.Tensor], batch)
460
461
462
463
464
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
465
            first_shape = batch[0].shape
466
            if all(elem.shape == first_shape for elem in batch):
467
468
469
470
471
472
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
473
                return torch.stack(batch, out=out)
474
475
476
477

        return batch


478
@dataclass(frozen=True, kw_only=True)
479
480
class MultiModalFlatField(BaseMultiModalField):
    """
481
    Info:
482
483
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
484
    """
485

486
    slices: Sequence[slice] | Sequence[Sequence[slice]]
487
    dim: int = 0
488

489
    def build_elems(
490
        self,
491
492
493
494
495
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
496
        if not is_list_of(self.slices, slice, check="all"):
497
            assert isinstance(data, torch.Tensor), (
498
                "torch.Tensor is required for multiple slices"
499
            )
500
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
501

502
503
504
505
506
507
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
508
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
509
            batch = cast(list[torch.Tensor], batch)
510
511
512
513
514
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
515

516
517
518
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
519
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
520

521
            first_shape = _shape_before_after(batch[0])
522

523
524
525
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
526
527
528
529
530
531
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
532
                return torch.concat(batch, dim=self.dim, out=out)
533
534

        assert self.dim == 0, "dim == 0 is required for nested list"
535
        return [e for elem in batch for e in elem]
536
537


538
@dataclass(frozen=True, kw_only=True)
539
540
class MultiModalSharedField(BaseMultiModalField):
    """
541
    Info:
542
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
543
    """
544

545
546
547
548
549
550
551
552
553
554
555
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

556
557
558
559
560
561
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
562
563
564
        return batch[0]


565
@dataclass(frozen=True)
566
567
class MultiModalFieldConfig:
    @staticmethod
568
    def batched(modality: str, *, keep_on_cpu: bool = False):
569
570
571
572
573
574
575
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
576
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
577
578
579

        Example:

580
581
582
583
584
585
586
587
588
589
590
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
591
        """
592
        return MultiModalFieldConfig(
593
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
594
595
596
597
            modality=modality,
        )

    @staticmethod
598
599
    def flat(
        modality: str,
600
        slices: Sequence[slice] | Sequence[Sequence[slice]],
601
        dim: int = 0,
602
603
        *,
        keep_on_cpu: bool = False,
604
    ):
605
606
607
608
609
610
611
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
612
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
613
                slices (dim>0) that is used to extract the data corresponding
614
615
                to it.
            dim: The dimension to extract data, default to 0.
616
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
617
618
619

        Example:

620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
649
        """
650
        return MultiModalFieldConfig(
651
652
653
654
655
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
656
657
658
            modality=modality,
        )

659
    @staticmethod
660
661
662
663
664
665
666
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
667
668
669
670
671
672
673
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
674
675
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
676
            dim: The dimension to slice, default to 0.
677
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
678
679
680

        Example:

681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
696
            size_per_item: [3, 4, 2]
697
698
699
700
701
702
703
704
705
706
707
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

708
        Info:
709
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
710
711
        """

712
        if size_per_item.ndim != 1:
713
714
715
716
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
717

718
        slice_idxs = [0, *accumulate(size_per_item)]
719
720
721
722
723
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
724

725
726
727
728
729
730
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
731

732
    @staticmethod
733
734
735
736
737
738
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
739
740
741
742
743
744
745
746
747
748
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
749
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
750
751
752

        Example:

753
754
755
        ```
        Given:
            batch_size: 4
756

757
758
        Input:
            Data: [XYZ]
759

760
761
762
763
764
765
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
766
767
        """
        return MultiModalFieldConfig(
768
769
770
771
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
772
773
774
            modality=modality,
        )

775
776
    field: BaseMultiModalField
    modality: str
777

778
    def build_elems(
779
780
781
        self,
        key: str,
        batch: NestedTensors,
782
    ) -> Sequence[MultiModalFieldElem]:
783
        return self.field.build_elems(self.modality, key, batch)
784
785


786
787
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
788
789
790
791
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
792
    """
793

794
    @staticmethod
795
    def dummy(modality: str, nbytes: int = 1):
796
797
798
799
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
800
            data=torch.empty(nbytes, dtype=torch.uint8),
801
            field=MultiModalSharedField(batch_size=1),
802
803
804
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

805
806
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
807
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
808

809
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
810
811
        super().__init__(data)

812
        modalities = {elem.modality for elem in self.values()}
813
        assert len(modalities) == 1, f"Found different modalities={modalities}"
814
815
816
817
818
819
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

820
    def get_data(self) -> dict[str, NestedTensors]:
821
        return {key: elem.data for key, elem in self.items()}
822
823


824
825
826
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
827
    MultiModalKwargsItem | None,
828
829
830
831
832
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
833
    """
834
835
836
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
837
838
    """

839
840
    @staticmethod
    def from_hf_inputs(
841
        hf_inputs: "BatchFeature",
842
843
844
845
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
864
865
                    f"Found: {batch_sizes=}"
                )
866
867
868
869
870
871

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

872
        return MultiModalKwargsItems.from_seq(items)
873

874
875
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
876
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
877
        return MultiModalKwargsItems(items_by_modality)
878

879
    def __getitem__(self, modality: str) -> Sequence[_I]:
880
        if modality not in self:
881
882
883
884
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
885

886
        return super().__getitem__(modality)  # type: ignore[return-value]
887

888
889
890
891
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
892
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
893
894
895

        return self  # type: ignore[return-value]

896
897
898
899
900
901
902
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
903
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
904
905
906
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
907
908
909
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
910

911
912
913
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

914
        data = {
915
916
917
918
919
            key: elems[0].field.reduce_data(
                elems,
                device=device,
                pin_memory=pin_memory,
            )
920
921
922
923
            for key, elems in elems_by_key.items()
        }

        return data
924

925

926
927
928
929
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
930
931


932
@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.13.")
933
934
935
936
937
938
939
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
940
941
942
943
944
945
    @deprecated(
        "`MultiModalKwargs.from_hf_inputs` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_hf_inputs` and "
        "access the tensor data using `.get_data()`."
    )
946
947
948
949
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
950
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data()
951
952

    @staticmethod
953
954
955
956
957
958
    @deprecated(
        "`MultiModalKwargs.from_items` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_seq` and "
        "access the tensor data using `.get_data()`."
    )
959
960
961
962
963
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
964
        return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
965

966
    def __getitem__(self, key: str):
967
        if key not in self:
968
969
970
971
            raise KeyError(
                f"Keyword argument {key!r} not found. "
                f"Available keys: {set(self.keys())}"
            )
972
973

        return super().__getitem__(key)
974

975
976
977
978
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

979
980
981
982
983
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
984

985
        return True
986

987

988
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
989
"""
990
A dictionary containing placeholder ranges for each modality.
991
992
993
"""


994
class MultiModalInputs(TypedDict):
995
    """
996
    Represents the outputs of
997
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
998
999
1000
1001
1002
1003
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

1004
    prompt_token_ids: list[int]
1005
1006
    """The processed token IDs which includes placeholder tokens."""

1007
    mm_kwargs: MultiModalKwargsOptionalItems
1008
1009
    """Keyword arguments to be directly passed to the model after batching."""

1010
    mm_hashes: "MultiModalHashes"
1011
1012
    """The hashes of the multi-modal data."""

1013
    mm_placeholders: "MultiModalPlaceholderDict"
1014
1015
    """
    For each modality, information about the placeholder tokens in
1016
    `prompt_token_ids`.
1017
    """
1018

1019
1020
1021
1022
1023
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1024
1025
1026

class MultiModalEncDecInputs(MultiModalInputs):
    """
1027
1028
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1029
1030
1031
1032
1033
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""