inputs.py 30.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar
24

25
from vllm.utils.collection_utils import full_groupby, is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
if TYPE_CHECKING:
30
31
32
33
34
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

35
    from .base import MediaWithBytes
36
37
    from .processing import MultiModalHashes

38
39
else:
    torch = LazyLoader("torch", globals(), "torch")
40

41
42
_T = TypeVar("_T")

43
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
44
"""
45
A `transformers.image_utils.ImageInput` representing a single image
46
item, which can be passed to a HuggingFace `ImageProcessor`.
47
48
"""

49
50
51
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
52
"""
53
A `transformers.image_utils.VideoInput` representing a single video
54
item, which can be passed to a HuggingFace `VideoProcessor`.
55
56
"""

57
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
58
"""
59
Represents a single audio
60
item, which can be passed to a HuggingFace `AudioProcessor`.
61
62
"""

63
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
64
"""
65
A `transformers.image_utils.ImageInput` representing a single image
66
item, which can be passed to a HuggingFace `ImageProcessor`.
67
68
69
70
71
72

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

73
74
75
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
76
"""
77
78
79
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
80
81
82
83
84
85

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

86
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
87
88
"""
Represents a single audio
89
item, which can be passed to a HuggingFace `AudioProcessor`.
90
91
92
93
94
95
96
97
98
99

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

100
ModalityData: TypeAlias = _T | list[_T | None] | None
101
"""
102
103
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
104
105

The number of data items allowed per modality is restricted by
106
`--limit-mm-per-prompt`.
107
108
109
110
111
112
113
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

114
    image: ModalityData[ImageItem]
115
116
    """The input image(s)."""

117
    video: ModalityData[VideoItem]
118
119
    """The input video(s)."""

120
    audio: ModalityData[AudioItem]
121
122
123
    """The input audio(s)."""


124
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
125
126
"""
A dictionary containing an entry for each modality type to input.
127

128
129
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
130
131
"""

132
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
133
134
135
136
137
138
139
140
141
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

142

143
144
@dataclass(frozen=True)
class PlaceholderRange:
145
146
147
    """
    Placeholder location information for multi-modal data.

148
149
    Example:

150
    Prompt: `AAAA BBBB What is in these images?`
151

152
    Images A and B will have:
153

154
155
156
157
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
158
159
160
161
162
163
164
165
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

166
    is_embed: Optional["torch.Tensor"] = None
167
168
169
170
171
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

172
173
    @cached_property
    def embeds_cumsum(self) -> torch.Tensor | None:
174
        return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
175
176
177
178

    @cached_property
    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
179
180
            return self.length

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        return int(self.embeds_cumsum[-1])

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

        embeds_start_idx = (
            int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
        )
        embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])

        return embeds_start_idx, embeds_end_idx
205

206
207
208
209
210
211
212
213
214
215
216
217
218
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
219
            return [(self.offset, self.offset + self.length - 1)]
220
221
222
223
224
225
226
227
228
229
230

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

231
232
233
234
235
236
237
238
239
240
241
242
243
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

244

245
246
247
248
249
250
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
251
252
253
254
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

255
256

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
257
258
259
260
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
261
    if isinstance(a, torch.Tensor):
262
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
263
    elif isinstance(b, torch.Tensor):
264
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
265
266

    if isinstance(a, list):
267
268
269
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
270
    if isinstance(b, list):
271
272
273
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
274
275
276
277
278

    # Both a and b are scalars
    return a == b


279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


296
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
297
298
"""
A dictionary containing nested tensors which have been batched via
299
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
300
301
302
"""


303
304
305
306
307
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
308
    return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
309
310


311
312
313
314
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
315

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""

333
334
335
    mm_hash: str | None = None
    """Base mm_hash for processor cache (without LoRA prefix)."""

336
337
338
339
340
341
342
343
344
345
346
347
348
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

349

350
@dataclass
351
class MultiModalFieldElem:
352
    """
353
354
    Represents a keyword argument inside a
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
355
356
357
358
359
360
361
362
363
364
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
365
    The key of this field in
366
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
367
368
369
    i.e. the name of the keyword argument to be passed to the model.
    """

370
    data: NestedTensors
371
    """
372
    The tensor data of this field in
373
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
374
    i.e. the value of the keyword argument to be passed to the model.
375
376
377

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
378
379
380
381
382
383
384
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
385
386
387
388
389

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

390
391
392
393
394
395
396
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

397
398
399
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
400
            and type(self.field) is type(other.field)
401
        )  # noqa: E721
402
403


404
@dataclass(frozen=True, kw_only=True)
405
class BaseMultiModalField(ABC):
406
    """
407
408
409
    Defines how to interpret tensor data belonging to a keyword argument for
    [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
    and vice versa.
410
411
    """

412
413
414
415
416
417
    keep_on_cpu: bool = False
    """
    If `True`, then this field is excluded from being moved to the accelerator
    when `MultiModalKwargsItems.get_data()` is called to batch the data.
    """

418
419
420
421
422
423
424
425
426
427
428
429
430
    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
431
432

    @abstractmethod
433
434
435
436
437
438
439
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
440
441
442
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
443

444
445
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
446
        """
447
448
        raise NotImplementedError

449
    @abstractmethod
450
451
452
453
454
455
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
456
        raise NotImplementedError
457

458
459
460
461
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
462
        device: torch.types.Device = None,
463
464
        pin_memory: bool = False,
    ) -> NestedTensors:
465
        """
466
467
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
468

469
470
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
471
472
473
474
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
475

476
477
478
479
480
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

481
        batch = [elem.data for elem in elems]
482
483
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
484
485


486
@dataclass(frozen=True, kw_only=True)
487
488
class MultiModalBatchedField(BaseMultiModalField):
    """
489
    Info:
490
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
491
492
    """

493
494
495
496
497
498
499
500
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
501

502
503
504
505
506
507
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
508
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
509
            batch = cast(list[torch.Tensor], batch)
510
511
512
513
514
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
515
            first_shape = batch[0].shape
516
            if all(elem.shape == first_shape for elem in batch):
517
518
519
520
521
522
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
523
                return torch.stack(batch, out=out)
524
525
526
527

        return batch


528
@dataclass(frozen=True, kw_only=True)
529
530
class MultiModalFlatField(BaseMultiModalField):
    """
531
    Info:
532
533
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
534
    """
535

536
    slices: Sequence[slice] | Sequence[Sequence[slice]]
537
    dim: int = 0
538

539
    def build_elems(
540
        self,
541
542
543
544
545
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
546
        if not is_list_of(self.slices, slice, check="all"):
547
            assert isinstance(data, torch.Tensor), (
548
                "torch.Tensor is required for multiple slices"
549
            )
550
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
551

552
553
554
555
556
557
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
558
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
559
            batch = cast(list[torch.Tensor], batch)
560
561
562
563
564
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
565

566
567
568
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
569
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
570

571
            first_shape = _shape_before_after(batch[0])
572

573
574
575
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
576
577
578
579
580
581
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
582
                return torch.concat(batch, dim=self.dim, out=out)
583
584

        assert self.dim == 0, "dim == 0 is required for nested list"
585
        return [e for elem in batch for e in elem]
586
587


588
@dataclass(frozen=True, kw_only=True)
589
590
class MultiModalSharedField(BaseMultiModalField):
    """
591
    Info:
592
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
593
    """
594

595
596
597
598
599
600
601
602
603
604
605
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

606
607
608
609
610
611
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
612
613
614
        return batch[0]


615
@dataclass(frozen=True)
616
617
class MultiModalFieldConfig:
    @staticmethod
618
    def batched(modality: str, *, keep_on_cpu: bool = False):
619
620
621
622
623
624
625
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
626
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
627
628
629

        Example:

630
631
632
633
634
635
636
637
638
639
640
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
641
        """
642
        return MultiModalFieldConfig(
643
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
644
645
646
647
            modality=modality,
        )

    @staticmethod
648
649
    def flat(
        modality: str,
650
        slices: Sequence[slice] | Sequence[Sequence[slice]],
651
        dim: int = 0,
652
653
        *,
        keep_on_cpu: bool = False,
654
    ):
655
656
657
658
659
660
661
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
662
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
663
                slices (dim>0) that is used to extract the data corresponding
664
665
                to it.
            dim: The dimension to extract data, default to 0.
666
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
667
668
669

        Example:

670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
699
        """
700
        return MultiModalFieldConfig(
701
702
703
704
705
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
706
707
708
            modality=modality,
        )

709
    @staticmethod
710
711
712
713
714
715
716
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
717
718
719
720
721
722
723
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
724
725
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
726
            dim: The dimension to slice, default to 0.
727
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
728
729
730

        Example:

731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
746
            size_per_item: [3, 4, 2]
747
748
749
750
751
752
753
754
755
756
757
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

758
        Info:
759
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
760
761
        """

762
        if size_per_item.ndim != 1:
763
764
765
766
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
767

768
        slice_idxs = [0, *accumulate(size_per_item)]
769
770
771
772
773
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
774

775
776
777
778
779
780
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
781

782
    @staticmethod
783
784
785
786
787
788
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
789
790
791
792
793
794
795
796
797
798
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
799
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
800
801
802

        Example:

803
804
805
        ```
        Given:
            batch_size: 4
806

807
808
        Input:
            Data: [XYZ]
809

810
811
812
813
814
815
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
816
817
        """
        return MultiModalFieldConfig(
818
819
820
821
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
822
823
824
            modality=modality,
        )

825
826
    field: BaseMultiModalField
    modality: str
827

828
    def build_elems(
829
830
831
        self,
        key: str,
        batch: NestedTensors,
832
    ) -> Sequence[MultiModalFieldElem]:
833
        return self.field.build_elems(self.modality, key, batch)
834
835


836
837
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
838
839
840
841
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
842
    """
843

844
    @staticmethod
845
    def dummy(modality: str, nbytes: int = 1):
846
847
848
849
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
850
            data=torch.empty(nbytes, dtype=torch.uint8),
851
            field=MultiModalSharedField(batch_size=1),
852
853
854
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

855
856
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
857
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
858

859
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
860
861
        super().__init__(data)

862
        modalities = {elem.modality for elem in self.values()}
863
        assert len(modalities) == 1, f"Found different modalities={modalities}"
864
865
866
867
868
869
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

870
    def get_data(self) -> dict[str, NestedTensors]:
871
        return {key: elem.data for key, elem in self.items()}
872
873


874
875
876
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
877
    MultiModalKwargsItem | None,
878
879
880
881
882
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
883
    """
884
885
886
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
887
888
    """

889
890
    @staticmethod
    def from_hf_inputs(
891
        hf_inputs: "BatchFeature",
892
893
894
895
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
914
915
                    f"Found: {batch_sizes=}"
                )
916
917
918
919
920
921

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

922
        return MultiModalKwargsItems.from_seq(items)
923

924
925
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
926
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
927
        return MultiModalKwargsItems(items_by_modality)
928

929
    def __getitem__(self, modality: str) -> Sequence[_I]:
930
        if modality not in self:
931
932
933
934
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
935

936
        return super().__getitem__(modality)  # type: ignore[return-value]
937

938
939
940
941
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
942
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
943
944
945

        return self  # type: ignore[return-value]

946
947
948
949
950
951
952
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
953
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
954
955
956
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
957
958
959
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
960

961
962
963
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

964
        data = {
965
966
967
968
969
            key: elems[0].field.reduce_data(
                elems,
                device=device,
                pin_memory=pin_memory,
            )
970
971
972
973
            for key, elems in elems_by_key.items()
        }

        return data
974

975

976
977
978
979
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
980
981


982
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
983
"""
984
A dictionary containing placeholder ranges for each modality.
985
986
987
"""


988
class MultiModalInputs(TypedDict):
989
    """
990
    Represents the outputs of
991
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
992
993
994
995
996
997
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

998
    prompt_token_ids: list[int]
999
1000
    """The processed token IDs which includes placeholder tokens."""

1001
    mm_kwargs: MultiModalKwargsOptionalItems
1002
1003
    """Keyword arguments to be directly passed to the model after batching."""

1004
    mm_hashes: "MultiModalHashes"
1005
1006
    """The hashes of the multi-modal data."""

1007
    mm_placeholders: "MultiModalPlaceholderDict"
1008
1009
    """
    For each modality, information about the placeholder tokens in
1010
    `prompt_token_ids`.
1011
    """
1012

1013
1014
1015
1016
1017
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1018
1019
1020

class MultiModalEncDecInputs(MultiModalInputs):
    """
1021
1022
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1023
1024
1025
1026
1027
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""