inputs.py 31.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar, deprecated
24

25
from vllm.utils.collection_utils import full_groupby, is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
if TYPE_CHECKING:
30
31
32
33
34
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

35
36
    from .processing import MultiModalHashes

37
38
else:
    torch = LazyLoader("torch", globals(), "torch")
39

40
41
_T = TypeVar("_T")

42
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
43
"""
44
A `transformers.image_utils.ImageInput` representing a single image
45
item, which can be passed to a HuggingFace `ImageProcessor`.
46
47
"""

48
49
50
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
51
"""
52
A `transformers.image_utils.VideoInput` representing a single video
53
item, which can be passed to a HuggingFace `VideoProcessor`.
54
55
"""

56
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
57
"""
58
Represents a single audio
59
item, which can be passed to a HuggingFace `AudioProcessor`.
60
61
"""

62
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
63
"""
64
A `transformers.image_utils.ImageInput` representing a single image
65
item, which can be passed to a HuggingFace `ImageProcessor`.
66
67
68
69
70
71

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

72
73
74
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
75
"""
76
77
78
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
79
80
81
82
83
84

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

85
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
86
87
"""
Represents a single audio
88
item, which can be passed to a HuggingFace `AudioProcessor`.
89
90
91
92
93
94
95
96
97
98

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

99
ModalityData: TypeAlias = _T | list[_T | None] | None
100
"""
101
102
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
103
104

The number of data items allowed per modality is restricted by
105
`--limit-mm-per-prompt`.
106
107
108
109
110
111
112
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

113
    image: ModalityData[ImageItem]
114
115
    """The input image(s)."""

116
    video: ModalityData[VideoItem]
117
118
    """The input video(s)."""

119
    audio: ModalityData[AudioItem]
120
121
122
    """The input audio(s)."""


123
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
124
125
"""
A dictionary containing an entry for each modality type to input.
126

127
128
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
129
130
"""

131
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
132
133
134
135
136
137
138
139
140
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

141

142
143
@dataclass(frozen=True)
class PlaceholderRange:
144
145
146
    """
    Placeholder location information for multi-modal data.

147
148
    Example:

149
    Prompt: `AAAA BBBB What is in these images?`
150

151
    Images A and B will have:
152

153
154
155
156
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
157
158
159
160
161
162
163
164
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

165
    is_embed: Optional["torch.Tensor"] = None
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

190

191
192
193
194
195
196
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
197
198
199
200
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

201
202

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
203
204
    """Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
205
    if isinstance(a, torch.Tensor):
206
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
207
    elif isinstance(b, torch.Tensor):
208
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
209
210

    if isinstance(a, list):
211
212
213
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
214
    if isinstance(b, list):
215
216
217
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
218
219
220
221
222

    # Both a and b are scalars
    return a == b


223
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
224
225
"""
A dictionary containing nested tensors which have been batched via
226
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
227
228
229
"""


230
231
232
233
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""

252
253
254
255
256
257
258
259
260
261
262
263
264
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

265

266
@dataclass
267
class MultiModalFieldElem:
268
269
    """
    Represents a keyword argument corresponding to a multi-modal item
270
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
271
272
273
274
275
276
277
278
279
280
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
281
282
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
283
284
285
    i.e. the name of the keyword argument to be passed to the model.
    """

286
    data: NestedTensors
287
    """
288
289
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
290
    i.e. the value of the keyword argument to be passed to the model.
291
292
293

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
294
295
296
297
298
299
300
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
301
302
303
304
305

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

306
307
308
309
310
311
312
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

313
314
315
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
316
            and type(self.field) is type(other.field)
317
        )  # noqa: E721
318
319
320
321


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
322
323
    """
    Defines how to interpret tensor data belonging to a keyword argument in
324
325
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
341
342

    @abstractmethod
343
344
345
346
347
348
349
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
350
351
352
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
353

354
355
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
356
        """
357
358
        raise NotImplementedError

359
    @abstractmethod
360
361
362
363
364
365
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
366
        raise NotImplementedError
367

368
369
370
371
372
373
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
        pin_memory: bool = False,
    ) -> NestedTensors:
374
        """
375
376
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
377

378
379
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
380
381
382
383
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
384

385
386
        batch = [elem.data for elem in elems]
        return self._reduce_data(batch, pin_memory=pin_memory)
387
388
389
390
391


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
392
    Info:
393
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
394
395
    """

396
397
398
399
400
401
402
403
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
404

405
406
407
408
409
410
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
411
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
412
            batch = cast(list[torch.Tensor], batch)
413
414
415
416
417
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
418
            first_shape = batch[0].shape
419
            if all(elem.shape == first_shape for elem in batch):
420
421
422
423
424
425
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
426
                return torch.stack(batch, out=out)
427
428
429
430
431
432
433

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
434
    Info:
435
436
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
437
    """
438

439
    slices: Sequence[slice] | Sequence[Sequence[slice]]
440
    dim: int = 0
441

442
    def build_elems(
443
        self,
444
445
446
447
448
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
449
        if not is_list_of(self.slices, slice, check="all"):
450
            assert isinstance(data, torch.Tensor), (
451
                "torch.Tensor is required for multiple slices"
452
            )
453
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
454

455
456
457
458
459
460
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
461
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
462
            batch = cast(list[torch.Tensor], batch)
463
464
465
466
467
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
468

469
470
471
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
472
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
473

474
            first_shape = _shape_before_after(batch[0])
475

476
477
478
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
479
480
481
482
483
484
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
485
                return torch.concat(batch, dim=self.dim, out=out)
486
487

        assert self.dim == 0, "dim == 0 is required for nested list"
488
        return [e for elem in batch for e in elem]
489
490


491
492
493
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
494
    Info:
495
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
496
    """
497

498
499
500
501
502
503
504
505
506
507
508
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

509
510
511
512
513
514
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
515
516
517
        return batch[0]


518
519
520
class MultiModalFieldConfig:
    @staticmethod
    def batched(modality: str):
521
522
523
524
525
526
527
528
529
530
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

531
532
533
534
535
536
537
538
539
540
541
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
542
        """
543
        return MultiModalFieldConfig(
544
            field=MultiModalBatchedField(),
545
546
547
548
            modality=modality,
        )

    @staticmethod
549
550
    def flat(
        modality: str,
551
        slices: Sequence[slice] | Sequence[Sequence[slice]],
552
553
        dim: int = 0,
    ):
554
555
556
557
558
559
560
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
561
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
562
                slices (dim>0) that is used to extract the data corresponding
563
564
                to it.
            dim: The dimension to extract data, default to 0.
565
566
567

        Example:

568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
597
        """
598
        return MultiModalFieldConfig(
599
            field=MultiModalFlatField(slices=slices, dim=dim),
600
601
602
            modality=modality,
        )

603
    @staticmethod
604
    def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0):
605
606
607
608
609
610
611
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
612
613
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
614
            dim: The dimension to slice, default to 0.
615
616
617

        Example:

618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
633
            size_per_item: [3, 4, 2]
634
635
636
637
638
639
640
641
642
643
644
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

645
        Info:
646
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
647
648
        """

649
        if size_per_item.ndim != 1:
650
651
652
653
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
654

655
        slice_idxs = [0, *accumulate(size_per_item)]
656
657
658
659
660
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
661

662
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
663

664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

679
680
681
        ```
        Given:
            batch_size: 4
682

683
684
        Input:
            Data: [XYZ]
685

686
687
688
689
690
691
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
692
693
694
695
696
697
698
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
699
700
        super().__init__()

701
        self.field = field
702
        self.modality = modality
703

704
705
706
    def __repr__(self) -> str:
        return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"

707
    def build_elems(
708
709
710
        self,
        key: str,
        batch: NestedTensors,
711
    ) -> Sequence[MultiModalFieldElem]:
712
        return self.field.build_elems(self.modality, key, batch)
713
714


715
716
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
717
718
719
720
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
721
    """
722

723
    @staticmethod
724
    def dummy(modality: str, nbytes: int = 1):
725
726
727
728
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
729
            data=torch.empty(nbytes, dtype=torch.uint8),
730
731
732
733
            field=MultiModalSharedField(1),
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

734
735
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
736
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
737

738
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
739
740
        super().__init__(data)

741
        modalities = {elem.modality for elem in self.values()}
742
        assert len(modalities) == 1, f"Found different modalities={modalities}"
743
744
745
746
747
748
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

749
    def get_data(self) -> dict[str, NestedTensors]:
750
        return {key: elem.data for key, elem in self.items()}
751
752


753
754
755
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
756
    MultiModalKwargsItem | None,
757
758
759
760
761
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
762
    """
763
764
765
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
766
767
    """

768
769
    @staticmethod
    def from_hf_inputs(
770
        hf_inputs: "BatchFeature",
771
772
773
774
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
793
794
                    f"Found: {batch_sizes=}"
                )
795
796
797
798
799
800

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

801
        return MultiModalKwargsItems.from_seq(items)
802

803
804
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
805
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
806
        return MultiModalKwargsItems(items_by_modality)
807

808
    def __getitem__(self, modality: str) -> Sequence[_I]:
809
        if modality not in self:
810
811
812
813
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
814

815
        return super().__getitem__(modality)  # type: ignore[return-value]
816

817
818
819
820
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
821
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
822
823
824

        return self  # type: ignore[return-value]

825
826
    def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
827
828
829
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
830
831
832
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
833

834
835
836
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

837
838
839
840
841
842
        return MultiModalKwargs(
            {
                key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
                for key, elems in elems_by_key.items()
            }
        )
843

844

845
846
847
848
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
849
850


851
852
853
854
855
856
857
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
858
859
860
861
862
863
    @deprecated(
        "`MultiModalKwargs.from_hf_inputs` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_hf_inputs` and "
        "access the tensor data using `.get_data()`."
    )
864
865
866
867
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
868
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data()
869
870

    @staticmethod
871
872
873
874
875
876
    @deprecated(
        "`MultiModalKwargs.from_items` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_seq` and "
        "access the tensor data using `.get_data()`."
    )
877
878
879
880
881
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
882
        return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
883

884
    @staticmethod
885
886
887
    def _try_stack(
        nested_tensors: NestedTensors, pin_memory: bool = False
    ) -> NestedTensors:
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

904
        stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors]
905
906
907
908
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

909
        tensors_ = cast(list[torch.Tensor], stacked)
910
911
912
913
914
915
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

916
917
918
919
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

920
921
922
923
924
925
926
        outputs = torch.empty(
            len(tensors_),
            *tensors_[0].shape,
            dtype=tensors_[0].dtype,
            device=tensors_[0].device,
            pin_memory=pin_memory,
        )
927
        return torch.stack(tensors_, out=outputs)
928
929

    @staticmethod
930
931
932
    def batch(
        inputs_list: list["MultiModalKwargs"], pin_memory: bool = False
    ) -> BatchedTensorInputs:
933
934
935
936
937
938
939
940
941
942
943
944
945
946
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
947
        item_lists = defaultdict[str, list[NestedTensors]](list)
948
949
950
951
952
953

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
954
            k: MultiModalKwargs._try_stack(item_list, pin_memory)
955
956
957
958
959
960
961
962
963
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
964
        return json_map_leaves(
965
            lambda x: x.to(device=device, non_blocking=True),
966
            batched_inputs,
967
968
        )

969
    def __getitem__(self, key: str):
970
        if key not in self:
971
972
973
974
            raise KeyError(
                f"Keyword argument {key!r} not found. "
                f"Available keys: {set(self.keys())}"
            )
975
976

        return super().__getitem__(key)
977

978
979
980
981
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

982
983
984
985
986
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
987

988
        return True
989

990

991
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
992
"""
993
A dictionary containing placeholder ranges for each modality.
994
995
996
"""


997
class MultiModalInputs(TypedDict):
998
    """
999
    Represents the outputs of
1000
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
1001
1002
1003
1004
1005
1006
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

1007
    prompt_token_ids: list[int]
1008
1009
    """The processed token IDs which includes placeholder tokens."""

1010
    mm_kwargs: MultiModalKwargsOptionalItems
1011
1012
    """Keyword arguments to be directly passed to the model after batching."""

1013
    mm_hashes: "MultiModalHashes"
1014
1015
    """The hashes of the multi-modal data."""

1016
    mm_placeholders: "MultiModalPlaceholderDict"
1017
1018
    """
    For each modality, information about the placeholder tokens in
1019
    `prompt_token_ids`.
1020
    """
1021

1022
1023
1024
1025
1026
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1027
1028
1029

class MultiModalEncDecInputs(MultiModalInputs):
    """
1030
1031
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1032
1033
1034
1035
1036
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""