"...crd/bases/nvidia.com_dynamocomponentdeployments.yaml" did not exist on "a82f350a0c63cfa125fce23d30630ca2608d2cb6"
inputs.py 31.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar, deprecated
24

25
from vllm.utils.collection_utils import full_groupby, is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
if TYPE_CHECKING:
30
31
32
33
34
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

35
    from .base import MediaWithBytes
36
37
    from .processing import MultiModalHashes

38
39
else:
    torch = LazyLoader("torch", globals(), "torch")
40

41
42
_T = TypeVar("_T")

43
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
44
"""
45
A `transformers.image_utils.ImageInput` representing a single image
46
item, which can be passed to a HuggingFace `ImageProcessor`.
47
48
"""

49
50
51
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
52
"""
53
A `transformers.image_utils.VideoInput` representing a single video
54
item, which can be passed to a HuggingFace `VideoProcessor`.
55
56
"""

57
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
58
"""
59
Represents a single audio
60
item, which can be passed to a HuggingFace `AudioProcessor`.
61
62
"""

63
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
64
"""
65
A `transformers.image_utils.ImageInput` representing a single image
66
item, which can be passed to a HuggingFace `ImageProcessor`.
67
68
69
70
71
72

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

73
74
75
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
76
"""
77
78
79
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
80
81
82
83
84
85

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

86
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
87
88
"""
Represents a single audio
89
item, which can be passed to a HuggingFace `AudioProcessor`.
90
91
92
93
94
95
96
97
98
99

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

100
ModalityData: TypeAlias = _T | list[_T | None] | None
101
"""
102
103
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
104
105

The number of data items allowed per modality is restricted by
106
`--limit-mm-per-prompt`.
107
108
109
110
111
112
113
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

114
    image: ModalityData[ImageItem]
115
116
    """The input image(s)."""

117
    video: ModalityData[VideoItem]
118
119
    """The input video(s)."""

120
    audio: ModalityData[AudioItem]
121
122
123
    """The input audio(s)."""


124
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
125
126
"""
A dictionary containing an entry for each modality type to input.
127

128
129
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
130
131
"""

132
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
133
134
135
136
137
138
139
140
141
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

142

143
144
@dataclass(frozen=True)
class PlaceholderRange:
145
146
147
    """
    Placeholder location information for multi-modal data.

148
149
    Example:

150
    Prompt: `AAAA BBBB What is in these images?`
151

152
    Images A and B will have:
153

154
155
156
157
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
158
159
160
161
162
163
164
165
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

166
    is_embed: Optional["torch.Tensor"] = None
167
168
169
170
171
172
173
174
175
176
177
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
            return [(self.offset, self.offset + self.length)]

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

203
204
205
206
207
208
209
210
211
212
213
214
215
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

216

217
218
219
220
221
222
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
223
224
225
226
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

227
228

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
229
230
231
232
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
233
    if isinstance(a, torch.Tensor):
234
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
235
    elif isinstance(b, torch.Tensor):
236
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
237
238

    if isinstance(a, list):
239
240
241
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
242
    if isinstance(b, list):
243
244
245
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
246
247
248
249
250

    # Both a and b are scalars
    return a == b


251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


268
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
269
270
"""
A dictionary containing nested tensors which have been batched via
271
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
272
273
274
"""


275
276
277
278
279
280
281
282
283
284
285
286
287
288
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
    for k in a:
        if k not in b:
            return False
        if not nested_tensors_equal(a[k], b[k]):
            return False

    return True


289
290
291
292
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
293

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""

311
312
313
314
315
316
317
318
319
320
321
322
323
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

324

325
@dataclass
326
class MultiModalFieldElem:
327
328
    """
    Represents a keyword argument corresponding to a multi-modal item
329
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
330
331
332
333
334
335
336
337
338
339
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
340
341
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
342
343
344
    i.e. the name of the keyword argument to be passed to the model.
    """

345
    data: NestedTensors
346
    """
347
348
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
349
    i.e. the value of the keyword argument to be passed to the model.
350
351
352

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
353
354
355
356
357
358
359
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
360
361
362
363
364

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

365
366
367
368
369
370
371
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

372
373
374
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
375
            and type(self.field) is type(other.field)
376
        )  # noqa: E721
377
378


379
@dataclass(frozen=True, kw_only=True)
380
class BaseMultiModalField(ABC):
381
382
    """
    Defines how to interpret tensor data belonging to a keyword argument in
383
384
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
385
386
    """

387
388
389
390
391
392
    keep_on_cpu: bool = False
    """
    If `True`, then this field is excluded from being moved to the accelerator
    when `MultiModalKwargsItems.get_data()` is called to batch the data.
    """

393
394
395
396
397
398
399
400
401
402
403
404
405
    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
406
407

    @abstractmethod
408
409
410
411
412
413
414
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
415
416
417
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
418

419
420
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
421
        """
422
423
        raise NotImplementedError

424
    @abstractmethod
425
426
427
428
429
430
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
431
        raise NotImplementedError
432

433
434
435
436
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
437
        device: torch.types.Device = None,
438
439
        pin_memory: bool = False,
    ) -> NestedTensors:
440
        """
441
442
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
443

444
445
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
446
447
448
449
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
450

451
452
453
454
455
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

456
        batch = [elem.data for elem in elems]
457
458
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
459
460


461
@dataclass(frozen=True, kw_only=True)
462
463
class MultiModalBatchedField(BaseMultiModalField):
    """
464
    Info:
465
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
466
467
    """

468
469
470
471
472
473
474
475
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
476

477
478
479
480
481
482
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
483
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
484
            batch = cast(list[torch.Tensor], batch)
485
486
487
488
489
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
490
            first_shape = batch[0].shape
491
            if all(elem.shape == first_shape for elem in batch):
492
493
494
495
496
497
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
498
                return torch.stack(batch, out=out)
499
500
501
502

        return batch


503
@dataclass(frozen=True, kw_only=True)
504
505
class MultiModalFlatField(BaseMultiModalField):
    """
506
    Info:
507
508
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
509
    """
510

511
    slices: Sequence[slice] | Sequence[Sequence[slice]]
512
    dim: int = 0
513

514
    def build_elems(
515
        self,
516
517
518
519
520
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
521
        if not is_list_of(self.slices, slice, check="all"):
522
            assert isinstance(data, torch.Tensor), (
523
                "torch.Tensor is required for multiple slices"
524
            )
525
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
526

527
528
529
530
531
532
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
533
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
534
            batch = cast(list[torch.Tensor], batch)
535
536
537
538
539
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
540

541
542
543
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
544
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
545

546
            first_shape = _shape_before_after(batch[0])
547

548
549
550
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
551
552
553
554
555
556
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
557
                return torch.concat(batch, dim=self.dim, out=out)
558
559

        assert self.dim == 0, "dim == 0 is required for nested list"
560
        return [e for elem in batch for e in elem]
561
562


563
@dataclass(frozen=True, kw_only=True)
564
565
class MultiModalSharedField(BaseMultiModalField):
    """
566
    Info:
567
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
568
    """
569

570
571
572
573
574
575
576
577
578
579
580
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

581
582
583
584
585
586
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
587
588
589
        return batch[0]


590
@dataclass(frozen=True)
591
592
class MultiModalFieldConfig:
    @staticmethod
593
    def batched(modality: str, *, keep_on_cpu: bool = False):
594
595
596
597
598
599
600
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
601
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
602
603
604

        Example:

605
606
607
608
609
610
611
612
613
614
615
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
616
        """
617
        return MultiModalFieldConfig(
618
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
619
620
621
622
            modality=modality,
        )

    @staticmethod
623
624
    def flat(
        modality: str,
625
        slices: Sequence[slice] | Sequence[Sequence[slice]],
626
        dim: int = 0,
627
628
        *,
        keep_on_cpu: bool = False,
629
    ):
630
631
632
633
634
635
636
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
637
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
638
                slices (dim>0) that is used to extract the data corresponding
639
640
                to it.
            dim: The dimension to extract data, default to 0.
641
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
642
643
644

        Example:

645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
674
        """
675
        return MultiModalFieldConfig(
676
677
678
679
680
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
681
682
683
            modality=modality,
        )

684
    @staticmethod
685
686
687
688
689
690
691
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
692
693
694
695
696
697
698
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
699
700
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
701
            dim: The dimension to slice, default to 0.
702
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
703
704
705

        Example:

706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
721
            size_per_item: [3, 4, 2]
722
723
724
725
726
727
728
729
730
731
732
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

733
        Info:
734
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
735
736
        """

737
        if size_per_item.ndim != 1:
738
739
740
741
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
742

743
        slice_idxs = [0, *accumulate(size_per_item)]
744
745
746
747
748
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
749

750
751
752
753
754
755
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
756

757
    @staticmethod
758
759
760
761
762
763
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
764
765
766
767
768
769
770
771
772
773
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
774
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
775
776
777

        Example:

778
779
780
        ```
        Given:
            batch_size: 4
781

782
783
        Input:
            Data: [XYZ]
784

785
786
787
788
789
790
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
791
792
        """
        return MultiModalFieldConfig(
793
794
795
796
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
797
798
799
            modality=modality,
        )

800
801
    field: BaseMultiModalField
    modality: str
802

803
    def build_elems(
804
805
806
        self,
        key: str,
        batch: NestedTensors,
807
    ) -> Sequence[MultiModalFieldElem]:
808
        return self.field.build_elems(self.modality, key, batch)
809
810


811
812
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
813
814
815
816
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
817
    """
818

819
    @staticmethod
820
    def dummy(modality: str, nbytes: int = 1):
821
822
823
824
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
825
            data=torch.empty(nbytes, dtype=torch.uint8),
826
            field=MultiModalSharedField(batch_size=1),
827
828
829
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

830
831
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
832
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
833

834
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
835
836
        super().__init__(data)

837
        modalities = {elem.modality for elem in self.values()}
838
        assert len(modalities) == 1, f"Found different modalities={modalities}"
839
840
841
842
843
844
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

845
    def get_data(self) -> dict[str, NestedTensors]:
846
        return {key: elem.data for key, elem in self.items()}
847
848


849
850
851
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
852
    MultiModalKwargsItem | None,
853
854
855
856
857
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
858
    """
859
860
861
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
862
863
    """

864
865
    @staticmethod
    def from_hf_inputs(
866
        hf_inputs: "BatchFeature",
867
868
869
870
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
889
890
                    f"Found: {batch_sizes=}"
                )
891
892
893
894
895
896

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

897
        return MultiModalKwargsItems.from_seq(items)
898

899
900
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
901
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
902
        return MultiModalKwargsItems(items_by_modality)
903

904
    def __getitem__(self, modality: str) -> Sequence[_I]:
905
        if modality not in self:
906
907
908
909
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
910

911
        return super().__getitem__(modality)  # type: ignore[return-value]
912

913
914
915
916
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
917
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
918
919
920

        return self  # type: ignore[return-value]

921
922
923
924
925
926
927
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
928
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
929
930
931
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
932
933
934
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
935

936
937
938
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

939
        data = {
940
941
942
943
944
            key: elems[0].field.reduce_data(
                elems,
                device=device,
                pin_memory=pin_memory,
            )
945
946
947
948
            for key, elems in elems_by_key.items()
        }

        return data
949

950

951
952
953
954
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
955
956


957
@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.13.")
958
959
960
961
962
963
964
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
965
966
967
968
969
970
    @deprecated(
        "`MultiModalKwargs.from_hf_inputs` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_hf_inputs` and "
        "access the tensor data using `.get_data()`."
    )
971
972
973
974
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
975
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data()
976
977

    @staticmethod
978
979
980
981
982
983
    @deprecated(
        "`MultiModalKwargs.from_items` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_seq` and "
        "access the tensor data using `.get_data()`."
    )
984
985
986
987
988
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
989
        return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
990

991
    def __getitem__(self, key: str):
992
        if key not in self:
993
994
995
996
            raise KeyError(
                f"Keyword argument {key!r} not found. "
                f"Available keys: {set(self.keys())}"
            )
997
998

        return super().__getitem__(key)
999

1000
1001
1002
1003
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

1004
1005
1006
1007
1008
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
1009

1010
        return True
1011

1012

1013
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
1014
"""
1015
A dictionary containing placeholder ranges for each modality.
1016
1017
1018
"""


1019
class MultiModalInputs(TypedDict):
1020
    """
1021
    Represents the outputs of
1022
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
1023
1024
1025
1026
1027
1028
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

1029
    prompt_token_ids: list[int]
1030
1031
    """The processed token IDs which includes placeholder tokens."""

1032
    mm_kwargs: MultiModalKwargsOptionalItems
1033
1034
    """Keyword arguments to be directly passed to the model after batching."""

1035
    mm_hashes: "MultiModalHashes"
1036
1037
    """The hashes of the multi-modal data."""

1038
    mm_placeholders: "MultiModalPlaceholderDict"
1039
1040
    """
    For each modality, information about the placeholder tokens in
1041
    `prompt_token_ids`.
1042
    """
1043

1044
1045
1046
1047
1048
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1049
1050
1051

class MultiModalEncDecInputs(MultiModalInputs):
    """
1052
1053
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1054
1055
1056
1057
1058
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""