inputs.py 30.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar
24

25
from vllm.utils.collection_utils import full_groupby, is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
if TYPE_CHECKING:
30
31
32
33
34
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

35
    from .base import MediaWithBytes
36
37
    from .processing import MultiModalHashes

38
39
else:
    torch = LazyLoader("torch", globals(), "torch")
40

41
42
_T = TypeVar("_T")

43
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
44
"""
45
A `transformers.image_utils.ImageInput` representing a single image
46
item, which can be passed to a HuggingFace `ImageProcessor`.
47
48
"""

49
50
51
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
52
"""
53
A `transformers.image_utils.VideoInput` representing a single video
54
item, which can be passed to a HuggingFace `VideoProcessor`.
55
56
"""

57
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
58
"""
59
Represents a single audio
60
item, which can be passed to a HuggingFace `AudioProcessor`.
61
62
"""

63
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
64
"""
65
A `transformers.image_utils.ImageInput` representing a single image
66
item, which can be passed to a HuggingFace `ImageProcessor`.
67
68
69
70
71
72

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

73
74
75
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
76
"""
77
78
79
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
80
81
82
83
84
85

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

86
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
87
88
"""
Represents a single audio
89
item, which can be passed to a HuggingFace `AudioProcessor`.
90
91
92
93
94
95
96
97
98
99

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

100
ModalityData: TypeAlias = _T | list[_T | None] | None
101
"""
102
103
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
104
105

The number of data items allowed per modality is restricted by
106
`--limit-mm-per-prompt`.
107
108
109
110
111
112
113
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

114
    image: ModalityData[ImageItem]
115
116
    """The input image(s)."""

117
    video: ModalityData[VideoItem]
118
119
    """The input video(s)."""

120
    audio: ModalityData[AudioItem]
121
122
123
    """The input audio(s)."""


124
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
125
126
"""
A dictionary containing an entry for each modality type to input.
127

128
129
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
130
131
"""

132
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
133
134
135
136
137
138
139
140
141
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

142

143
144
@dataclass(frozen=True)
class PlaceholderRange:
145
146
147
    """
    Placeholder location information for multi-modal data.

148
149
    Example:

150
    Prompt: `AAAA BBBB What is in these images?`
151

152
    Images A and B will have:
153

154
155
156
157
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
158
159
160
161
162
163
164
165
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

166
    is_embed: Optional["torch.Tensor"] = None
167
168
169
170
171
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

172
173
    @cached_property
    def embeds_cumsum(self) -> torch.Tensor | None:
174
        return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
175
176
177
178

    @cached_property
    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
179
180
            return self.length

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        return int(self.embeds_cumsum[-1])

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

        embeds_start_idx = (
            int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
        )
        embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])

        return embeds_start_idx, embeds_end_idx
205

206
207
208
209
210
211
212
213
214
215
216
217
218
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
219
            return [(self.offset, self.offset + self.length - 1)]
220
221
222
223
224
225
226
227
228
229
230

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

231
232
233
234
235
236
237
238
239
240
241
242
243
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

244

245
246
247
248
249
250
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
251
252
253
254
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

255
256

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
257
258
259
260
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
261
    if isinstance(a, torch.Tensor):
262
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
263
    elif isinstance(b, torch.Tensor):
264
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
265
266

    if isinstance(a, list):
267
268
269
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
270
    if isinstance(b, list):
271
272
273
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
274
275
276
277
278

    # Both a and b are scalars
    return a == b


279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


296
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
297
298
"""
A dictionary containing nested tensors which have been batched via
299
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
300
301
302
"""


303
304
305
306
307
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
308
    return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
309
310


311
312
313
314
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
315

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""

333
334
335
336
337
338
339
340
341
342
343
344
345
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

346

347
@dataclass
348
class MultiModalFieldElem:
349
    """
350
351
    Represents a keyword argument inside a
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
352
353
354
355
356
357
358
359
360
361
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
362
    The key of this field in
363
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
364
365
366
    i.e. the name of the keyword argument to be passed to the model.
    """

367
    data: NestedTensors
368
    """
369
    The tensor data of this field in
370
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
371
    i.e. the value of the keyword argument to be passed to the model.
372
373
374

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
375
376
377
378
379
380
381
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
382
383
384
385
386

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

387
388
389
390
391
392
393
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

394
395
396
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
397
            and type(self.field) is type(other.field)
398
        )  # noqa: E721
399
400


401
@dataclass(frozen=True, kw_only=True)
402
class BaseMultiModalField(ABC):
403
    """
404
405
406
    Defines how to interpret tensor data belonging to a keyword argument for
    [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
    and vice versa.
407
408
    """

409
410
411
412
413
414
    keep_on_cpu: bool = False
    """
    If `True`, then this field is excluded from being moved to the accelerator
    when `MultiModalKwargsItems.get_data()` is called to batch the data.
    """

415
416
417
418
419
420
421
422
423
424
425
426
427
    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
428
429

    @abstractmethod
430
431
432
433
434
435
436
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
437
438
439
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
440

441
442
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
443
        """
444
445
        raise NotImplementedError

446
    @abstractmethod
447
448
449
450
451
452
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
453
        raise NotImplementedError
454

455
456
457
458
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
459
        device: torch.types.Device = None,
460
461
        pin_memory: bool = False,
    ) -> NestedTensors:
462
        """
463
464
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
465

466
467
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
468
469
470
471
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
472

473
474
475
476
477
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

478
        batch = [elem.data for elem in elems]
479
480
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
481
482


483
@dataclass(frozen=True, kw_only=True)
484
485
class MultiModalBatchedField(BaseMultiModalField):
    """
486
    Info:
487
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
488
489
    """

490
491
492
493
494
495
496
497
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
498

499
500
501
502
503
504
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
505
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
506
            batch = cast(list[torch.Tensor], batch)
507
508
509
510
511
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
512
            first_shape = batch[0].shape
513
            if all(elem.shape == first_shape for elem in batch):
514
515
516
517
518
519
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
520
                return torch.stack(batch, out=out)
521
522
523
524

        return batch


525
@dataclass(frozen=True, kw_only=True)
526
527
class MultiModalFlatField(BaseMultiModalField):
    """
528
    Info:
529
530
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
531
    """
532

533
    slices: Sequence[slice] | Sequence[Sequence[slice]]
534
    dim: int = 0
535

536
    def build_elems(
537
        self,
538
539
540
541
542
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
543
        if not is_list_of(self.slices, slice, check="all"):
544
            assert isinstance(data, torch.Tensor), (
545
                "torch.Tensor is required for multiple slices"
546
            )
547
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
548

549
550
551
552
553
554
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
555
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
556
            batch = cast(list[torch.Tensor], batch)
557
558
559
560
561
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
562

563
564
565
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
566
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
567

568
            first_shape = _shape_before_after(batch[0])
569

570
571
572
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
573
574
575
576
577
578
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
579
                return torch.concat(batch, dim=self.dim, out=out)
580
581

        assert self.dim == 0, "dim == 0 is required for nested list"
582
        return [e for elem in batch for e in elem]
583
584


585
@dataclass(frozen=True, kw_only=True)
586
587
class MultiModalSharedField(BaseMultiModalField):
    """
588
    Info:
589
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
590
    """
591

592
593
594
595
596
597
598
599
600
601
602
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

603
604
605
606
607
608
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
609
610
611
        return batch[0]


612
@dataclass(frozen=True)
613
614
class MultiModalFieldConfig:
    @staticmethod
615
    def batched(modality: str, *, keep_on_cpu: bool = False):
616
617
618
619
620
621
622
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
623
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
624
625
626

        Example:

627
628
629
630
631
632
633
634
635
636
637
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
638
        """
639
        return MultiModalFieldConfig(
640
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
641
642
643
644
            modality=modality,
        )

    @staticmethod
645
646
    def flat(
        modality: str,
647
        slices: Sequence[slice] | Sequence[Sequence[slice]],
648
        dim: int = 0,
649
650
        *,
        keep_on_cpu: bool = False,
651
    ):
652
653
654
655
656
657
658
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
659
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
660
                slices (dim>0) that is used to extract the data corresponding
661
662
                to it.
            dim: The dimension to extract data, default to 0.
663
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
664
665
666

        Example:

667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
696
        """
697
        return MultiModalFieldConfig(
698
699
700
701
702
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
703
704
705
            modality=modality,
        )

706
    @staticmethod
707
708
709
710
711
712
713
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
714
715
716
717
718
719
720
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
721
722
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
723
            dim: The dimension to slice, default to 0.
724
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
725
726
727

        Example:

728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
743
            size_per_item: [3, 4, 2]
744
745
746
747
748
749
750
751
752
753
754
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

755
        Info:
756
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
757
758
        """

759
        if size_per_item.ndim != 1:
760
761
762
763
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
764

765
        slice_idxs = [0, *accumulate(size_per_item)]
766
767
768
769
770
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
771

772
773
774
775
776
777
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
778

779
    @staticmethod
780
781
782
783
784
785
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
786
787
788
789
790
791
792
793
794
795
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
796
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
797
798
799

        Example:

800
801
802
        ```
        Given:
            batch_size: 4
803

804
805
        Input:
            Data: [XYZ]
806

807
808
809
810
811
812
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
813
814
        """
        return MultiModalFieldConfig(
815
816
817
818
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
819
820
821
            modality=modality,
        )

822
823
    field: BaseMultiModalField
    modality: str
824

825
    def build_elems(
826
827
828
        self,
        key: str,
        batch: NestedTensors,
829
    ) -> Sequence[MultiModalFieldElem]:
830
        return self.field.build_elems(self.modality, key, batch)
831
832


833
834
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
835
836
837
838
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
839
    """
840

841
    @staticmethod
842
    def dummy(modality: str, nbytes: int = 1):
843
844
845
846
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
847
            data=torch.empty(nbytes, dtype=torch.uint8),
848
            field=MultiModalSharedField(batch_size=1),
849
850
851
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

852
853
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
854
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
855

856
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
857
858
        super().__init__(data)

859
        modalities = {elem.modality for elem in self.values()}
860
        assert len(modalities) == 1, f"Found different modalities={modalities}"
861
862
863
864
865
866
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

867
    def get_data(self) -> dict[str, NestedTensors]:
868
        return {key: elem.data for key, elem in self.items()}
869
870


871
872
873
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
874
    MultiModalKwargsItem | None,
875
876
877
878
879
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
880
    """
881
882
883
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
884
885
    """

886
887
    @staticmethod
    def from_hf_inputs(
888
        hf_inputs: "BatchFeature",
889
890
891
892
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
911
912
                    f"Found: {batch_sizes=}"
                )
913
914
915
916
917
918

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

919
        return MultiModalKwargsItems.from_seq(items)
920

921
922
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
923
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
924
        return MultiModalKwargsItems(items_by_modality)
925

926
    def __getitem__(self, modality: str) -> Sequence[_I]:
927
        if modality not in self:
928
929
930
931
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
932

933
        return super().__getitem__(modality)  # type: ignore[return-value]
934

935
936
937
938
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
939
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
940
941
942

        return self  # type: ignore[return-value]

943
944
945
946
947
948
949
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
950
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
951
952
953
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
954
955
956
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
957

958
959
960
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

961
        data = {
962
963
964
965
966
            key: elems[0].field.reduce_data(
                elems,
                device=device,
                pin_memory=pin_memory,
            )
967
968
969
970
            for key, elems in elems_by_key.items()
        }

        return data
971

972

973
974
975
976
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
977
978


979
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
980
"""
981
A dictionary containing placeholder ranges for each modality.
982
983
984
"""


985
class MultiModalInputs(TypedDict):
986
    """
987
    Represents the outputs of
988
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
989
990
991
992
993
994
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

995
    prompt_token_ids: list[int]
996
997
    """The processed token IDs which includes placeholder tokens."""

998
    mm_kwargs: MultiModalKwargsOptionalItems
999
1000
    """Keyword arguments to be directly passed to the model after batching."""

1001
    mm_hashes: "MultiModalHashes"
1002
1003
    """The hashes of the multi-modal data."""

1004
    mm_placeholders: "MultiModalPlaceholderDict"
1005
1006
    """
    For each modality, information about the placeholder tokens in
1007
    `prompt_token_ids`.
1008
    """
1009

1010
1011
1012
1013
1014
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1015
1016
1017

class MultiModalEncDecInputs(MultiModalInputs):
    """
1018
1019
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1020
1021
1022
1023
1024
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""