inputs.py 30.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar
24

25
from vllm.utils.collection_utils import full_groupby, is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
if TYPE_CHECKING:
30
31
32
33
34
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

35
    from .media import MediaWithBytes
36
37
else:
    torch = LazyLoader("torch", globals(), "torch")
38

39
40
_T = TypeVar("_T")

41
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
42
"""
43
A `transformers.image_utils.ImageInput` representing a single image
44
item, which can be passed to a HuggingFace `ImageProcessor`.
45
46
"""

47
48
49
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
50
"""
51
A `transformers.image_utils.VideoInput` representing a single video
52
item, which can be passed to a HuggingFace `VideoProcessor`.
53
54
"""

55
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
56
"""
57
Represents a single audio
58
item, which can be passed to a HuggingFace `AudioProcessor`.
59
60
"""

61
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
62
"""
63
A `transformers.image_utils.ImageInput` representing a single image
64
item, which can be passed to a HuggingFace `ImageProcessor`.
65
66
67
68
69
70

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

71
72
73
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
74
"""
75
76
77
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
78
79
80
81
82
83

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

84
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
85
86
"""
Represents a single audio
87
item, which can be passed to a HuggingFace `AudioProcessor`.
88
89
90
91
92
93
94
95
96
97

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

98
ModalityData: TypeAlias = _T | list[_T | None] | None
99
"""
100
101
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
102
103

The number of data items allowed per modality is restricted by
104
`--limit-mm-per-prompt`.
105
106
107
108
109
110
111
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

112
    image: ModalityData[ImageItem]
113
114
    """The input image(s)."""

115
    video: ModalityData[VideoItem]
116
117
    """The input video(s)."""

118
    audio: ModalityData[AudioItem]
119
120
121
    """The input audio(s)."""


122
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
123
124
"""
A dictionary containing an entry for each modality type to input.
125

126
127
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
128
129
"""

130
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
131
132
133
134
135
136
137
138
139
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

140

141
142
@dataclass(frozen=True)
class PlaceholderRange:
143
144
145
    """
    Placeholder location information for multi-modal data.

146
147
    Example:

148
    Prompt: `AAAA BBBB What is in these images?`
149

150
    Images A and B will have:
151

152
153
154
155
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
156
157
158
159
160
161
162
163
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

164
    is_embed: Optional["torch.Tensor"] = None
165
166
167
168
169
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

170
171
    @cached_property
    def embeds_cumsum(self) -> torch.Tensor | None:
172
        return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
173
174
175
176

    @cached_property
    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
177
178
            return self.length

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        return int(self.embeds_cumsum[-1])

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

        embeds_start_idx = (
            int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
        )
        embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])

        return embeds_start_idx, embeds_end_idx
203

204
205
206
207
208
209
210
211
212
213
214
215
216
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
217
            return [(self.offset, self.offset + self.length - 1)]
218
219
220
221
222
223
224
225
226
227
228

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

229
230
231
232
233
234
235
236
237
238
239
240
241
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

242

243
244
245
246
247
248
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
249
250
251
252
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

253
254

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
255
256
257
258
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
259
    if isinstance(a, torch.Tensor):
260
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
261
    elif isinstance(b, torch.Tensor):
262
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
263
264

    if isinstance(a, list):
265
266
267
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
268
    if isinstance(b, list):
269
270
271
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
272
273
274
275
276

    # Both a and b are scalars
    return a == b


277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


294
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
295
296
"""
A dictionary containing nested tensors which have been batched via
297
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
298
299
300
"""


301
302
303
304
305
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
306
    return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
307
308


309
310
311
312
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
313

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""

331
332
333
    mm_hash: str | None = None
    """Base mm_hash for processor cache (without LoRA prefix)."""

334
335
336
337
338
339
340
341
342
343
344
345
346
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

347

348
@dataclass
349
class MultiModalFieldElem:
350
    """
351
352
    Represents a keyword argument inside a
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
353
354
355
356
357
358
359
360
361
362
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
363
    The key of this field in
364
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
365
366
367
    i.e. the name of the keyword argument to be passed to the model.
    """

368
    data: NestedTensors
369
    """
370
    The tensor data of this field in
371
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
372
    i.e. the value of the keyword argument to be passed to the model.
373
374
375

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
376
377
378
379
380
381
382
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
383
384
385
386
387

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

388
389
390
391
392
393
394
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

395
396
397
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
398
            and type(self.field) is type(other.field)
399
        )  # noqa: E721
400
401


402
@dataclass(frozen=True, kw_only=True)
403
class BaseMultiModalField(ABC):
404
    """
405
406
407
    Defines how to interpret tensor data belonging to a keyword argument for
    [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
    and vice versa.
408
409
    """

410
411
412
413
414
415
    keep_on_cpu: bool = False
    """
    If `True`, then this field is excluded from being moved to the accelerator
    when `MultiModalKwargsItems.get_data()` is called to batch the data.
    """

416
417
418
419
420
421
422
423
424
425
426
427
428
    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
429
430

    @abstractmethod
431
432
433
434
435
436
437
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
438
439
440
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
441

442
443
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
444
        """
445
446
        raise NotImplementedError

447
    @abstractmethod
448
449
450
451
452
453
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
454
        raise NotImplementedError
455

456
457
458
459
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
460
        device: torch.types.Device = None,
461
462
        pin_memory: bool = False,
    ) -> NestedTensors:
463
        """
464
465
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
466

467
468
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
469
470
471
472
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
473

474
475
476
477
478
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

479
        batch = [elem.data for elem in elems]
480
481
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
482
483


484
@dataclass(frozen=True, kw_only=True)
485
486
class MultiModalBatchedField(BaseMultiModalField):
    """
487
    Info:
488
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
489
490
    """

491
492
493
494
495
496
497
498
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
499

500
501
502
503
504
505
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
506
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
507
            batch = cast(list[torch.Tensor], batch)
508
509
510
511
512
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
513
            first_shape = batch[0].shape
514
            if all(elem.shape == first_shape for elem in batch):
515
516
517
518
519
520
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
521
                return torch.stack(batch, out=out)
522
523
524
525

        return batch


526
@dataclass(frozen=True, kw_only=True)
527
528
class MultiModalFlatField(BaseMultiModalField):
    """
529
    Info:
530
531
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
532
    """
533

534
    slices: Sequence[slice] | Sequence[Sequence[slice]]
535
    dim: int = 0
536

537
    def build_elems(
538
        self,
539
540
541
542
543
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
544
        if not is_list_of(self.slices, slice, check="all"):
545
            assert isinstance(data, torch.Tensor), (
546
                "torch.Tensor is required for multiple slices"
547
            )
548
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
549

550
551
552
553
554
555
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
556
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
557
            batch = cast(list[torch.Tensor], batch)
558
559
560
561
562
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
563

564
565
566
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
567
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
568

569
            first_shape = _shape_before_after(batch[0])
570

571
572
573
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
574
575
576
577
578
579
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
580
                return torch.concat(batch, dim=self.dim, out=out)
581
582

        assert self.dim == 0, "dim == 0 is required for nested list"
583
        return [e for elem in batch for e in elem]
584
585


586
@dataclass(frozen=True, kw_only=True)
587
588
class MultiModalSharedField(BaseMultiModalField):
    """
589
    Info:
590
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
591
    """
592

593
594
595
596
597
598
599
600
601
602
603
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

604
605
606
607
608
609
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
610
611
612
        return batch[0]


613
@dataclass(frozen=True)
614
615
class MultiModalFieldConfig:
    @staticmethod
616
    def batched(modality: str, *, keep_on_cpu: bool = False):
617
618
619
620
621
622
623
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
624
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
625
626
627

        Example:

628
629
630
631
632
633
634
635
636
637
638
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
639
        """
640
        return MultiModalFieldConfig(
641
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
642
643
644
645
            modality=modality,
        )

    @staticmethod
646
647
    def flat(
        modality: str,
648
        slices: Sequence[slice] | Sequence[Sequence[slice]],
649
        dim: int = 0,
650
651
        *,
        keep_on_cpu: bool = False,
652
    ):
653
654
655
656
657
658
659
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
660
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
661
                slices (dim>0) that is used to extract the data corresponding
662
663
                to it.
            dim: The dimension to extract data, default to 0.
664
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
665
666
667

        Example:

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
697
        """
698
        return MultiModalFieldConfig(
699
700
701
702
703
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
704
705
706
            modality=modality,
        )

707
    @staticmethod
708
709
710
711
712
713
714
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
715
716
717
718
719
720
721
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
722
723
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
724
            dim: The dimension to slice, default to 0.
725
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
726
727
728

        Example:

729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
744
            size_per_item: [3, 4, 2]
745
746
747
748
749
750
751
752
753
754
755
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

756
        Info:
757
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
758
759
        """

760
        if size_per_item.ndim != 1:
761
762
763
764
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
765

766
        slice_idxs = [0, *accumulate(size_per_item)]
767
768
769
770
771
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
772

773
774
775
776
777
778
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
779

780
    @staticmethod
781
782
783
784
785
786
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
787
788
789
790
791
792
793
794
795
796
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
797
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
798
799
800

        Example:

801
802
803
        ```
        Given:
            batch_size: 4
804

805
806
        Input:
            Data: [XYZ]
807

808
809
810
811
812
813
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
814
815
        """
        return MultiModalFieldConfig(
816
817
818
819
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
820
821
822
            modality=modality,
        )

823
824
    field: BaseMultiModalField
    modality: str
825

826
    def build_elems(
827
828
829
        self,
        key: str,
        batch: NestedTensors,
830
    ) -> Sequence[MultiModalFieldElem]:
831
        return self.field.build_elems(self.modality, key, batch)
832
833


834
835
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
836
837
838
839
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
840
    """
841

842
    @staticmethod
843
    def dummy(modality: str, nbytes: int = 1):
844
845
846
847
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
848
            data=torch.empty(nbytes, dtype=torch.uint8),
849
            field=MultiModalSharedField(batch_size=1),
850
851
852
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

853
854
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
855
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
856

857
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
858
859
        super().__init__(data)

860
        modalities = {elem.modality for elem in self.values()}
861
        assert len(modalities) == 1, f"Found different modalities={modalities}"
862
863
864
865
866
867
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

868
    def get_data(self) -> dict[str, NestedTensors]:
869
        return {key: elem.data for key, elem in self.items()}
870
871


872
873
874
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
875
    MultiModalKwargsItem | None,
876
877
878
879
880
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
881
    """
882
883
884
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
885
886
    """

887
888
    @staticmethod
    def from_hf_inputs(
889
        hf_inputs: "BatchFeature",
890
891
892
893
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
912
913
                    f"Found: {batch_sizes=}"
                )
914
915
916
917
918
919

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

920
        return MultiModalKwargsItems.from_seq(items)
921

922
923
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
924
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
925
        return MultiModalKwargsItems(items_by_modality)
926

927
    def __getitem__(self, modality: str) -> Sequence[_I]:
928
        if modality not in self:
929
930
931
932
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
933

934
        return super().__getitem__(modality)  # type: ignore[return-value]
935

936
937
938
939
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
940
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
941
942
943

        return self  # type: ignore[return-value]

944
945
946
947
948
949
950
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
951
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
952
953
954
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
955
956
957
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
958

959
960
961
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

962
        data = {
963
964
965
966
967
            key: elems[0].field.reduce_data(
                elems,
                device=device,
                pin_memory=pin_memory,
            )
968
969
970
971
            for key, elems in elems_by_key.items()
        }

        return data
972

973

974
975
976
977
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
978
979


980
981
982
983
984
985
MultiModalHashes = dict[str, list[str]]
"""
A dictionary containing per-item hashes for each modality.
"""


986
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
987
"""
988
A dictionary containing per-item placeholder ranges for each modality.
989
990
991
"""


992
class MultiModalInputs(TypedDict):
993
    """
994
    Represents the outputs of
995
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
996
997
998
999
1000
1001
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

1002
    prompt_token_ids: list[int]
1003
1004
    """The processed token IDs which includes placeholder tokens."""

1005
    mm_kwargs: MultiModalKwargsOptionalItems
1006
1007
    """Keyword arguments to be directly passed to the model after batching."""

1008
    mm_hashes: MultiModalHashes
1009
1010
    """The hashes of the multi-modal data."""

1011
    mm_placeholders: MultiModalPlaceholderDict
1012
1013
    """
    For each modality, information about the placeholder tokens in
1014
    `prompt_token_ids`.
1015
    """
1016

1017
1018
1019
1020
1021
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1022
1023
1024

class MultiModalEncDecInputs(MultiModalInputs):
    """
1025
1026
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1027
1028
1029
1030
1031
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""