inputs.py 30.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar, deprecated
24

25
from vllm.utils import LazyLoader, full_groupby, is_list_of
26
from vllm.utils.jsontree import json_map_leaves
27

28
if TYPE_CHECKING:
29
30
31
32
33
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

34
35
    from .processing import MultiModalHashes

36
37
else:
    torch = LazyLoader("torch", globals(), "torch")
38

39
40
_T = TypeVar("_T")

41
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
42
"""
43
A `transformers.image_utils.ImageInput` representing a single image
44
item, which can be passed to a HuggingFace `ImageProcessor`.
45
46
"""

47
48
49
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
50
"""
51
A `transformers.image_utils.VideoInput` representing a single video
52
item, which can be passed to a HuggingFace `VideoProcessor`.
53
54
"""

55
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
56
"""
57
Represents a single audio
58
item, which can be passed to a HuggingFace `AudioProcessor`.
59
60
"""

61
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
62
"""
63
A `transformers.image_utils.ImageInput` representing a single image
64
item, which can be passed to a HuggingFace `ImageProcessor`.
65
66
67
68
69
70

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

71
72
73
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
74
"""
75
76
77
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
78
79
80
81
82
83

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

84
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
85
86
"""
Represents a single audio
87
item, which can be passed to a HuggingFace `AudioProcessor`.
88
89
90
91
92
93
94
95
96
97

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

98
ModalityData: TypeAlias = _T | list[_T | None] | None
99
"""
100
101
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
102
103

The number of data items allowed per modality is restricted by
104
`--limit-mm-per-prompt`.
105
106
107
108
109
110
111
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

112
    image: ModalityData[ImageItem]
113
114
    """The input image(s)."""

115
    video: ModalityData[VideoItem]
116
117
    """The input video(s)."""

118
    audio: ModalityData[AudioItem]
119
120
121
    """The input audio(s)."""


122
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
123
124
"""
A dictionary containing an entry for each modality type to input.
125

126
127
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
128
129
"""

130
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
131
132
133
134
135
136
137
138
139
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

140

141
142
@dataclass(frozen=True)
class PlaceholderRange:
143
144
145
    """
    Placeholder location information for multi-modal data.

146
147
    Example:

148
    Prompt: `AAAA BBBB What is in these images?`
149

150
    Images A and B will have:
151

152
153
154
155
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
156
157
158
159
160
161
162
163
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

164
    is_embed: Optional["torch.Tensor"] = None
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

189

190
191
192
193
194
195
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
196
197
198
199
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

200
201

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
202
203
    """Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
204
    if isinstance(a, torch.Tensor):
205
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
206
    elif isinstance(b, torch.Tensor):
207
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
208
209

    if isinstance(a, list):
210
211
212
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
213
    if isinstance(b, list):
214
215
216
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
217
218
219
220
221

    # Both a and b are scalars
    return a == b


222
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
223
224
"""
A dictionary containing nested tensors which have been batched via
225
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
226
227
228
"""


229
230
231
232
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
233

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""


252
@dataclass
253
class MultiModalFieldElem:
254
255
    """
    Represents a keyword argument corresponding to a multi-modal item
256
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
257
258
259
260
261
262
263
264
265
266
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
267
268
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
269
270
271
    i.e. the name of the keyword argument to be passed to the model.
    """

272
    data: NestedTensors
273
    """
274
275
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
276
    i.e. the value of the keyword argument to be passed to the model.
277
278
279

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
280
281
282
283
284
285
286
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
287
288
289
290
291

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

292
293
294
295
296
297
298
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

299
300
301
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
302
            and type(self.field) is type(other.field)
303
        )  # noqa: E721
304
305
306
307


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
308
309
    """
    Defines how to interpret tensor data belonging to a keyword argument in
310
311
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
327
328

    @abstractmethod
329
330
331
332
333
334
335
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
336
337
338
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
339

340
341
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
342
        """
343
344
        raise NotImplementedError

345
    @abstractmethod
346
347
348
349
350
351
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
352
        raise NotImplementedError
353

354
355
356
357
358
359
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
        pin_memory: bool = False,
    ) -> NestedTensors:
360
        """
361
362
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
363

364
365
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
366
367
368
369
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
370

371
372
        batch = [elem.data for elem in elems]
        return self._reduce_data(batch, pin_memory=pin_memory)
373
374
375
376
377


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
378
    Info:
379
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
380
381
    """

382
383
384
385
386
387
388
389
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
390

391
392
393
394
395
396
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
397
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
398
            batch = cast(list[torch.Tensor], batch)
399
400
401
402
403
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
404
            first_shape = batch[0].shape
405
            if all(elem.shape == first_shape for elem in batch):
406
407
408
409
410
411
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
412
                return torch.stack(batch, out=out)
413
414
415
416
417
418
419

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
420
    Info:
421
422
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
423
    """
424

425
    slices: Sequence[slice] | Sequence[Sequence[slice]]
426
    dim: int = 0
427

428
    def build_elems(
429
        self,
430
431
432
433
434
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
435
        if not is_list_of(self.slices, slice, check="all"):
436
            assert isinstance(data, torch.Tensor), (
437
                "torch.Tensor is required for multiple slices"
438
            )
439
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
440

441
442
443
444
445
446
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
447
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
448
            batch = cast(list[torch.Tensor], batch)
449
450
451
452
453
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
454

455
456
457
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
458
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
459

460
            first_shape = _shape_before_after(batch[0])
461

462
463
464
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
465
466
467
468
469
470
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
471
                return torch.concat(batch, dim=self.dim, out=out)
472
473

        assert self.dim == 0, "dim == 0 is required for nested list"
474
        return [e for elem in batch for e in elem]
475
476


477
478
479
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
480
    Info:
481
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
482
    """
483

484
485
486
487
488
489
490
491
492
493
494
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

495
496
497
498
499
500
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
501
502
503
        return batch[0]


504
505
506
class MultiModalFieldConfig:
    @staticmethod
    def batched(modality: str):
507
508
509
510
511
512
513
514
515
516
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

517
518
519
520
521
522
523
524
525
526
527
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
528
        """
529
        return MultiModalFieldConfig(
530
            field=MultiModalBatchedField(),
531
532
533
534
            modality=modality,
        )

    @staticmethod
535
536
    def flat(
        modality: str,
537
        slices: Sequence[slice] | Sequence[Sequence[slice]],
538
539
        dim: int = 0,
    ):
540
541
542
543
544
545
546
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
547
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
548
                slices (dim>0) that is used to extract the data corresponding
549
550
                to it.
            dim: The dimension to extract data, default to 0.
551
552
553

        Example:

554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
583
        """
584
        return MultiModalFieldConfig(
585
            field=MultiModalFlatField(slices=slices, dim=dim),
586
587
588
            modality=modality,
        )

589
    @staticmethod
590
    def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0):
591
592
593
594
595
596
597
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
598
599
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
600
            dim: The dimension to slice, default to 0.
601
602
603

        Example:

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
619
            size_per_item: [3, 4, 2]
620
621
622
623
624
625
626
627
628
629
630
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

631
        Info:
632
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
633
634
        """

635
        if size_per_item.ndim != 1:
636
637
638
639
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
640

641
        slice_idxs = [0, *accumulate(size_per_item)]
642
643
644
645
646
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
647

648
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
649

650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

665
666
667
        ```
        Given:
            batch_size: 4
668

669
670
        Input:
            Data: [XYZ]
671

672
673
674
675
676
677
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
678
679
680
681
682
683
684
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
685
686
        super().__init__()

687
        self.field = field
688
        self.modality = modality
689

690
691
692
    def __repr__(self) -> str:
        return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"

693
    def build_elems(
694
695
696
        self,
        key: str,
        batch: NestedTensors,
697
    ) -> Sequence[MultiModalFieldElem]:
698
        return self.field.build_elems(self.modality, key, batch)
699
700


701
702
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
703
704
705
706
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
707
    """
708

709
710
711
712
713
714
715
716
717
718
719
    @staticmethod
    def dummy(modality: str):
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
            data=torch.empty(1),
            field=MultiModalSharedField(1),
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

720
721
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
722
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
723

724
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
725
726
        super().__init__(data)

727
        modalities = {elem.modality for elem in self.values()}
728
        assert len(modalities) == 1, f"Found different modalities={modalities}"
729
730
731
732
733
734
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

735
    def get_data(self) -> dict[str, NestedTensors]:
736
        return {key: elem.data for key, elem in self.items()}
737
738


739
740
741
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
742
    MultiModalKwargsItem | None,
743
744
745
746
747
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
748
    """
749
750
751
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
752
753
    """

754
755
    @staticmethod
    def from_hf_inputs(
756
        hf_inputs: "BatchFeature",
757
758
759
760
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
779
780
                    f"Found: {batch_sizes=}"
                )
781
782
783
784
785
786

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

787
        return MultiModalKwargsItems.from_seq(items)
788

789
790
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
791
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
792
        return MultiModalKwargsItems(items_by_modality)
793

794
    def __getitem__(self, modality: str) -> Sequence[_I]:
795
        if modality not in self:
796
797
798
799
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
800

801
        return super().__getitem__(modality)  # type: ignore[return-value]
802

803
804
805
806
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
807
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
808
809
810

        return self  # type: ignore[return-value]

811
812
    def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
813
814
815
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
816
817
818
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
819

820
821
822
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

823
824
825
826
827
828
        return MultiModalKwargs(
            {
                key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
                for key, elems in elems_by_key.items()
            }
        )
829

830

831
832
833
834
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
835
836


837
838
839
840
841
842
843
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
844
845
846
847
848
849
    @deprecated(
        "`MultiModalKwargs.from_hf_inputs` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_hf_inputs` and "
        "access the tensor data using `.get_data()`."
    )
850
851
852
853
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
854
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data()
855
856

    @staticmethod
857
858
859
860
861
862
    @deprecated(
        "`MultiModalKwargs.from_items` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_seq` and "
        "access the tensor data using `.get_data()`."
    )
863
864
865
866
867
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
868
        return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
869

870
    @staticmethod
871
872
873
    def _try_stack(
        nested_tensors: NestedTensors, pin_memory: bool = False
    ) -> NestedTensors:
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

890
        stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors]
891
892
893
894
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

895
        tensors_ = cast(list[torch.Tensor], stacked)
896
897
898
899
900
901
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

902
903
904
905
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

906
907
908
909
910
911
912
        outputs = torch.empty(
            len(tensors_),
            *tensors_[0].shape,
            dtype=tensors_[0].dtype,
            device=tensors_[0].device,
            pin_memory=pin_memory,
        )
913
        return torch.stack(tensors_, out=outputs)
914
915

    @staticmethod
916
917
918
    def batch(
        inputs_list: list["MultiModalKwargs"], pin_memory: bool = False
    ) -> BatchedTensorInputs:
919
920
921
922
923
924
925
926
927
928
929
930
931
932
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
933
        item_lists = defaultdict[str, list[NestedTensors]](list)
934
935
936
937
938
939

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
940
            k: MultiModalKwargs._try_stack(item_list, pin_memory)
941
942
943
944
945
946
947
948
949
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
950
        return json_map_leaves(
951
            lambda x: x.to(device=device, non_blocking=True),
952
            batched_inputs,
953
954
        )

955
    def __getitem__(self, key: str):
956
        if key not in self:
957
958
959
960
            raise KeyError(
                f"Keyword argument {key!r} not found. "
                f"Available keys: {set(self.keys())}"
            )
961
962

        return super().__getitem__(key)
963

964
965
966
967
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

968
969
970
971
972
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
973

974
        return True
975

976

977
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
978
"""
979
A dictionary containing placeholder ranges for each modality.
980
981
982
"""


983
class MultiModalInputs(TypedDict):
984
    """
985
    Represents the outputs of
986
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
987
988
989
990
991
992
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

993
    prompt_token_ids: list[int]
994
995
    """The processed token IDs which includes placeholder tokens."""

996
    mm_kwargs: MultiModalKwargsOptionalItems
997
998
    """Keyword arguments to be directly passed to the model after batching."""

999
    mm_hashes: "MultiModalHashes"
1000
1001
    """The hashes of the multi-modal data."""

1002
    mm_placeholders: "MultiModalPlaceholderDict"
1003
1004
    """
    For each modality, information about the placeholder tokens in
1005
    `prompt_token_ids`.
1006
    """
1007

1008
1009
1010
1011
1012
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1013
1014
1015

class MultiModalEncDecInputs(MultiModalInputs):
    """
1016
1017
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1018
1019
1020
1021
1022
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""