inputs.py 32.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar, deprecated
24

25
from vllm.utils.collection_utils import full_groupby, is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
if TYPE_CHECKING:
30
31
32
33
34
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

35
    from .base import MediaWithBytes
36
37
    from .processing import MultiModalHashes

38
39
else:
    torch = LazyLoader("torch", globals(), "torch")
40

41
42
_T = TypeVar("_T")

43
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
44
"""
45
A `transformers.image_utils.ImageInput` representing a single image
46
item, which can be passed to a HuggingFace `ImageProcessor`.
47
48
"""

49
50
51
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
52
"""
53
A `transformers.image_utils.VideoInput` representing a single video
54
item, which can be passed to a HuggingFace `VideoProcessor`.
55
56
"""

57
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
58
"""
59
Represents a single audio
60
item, which can be passed to a HuggingFace `AudioProcessor`.
61
62
"""

63
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
64
"""
65
A `transformers.image_utils.ImageInput` representing a single image
66
item, which can be passed to a HuggingFace `ImageProcessor`.
67
68
69
70
71
72

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

73
74
75
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
76
"""
77
78
79
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
80
81
82
83
84
85

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

86
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
87
88
"""
Represents a single audio
89
item, which can be passed to a HuggingFace `AudioProcessor`.
90
91
92
93
94
95
96
97
98
99

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

100
ModalityData: TypeAlias = _T | list[_T | None] | None
101
"""
102
103
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
104
105

The number of data items allowed per modality is restricted by
106
`--limit-mm-per-prompt`.
107
108
109
110
111
112
113
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

114
    image: ModalityData[ImageItem]
115
116
    """The input image(s)."""

117
    video: ModalityData[VideoItem]
118
119
    """The input video(s)."""

120
    audio: ModalityData[AudioItem]
121
122
123
    """The input audio(s)."""


124
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
125
126
"""
A dictionary containing an entry for each modality type to input.
127

128
129
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
130
131
"""

132
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
133
134
135
136
137
138
139
140
141
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

142

143
144
@dataclass(frozen=True)
class PlaceholderRange:
145
146
147
    """
    Placeholder location information for multi-modal data.

148
149
    Example:

150
    Prompt: `AAAA BBBB What is in these images?`
151

152
    Images A and B will have:
153

154
155
156
157
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
158
159
160
161
162
163
164
165
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

166
    is_embed: Optional["torch.Tensor"] = None
167
168
169
170
171
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

172
173
    @cached_property
    def embeds_cumsum(self) -> torch.Tensor | None:
174
        if self.is_embed is None:
175
176
177
178
179
180
181
            return None

        return self.is_embed.cumsum(dim=0)

    @cached_property
    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
182
183
            return self.length

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        return int(self.embeds_cumsum[-1])

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

        embeds_start_idx = (
            int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
        )
        embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])

        return embeds_start_idx, embeds_end_idx
208

209
210
211
212
213
214
215
216
217
218
219
220
221
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
222
            return [(self.offset, self.offset + self.length - 1)]
223
224
225
226
227
228
229
230
231
232
233

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

234
235
236
237
238
239
240
241
242
243
244
245
246
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

247

248
249
250
251
252
253
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
254
255
256
257
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

258
259

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
260
261
262
263
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
264
    if isinstance(a, torch.Tensor):
265
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
266
    elif isinstance(b, torch.Tensor):
267
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
268
269

    if isinstance(a, list):
270
271
272
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
273
    if isinstance(b, list):
274
275
276
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
277
278
279
280
281

    # Both a and b are scalars
    return a == b


282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


299
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
300
301
"""
A dictionary containing nested tensors which have been batched via
302
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
303
304
305
"""


306
307
308
309
310
311
312
313
314
315
316
317
318
319
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
    for k in a:
        if k not in b:
            return False
        if not nested_tensors_equal(a[k], b[k]):
            return False

    return True


320
321
322
323
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
324

325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""

342
343
344
345
346
347
348
349
350
351
352
353
354
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

355

356
@dataclass
357
class MultiModalFieldElem:
358
359
    """
    Represents a keyword argument corresponding to a multi-modal item
360
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
361
362
363
364
365
366
367
368
369
370
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
371
372
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
373
374
375
    i.e. the name of the keyword argument to be passed to the model.
    """

376
    data: NestedTensors
377
    """
378
379
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
380
    i.e. the value of the keyword argument to be passed to the model.
381
382
383

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
384
385
386
387
388
389
390
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
391
392
393
394
395

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

396
397
398
399
400
401
402
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

403
404
405
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
406
            and type(self.field) is type(other.field)
407
        )  # noqa: E721
408
409


410
@dataclass(frozen=True, kw_only=True)
411
class BaseMultiModalField(ABC):
412
413
    """
    Defines how to interpret tensor data belonging to a keyword argument in
414
415
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
416
417
    """

418
419
420
421
422
423
    keep_on_cpu: bool = False
    """
    If `True`, then this field is excluded from being moved to the accelerator
    when `MultiModalKwargsItems.get_data()` is called to batch the data.
    """

424
425
426
427
428
429
430
431
432
433
434
435
436
    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
437
438

    @abstractmethod
439
440
441
442
443
444
445
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
446
447
448
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
449

450
451
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
452
        """
453
454
        raise NotImplementedError

455
    @abstractmethod
456
457
458
459
460
461
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
462
        raise NotImplementedError
463

464
465
466
467
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
468
        device: torch.types.Device = None,
469
470
        pin_memory: bool = False,
    ) -> NestedTensors:
471
        """
472
473
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
474

475
476
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
477
478
479
480
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
481

482
483
484
485
486
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

487
        batch = [elem.data for elem in elems]
488
489
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
490
491


492
@dataclass(frozen=True, kw_only=True)
493
494
class MultiModalBatchedField(BaseMultiModalField):
    """
495
    Info:
496
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
497
498
    """

499
500
501
502
503
504
505
506
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
507

508
509
510
511
512
513
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
514
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
515
            batch = cast(list[torch.Tensor], batch)
516
517
518
519
520
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
521
            first_shape = batch[0].shape
522
            if all(elem.shape == first_shape for elem in batch):
523
524
525
526
527
528
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
529
                return torch.stack(batch, out=out)
530
531
532
533

        return batch


534
@dataclass(frozen=True, kw_only=True)
535
536
class MultiModalFlatField(BaseMultiModalField):
    """
537
    Info:
538
539
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
540
    """
541

542
    slices: Sequence[slice] | Sequence[Sequence[slice]]
543
    dim: int = 0
544

545
    def build_elems(
546
        self,
547
548
549
550
551
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
552
        if not is_list_of(self.slices, slice, check="all"):
553
            assert isinstance(data, torch.Tensor), (
554
                "torch.Tensor is required for multiple slices"
555
            )
556
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
557

558
559
560
561
562
563
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
564
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
565
            batch = cast(list[torch.Tensor], batch)
566
567
568
569
570
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
571

572
573
574
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
575
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
576

577
            first_shape = _shape_before_after(batch[0])
578

579
580
581
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
582
583
584
585
586
587
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
588
                return torch.concat(batch, dim=self.dim, out=out)
589
590

        assert self.dim == 0, "dim == 0 is required for nested list"
591
        return [e for elem in batch for e in elem]
592
593


594
@dataclass(frozen=True, kw_only=True)
595
596
class MultiModalSharedField(BaseMultiModalField):
    """
597
    Info:
598
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
599
    """
600

601
602
603
604
605
606
607
608
609
610
611
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

612
613
614
615
616
617
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
618
619
620
        return batch[0]


621
@dataclass(frozen=True)
622
623
class MultiModalFieldConfig:
    @staticmethod
624
    def batched(modality: str, *, keep_on_cpu: bool = False):
625
626
627
628
629
630
631
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
632
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
633
634
635

        Example:

636
637
638
639
640
641
642
643
644
645
646
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
647
        """
648
        return MultiModalFieldConfig(
649
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
650
651
652
653
            modality=modality,
        )

    @staticmethod
654
655
    def flat(
        modality: str,
656
        slices: Sequence[slice] | Sequence[Sequence[slice]],
657
        dim: int = 0,
658
659
        *,
        keep_on_cpu: bool = False,
660
    ):
661
662
663
664
665
666
667
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
668
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
669
                slices (dim>0) that is used to extract the data corresponding
670
671
                to it.
            dim: The dimension to extract data, default to 0.
672
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
673
674
675

        Example:

676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
705
        """
706
        return MultiModalFieldConfig(
707
708
709
710
711
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
712
713
714
            modality=modality,
        )

715
    @staticmethod
716
717
718
719
720
721
722
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
723
724
725
726
727
728
729
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
730
731
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
732
            dim: The dimension to slice, default to 0.
733
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
734
735
736

        Example:

737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
752
            size_per_item: [3, 4, 2]
753
754
755
756
757
758
759
760
761
762
763
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

764
        Info:
765
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
766
767
        """

768
        if size_per_item.ndim != 1:
769
770
771
772
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
773

774
        slice_idxs = [0, *accumulate(size_per_item)]
775
776
777
778
779
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
780

781
782
783
784
785
786
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
787

788
    @staticmethod
789
790
791
792
793
794
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
795
796
797
798
799
800
801
802
803
804
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
805
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
806
807
808

        Example:

809
810
811
        ```
        Given:
            batch_size: 4
812

813
814
        Input:
            Data: [XYZ]
815

816
817
818
819
820
821
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
822
823
        """
        return MultiModalFieldConfig(
824
825
826
827
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
828
829
830
            modality=modality,
        )

831
832
    field: BaseMultiModalField
    modality: str
833

834
    def build_elems(
835
836
837
        self,
        key: str,
        batch: NestedTensors,
838
    ) -> Sequence[MultiModalFieldElem]:
839
        return self.field.build_elems(self.modality, key, batch)
840
841


842
843
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
844
845
846
847
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
848
    """
849

850
    @staticmethod
851
    def dummy(modality: str, nbytes: int = 1):
852
853
854
855
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
856
            data=torch.empty(nbytes, dtype=torch.uint8),
857
            field=MultiModalSharedField(batch_size=1),
858
859
860
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

861
862
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
863
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
864

865
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
866
867
        super().__init__(data)

868
        modalities = {elem.modality for elem in self.values()}
869
        assert len(modalities) == 1, f"Found different modalities={modalities}"
870
871
872
873
874
875
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

876
    def get_data(self) -> dict[str, NestedTensors]:
877
        return {key: elem.data for key, elem in self.items()}
878
879


880
881
882
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
883
    MultiModalKwargsItem | None,
884
885
886
887
888
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
889
    """
890
891
892
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
893
894
    """

895
896
    @staticmethod
    def from_hf_inputs(
897
        hf_inputs: "BatchFeature",
898
899
900
901
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
920
921
                    f"Found: {batch_sizes=}"
                )
922
923
924
925
926
927

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

928
        return MultiModalKwargsItems.from_seq(items)
929

930
931
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
932
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
933
        return MultiModalKwargsItems(items_by_modality)
934

935
    def __getitem__(self, modality: str) -> Sequence[_I]:
936
        if modality not in self:
937
938
939
940
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
941

942
        return super().__getitem__(modality)  # type: ignore[return-value]
943

944
945
946
947
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
948
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
949
950
951

        return self  # type: ignore[return-value]

952
953
954
955
956
957
958
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
959
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
960
961
962
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
963
964
965
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
966

967
968
969
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

970
        data = {
971
972
973
974
975
            key: elems[0].field.reduce_data(
                elems,
                device=device,
                pin_memory=pin_memory,
            )
976
977
978
979
            for key, elems in elems_by_key.items()
        }

        return data
980

981

982
983
984
985
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
986
987


988
@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.14.")
989
990
991
992
993
994
995
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
996
997
    @deprecated(
        "`MultiModalKwargs.from_hf_inputs` is deprecated and "
998
        "will be removed in v0.14. "
999
1000
1001
        "Please use `MultiModalKwargsItems.from_hf_inputs` and "
        "access the tensor data using `.get_data()`."
    )
1002
1003
1004
1005
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
1006
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data()
1007
1008

    @staticmethod
1009
1010
    @deprecated(
        "`MultiModalKwargs.from_items` is deprecated and "
1011
        "will be removed in v0.14. "
1012
1013
1014
        "Please use `MultiModalKwargsItems.from_seq` and "
        "access the tensor data using `.get_data()`."
    )
1015
1016
1017
1018
1019
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
1020
        return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
1021

1022
    def __getitem__(self, key: str):
1023
        if key not in self:
1024
1025
1026
1027
            raise KeyError(
                f"Keyword argument {key!r} not found. "
                f"Available keys: {set(self.keys())}"
            )
1028
1029

        return super().__getitem__(key)
1030

1031
1032
1033
1034
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

1035
1036
1037
1038
1039
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
1040

1041
        return True
1042

1043

1044
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
1045
"""
1046
A dictionary containing placeholder ranges for each modality.
1047
1048
1049
"""


1050
class MultiModalInputs(TypedDict):
1051
    """
1052
    Represents the outputs of
1053
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
1054
1055
1056
1057
1058
1059
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

1060
    prompt_token_ids: list[int]
1061
1062
    """The processed token IDs which includes placeholder tokens."""

1063
    mm_kwargs: MultiModalKwargsOptionalItems
1064
1065
    """Keyword arguments to be directly passed to the model after batching."""

1066
    mm_hashes: "MultiModalHashes"
1067
1068
    """The hashes of the multi-modal data."""

1069
    mm_placeholders: "MultiModalPlaceholderDict"
1070
1071
    """
    For each modality, information about the placeholder tokens in
1072
    `prompt_token_ids`.
1073
    """
1074

1075
1076
1077
1078
1079
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1080
1081
1082

class MultiModalEncDecInputs(MultiModalInputs):
    """
1083
1084
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1085
1086
1087
1088
1089
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""