inputs.py 30.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
23
from typing_extensions import NotRequired, TypeVar, deprecated
24

25
from vllm.utils.collection_utils import full_groupby, is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
if TYPE_CHECKING:
30
31
32
33
34
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

35
36
    from .processing import MultiModalHashes

37
38
else:
    torch = LazyLoader("torch", globals(), "torch")
39

40
41
_T = TypeVar("_T")

42
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
43
"""
44
A `transformers.image_utils.ImageInput` representing a single image
45
item, which can be passed to a HuggingFace `ImageProcessor`.
46
47
"""

48
49
50
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
51
"""
52
A `transformers.image_utils.VideoInput` representing a single video
53
item, which can be passed to a HuggingFace `VideoProcessor`.
54
55
"""

56
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
57
"""
58
Represents a single audio
59
item, which can be passed to a HuggingFace `AudioProcessor`.
60
61
"""

62
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
63
"""
64
A `transformers.image_utils.ImageInput` representing a single image
65
item, which can be passed to a HuggingFace `ImageProcessor`.
66
67
68
69
70
71

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

72
73
74
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
75
"""
76
77
78
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
79
80
81
82
83
84

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

85
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
86
87
"""
Represents a single audio
88
item, which can be passed to a HuggingFace `AudioProcessor`.
89
90
91
92
93
94
95
96
97
98

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

99
ModalityData: TypeAlias = _T | list[_T | None] | None
100
"""
101
102
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
103
104

The number of data items allowed per modality is restricted by
105
`--limit-mm-per-prompt`.
106
107
108
109
110
111
112
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

113
    image: ModalityData[ImageItem]
114
115
    """The input image(s)."""

116
    video: ModalityData[VideoItem]
117
118
    """The input video(s)."""

119
    audio: ModalityData[AudioItem]
120
121
122
    """The input audio(s)."""


123
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
124
125
"""
A dictionary containing an entry for each modality type to input.
126

127
128
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
129
130
"""

131
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
132
133
134
135
136
137
138
139
140
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

141

142
143
@dataclass(frozen=True)
class PlaceholderRange:
144
145
146
    """
    Placeholder location information for multi-modal data.

147
148
    Example:

149
    Prompt: `AAAA BBBB What is in these images?`
150

151
    Images A and B will have:
152

153
154
155
156
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
157
158
159
160
161
162
163
164
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

165
    is_embed: Optional["torch.Tensor"] = None
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

190

191
192
193
194
195
196
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
197
198
199
200
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

201
202

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
203
204
    """Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
205
    if isinstance(a, torch.Tensor):
206
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
207
    elif isinstance(b, torch.Tensor):
208
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
209
210

    if isinstance(a, list):
211
212
213
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
214
    if isinstance(b, list):
215
216
217
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
218
219
220
221
222

    # Both a and b are scalars
    return a == b


223
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
224
225
"""
A dictionary containing nested tensors which have been batched via
226
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
227
228
229
"""


230
231
232
233
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""


253
@dataclass
254
class MultiModalFieldElem:
255
256
    """
    Represents a keyword argument corresponding to a multi-modal item
257
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
258
259
260
261
262
263
264
265
266
267
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
268
269
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
270
271
272
    i.e. the name of the keyword argument to be passed to the model.
    """

273
    data: NestedTensors
274
    """
275
276
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
277
    i.e. the value of the keyword argument to be passed to the model.
278
279
280

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
281
282
283
284
285
286
287
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
288
289
290
291
292

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

293
294
295
296
297
298
299
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

300
301
302
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
303
            and type(self.field) is type(other.field)
304
        )  # noqa: E721
305
306
307
308


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
309
310
    """
    Defines how to interpret tensor data belonging to a keyword argument in
311
312
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
328
329

    @abstractmethod
330
331
332
333
334
335
336
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
337
338
339
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
340

341
342
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
343
        """
344
345
        raise NotImplementedError

346
    @abstractmethod
347
348
349
350
351
352
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
353
        raise NotImplementedError
354

355
356
357
358
359
360
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
        pin_memory: bool = False,
    ) -> NestedTensors:
361
        """
362
363
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
364

365
366
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
367
368
369
370
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
371

372
373
        batch = [elem.data for elem in elems]
        return self._reduce_data(batch, pin_memory=pin_memory)
374
375
376
377
378


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
379
    Info:
380
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
381
382
    """

383
384
385
386
387
388
389
390
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
391

392
393
394
395
396
397
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
398
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
399
            batch = cast(list[torch.Tensor], batch)
400
401
402
403
404
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
405
            first_shape = batch[0].shape
406
            if all(elem.shape == first_shape for elem in batch):
407
408
409
410
411
412
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
413
                return torch.stack(batch, out=out)
414
415
416
417
418
419
420

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
421
    Info:
422
423
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
424
    """
425

426
    slices: Sequence[slice] | Sequence[Sequence[slice]]
427
    dim: int = 0
428

429
    def build_elems(
430
        self,
431
432
433
434
435
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
436
        if not is_list_of(self.slices, slice, check="all"):
437
            assert isinstance(data, torch.Tensor), (
438
                "torch.Tensor is required for multiple slices"
439
            )
440
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
441

442
443
444
445
446
447
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
448
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
449
            batch = cast(list[torch.Tensor], batch)
450
451
452
453
454
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
455

456
457
458
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
459
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
460

461
            first_shape = _shape_before_after(batch[0])
462

463
464
465
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
466
467
468
469
470
471
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
472
                return torch.concat(batch, dim=self.dim, out=out)
473
474

        assert self.dim == 0, "dim == 0 is required for nested list"
475
        return [e for elem in batch for e in elem]
476
477


478
479
480
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
481
    Info:
482
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
483
    """
484

485
486
487
488
489
490
491
492
493
494
495
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

496
497
498
499
500
501
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
502
503
504
        return batch[0]


505
506
507
class MultiModalFieldConfig:
    @staticmethod
    def batched(modality: str):
508
509
510
511
512
513
514
515
516
517
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

518
519
520
521
522
523
524
525
526
527
528
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
529
        """
530
        return MultiModalFieldConfig(
531
            field=MultiModalBatchedField(),
532
533
534
535
            modality=modality,
        )

    @staticmethod
536
537
    def flat(
        modality: str,
538
        slices: Sequence[slice] | Sequence[Sequence[slice]],
539
540
        dim: int = 0,
    ):
541
542
543
544
545
546
547
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
548
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
549
                slices (dim>0) that is used to extract the data corresponding
550
551
                to it.
            dim: The dimension to extract data, default to 0.
552
553
554

        Example:

555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
584
        """
585
        return MultiModalFieldConfig(
586
            field=MultiModalFlatField(slices=slices, dim=dim),
587
588
589
            modality=modality,
        )

590
    @staticmethod
591
    def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0):
592
593
594
595
596
597
598
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
599
600
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
601
            dim: The dimension to slice, default to 0.
602
603
604

        Example:

605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
620
            size_per_item: [3, 4, 2]
621
622
623
624
625
626
627
628
629
630
631
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

632
        Info:
633
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
634
635
        """

636
        if size_per_item.ndim != 1:
637
638
639
640
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
641

642
        slice_idxs = [0, *accumulate(size_per_item)]
643
644
645
646
647
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
648

649
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
650

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

666
667
668
        ```
        Given:
            batch_size: 4
669

670
671
        Input:
            Data: [XYZ]
672

673
674
675
676
677
678
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
679
680
681
682
683
684
685
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
686
687
        super().__init__()

688
        self.field = field
689
        self.modality = modality
690

691
692
693
    def __repr__(self) -> str:
        return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"

694
    def build_elems(
695
696
697
        self,
        key: str,
        batch: NestedTensors,
698
    ) -> Sequence[MultiModalFieldElem]:
699
        return self.field.build_elems(self.modality, key, batch)
700
701


702
703
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
704
705
706
707
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
708
    """
709

710
711
712
713
714
715
716
717
718
719
720
    @staticmethod
    def dummy(modality: str):
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
            data=torch.empty(1),
            field=MultiModalSharedField(1),
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

721
722
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
723
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
724

725
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
726
727
        super().__init__(data)

728
        modalities = {elem.modality for elem in self.values()}
729
        assert len(modalities) == 1, f"Found different modalities={modalities}"
730
731
732
733
734
735
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

736
    def get_data(self) -> dict[str, NestedTensors]:
737
        return {key: elem.data for key, elem in self.items()}
738
739


740
741
742
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
743
    MultiModalKwargsItem | None,
744
745
746
747
748
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
749
    """
750
751
752
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
753
754
    """

755
756
    @staticmethod
    def from_hf_inputs(
757
        hf_inputs: "BatchFeature",
758
759
760
761
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
780
781
                    f"Found: {batch_sizes=}"
                )
782
783
784
785
786
787

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

788
        return MultiModalKwargsItems.from_seq(items)
789

790
791
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
792
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
793
        return MultiModalKwargsItems(items_by_modality)
794

795
    def __getitem__(self, modality: str) -> Sequence[_I]:
796
        if modality not in self:
797
798
799
800
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
801

802
        return super().__getitem__(modality)  # type: ignore[return-value]
803

804
805
806
807
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
808
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
809
810
811

        return self  # type: ignore[return-value]

812
813
    def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
814
815
816
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
817
818
819
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
820

821
822
823
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

824
825
826
827
828
829
        return MultiModalKwargs(
            {
                key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
                for key, elems in elems_by_key.items()
            }
        )
830

831

832
833
834
835
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
836
837


838
839
840
841
842
843
844
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
845
846
847
848
849
850
    @deprecated(
        "`MultiModalKwargs.from_hf_inputs` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_hf_inputs` and "
        "access the tensor data using `.get_data()`."
    )
851
852
853
854
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
855
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data()
856
857

    @staticmethod
858
859
860
861
862
863
    @deprecated(
        "`MultiModalKwargs.from_items` is deprecated and "
        "will be removed in v0.13. "
        "Please use `MultiModalKwargsItems.from_seq` and "
        "access the tensor data using `.get_data()`."
    )
864
865
866
867
868
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
869
        return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
870

871
    @staticmethod
872
873
874
    def _try_stack(
        nested_tensors: NestedTensors, pin_memory: bool = False
    ) -> NestedTensors:
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

891
        stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors]
892
893
894
895
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

896
        tensors_ = cast(list[torch.Tensor], stacked)
897
898
899
900
901
902
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

903
904
905
906
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

907
908
909
910
911
912
913
        outputs = torch.empty(
            len(tensors_),
            *tensors_[0].shape,
            dtype=tensors_[0].dtype,
            device=tensors_[0].device,
            pin_memory=pin_memory,
        )
914
        return torch.stack(tensors_, out=outputs)
915
916

    @staticmethod
917
918
919
    def batch(
        inputs_list: list["MultiModalKwargs"], pin_memory: bool = False
    ) -> BatchedTensorInputs:
920
921
922
923
924
925
926
927
928
929
930
931
932
933
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
934
        item_lists = defaultdict[str, list[NestedTensors]](list)
935
936
937
938
939
940

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
941
            k: MultiModalKwargs._try_stack(item_list, pin_memory)
942
943
944
945
946
947
948
949
950
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
951
        return json_map_leaves(
952
            lambda x: x.to(device=device, non_blocking=True),
953
            batched_inputs,
954
955
        )

956
    def __getitem__(self, key: str):
957
        if key not in self:
958
959
960
961
            raise KeyError(
                f"Keyword argument {key!r} not found. "
                f"Available keys: {set(self.keys())}"
            )
962
963

        return super().__getitem__(key)
964

965
966
967
968
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

969
970
971
972
973
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
974

975
        return True
976

977

978
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
979
"""
980
A dictionary containing placeholder ranges for each modality.
981
982
983
"""


984
class MultiModalInputs(TypedDict):
985
    """
986
    Represents the outputs of
987
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
988
989
990
991
992
993
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

994
    prompt_token_ids: list[int]
995
996
    """The processed token IDs which includes placeholder tokens."""

997
    mm_kwargs: MultiModalKwargsOptionalItems
998
999
    """Keyword arguments to be directly passed to the model after batching."""

1000
    mm_hashes: "MultiModalHashes"
1001
1002
    """The hashes of the multi-modal data."""

1003
    mm_placeholders: "MultiModalPlaceholderDict"
1004
1005
    """
    For each modality, information about the placeholder tokens in
1006
    `prompt_token_ids`.
1007
    """
1008

1009
1010
1011
1012
1013
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1014
1015
1016

class MultiModalEncDecInputs(MultiModalInputs):
    """
1017
1018
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1019
1020
1021
1022
1023
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""