inputs.py 30.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union,
                    cast, final)
12
13

import numpy as np
14
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
15

16
from vllm.utils import LazyLoader, full_groupby, is_list_of
17
from vllm.utils.jsontree import JSONTree, json_map_leaves
18

19
if TYPE_CHECKING:
20
21
22
23
24
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

25
    from .hasher import MultiModalHashDict
26
27
else:
    torch = LazyLoader("torch", globals(), "torch")
28

29
30
_T = TypeVar("_T")

31
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
32
"""
33
A `transformers.image_utils.ImageInput` representing a single image
34
item, which can be passed to a HuggingFace `ImageProcessor`.
35
36
"""

37
38
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
                               list[np.ndarray], list["torch.Tensor"]]
39
"""
40
A `transformers.image_utils.VideoInput` representing a single video
41
item, which can be passed to a HuggingFace `VideoProcessor`.
42
43
"""

44
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
45
"""
46
Represents a single audio
47
item, which can be passed to a HuggingFace `AudioProcessor`.
48
49
"""

50
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
51
"""
52
A `transformers.image_utils.ImageInput` representing a single image
53
item, which can be passed to a HuggingFace `ImageProcessor`.
54
55
56
57
58
59

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

60
61
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor",
                             tuple[HfVideoItem, dict[str, Any]]]
62
"""
63
64
65
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
66
67
68
69
70
71
72

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
73
                             "torch.Tensor"]
74
75
"""
Represents a single audio
76
item, which can be passed to a HuggingFace `AudioProcessor`.
77
78
79
80
81
82
83
84
85
86
87

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

ModalityData: TypeAlias = Union[_T, list[_T]]
88
89
90
91
"""
Either a single data item, or a list of data items.

The number of data items allowed per modality is restricted by
92
`--limit-mm-per-prompt`.
93
94
95
96
97
98
99
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

100
    image: ModalityData[ImageItem]
101
102
    """The input image(s)."""

103
    video: ModalityData[VideoItem]
104
105
    """The input video(s)."""

106
    audio: ModalityData[AudioItem]
107
108
109
    """The input audio(s)."""


110
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
111
112
"""
A dictionary containing an entry for each modality type to input.
113

114
115
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
116
117
118
"""


119
120
@dataclass(frozen=True)
class PlaceholderRange:
121
122
123
    """
    Placeholder location information for multi-modal data.

124
125
    Example:

126
    Prompt: `AAAA BBBB What is in these images?`
127

128
    Images A and B will have:
129

130
131
132
133
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
134
135
136
137
138
139
140
141
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

142
    is_embed: Optional["torch.Tensor"] = None
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

167

168
169
NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
                                 "torch.Tensor", tuple["torch.Tensor", ...]]
170
171
172
173
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

174
175

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
176
177
    """Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
178
    if isinstance(a, torch.Tensor):
179
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
180
    elif isinstance(b, torch.Tensor):
181
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
182
183
184
185
186
187
188
189
190
191
192
193
194

    if isinstance(a, list):
        return (isinstance(b, list)
                and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
    if isinstance(b, list):
        return (isinstance(a, list)
                and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))

    # Both a and b are scalars
    return a == b


BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
195
196
"""
A dictionary containing nested tensors which have been batched via
197
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
198
199
200
"""


201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
    
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""


224
@dataclass
225
class MultiModalFieldElem:
226
227
    """
    Represents a keyword argument corresponding to a multi-modal item
228
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
229
230
231
232
233
234
235
236
237
238
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
239
240
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
241
242
243
    i.e. the name of the keyword argument to be passed to the model.
    """

244
    data: NestedTensors
245
    """
246
247
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
248
    i.e. the value of the keyword argument to be passed to the model.
249
250
251

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
252
253
254
255
256
257
258
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
259
260
261
262
263

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

264
265
266
267
268
269
270
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

271
        return ((self.modality, self.key) == (other.modality, other.key)
272
                and data_equal
273
                and type(self.field) == type(other.field))  # noqa: E721
274
275
276
277


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
278
279
    """
    Defines how to interpret tensor data belonging to a keyword argument in
280
281
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
297
298

    @abstractmethod
299
300
301
302
303
304
305
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
306
307
308
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
309

310
311
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
312
        """
313
314
        raise NotImplementedError

315
    @abstractmethod
316
317
318
319
320
321
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
322
        raise NotImplementedError
323

324
325
326
327
328
329
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
        pin_memory: bool = False,
    ) -> NestedTensors:
330
        """
331
332
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
333

334
335
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
336
337
338
339
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
340

341
342
        batch = [elem.data for elem in elems]
        return self._reduce_data(batch, pin_memory=pin_memory)
343
344
345
346
347


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
348
    Info:
349
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
350
351
    """

352
353
354
355
356
357
358
359
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
360

361
362
363
364
365
366
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
367
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
368
369
370
371
372
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
373
            first_shape = batch[0].shape
374
            if all(elem.shape == first_shape for elem in batch):
375
376
377
378
379
                out = torch.empty((len(batch), *batch[0].shape),
                                  dtype=batch[0].dtype,
                                  device=batch[0].device,
                                  pin_memory=pin_memory)
                return torch.stack(batch, out=out)
380
381
382
383
384
385
386

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
387
    Info:
388
389
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
390
    """
391
392
    slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
    dim: int = 0
393

394
    def build_elems(
395
        self,
396
397
398
399
400
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
401
402
403
404
        if not is_list_of(self.slices, slice, check="all"):
            assert isinstance(data, torch.Tensor), \
                "torch.Tensor is required for multiple slices"
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
405

406
407
408
409
410
411
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
412
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
413
414
415
416
417
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
418

419
420
421
422
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
                return tensor.shape[:dim], tensor.shape[dim + 1:]
423

424
            first_shape = _shape_before_after(batch[0])
425

426
427
428
429
430
431
432
433
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
                out = torch.empty((*shape_before, shape_concat, *shape_after),
                                  dtype=batch[0].dtype,
                                  device=batch[0].device,
                                  pin_memory=pin_memory)
                return torch.concat(batch, dim=self.dim, out=out)
434
435

        assert self.dim == 0, "dim == 0 is required for nested list"
436
        return [e for elem in batch for e in elem]
437
438


439
440
441
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
442
    Info:
443
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
444
445
446
447
448
449
450
451
452
453
454
455
    """
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

456
457
458
459
460
461
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
462
463
464
        return batch[0]


465
466
467
468
class MultiModalFieldConfig:

    @staticmethod
    def batched(modality: str):
469
470
471
472
473
474
475
476
477
478
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

479
480
481
482
483
484
485
486
487
488
489
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
490
        """
491
        return MultiModalFieldConfig(
492
            field=MultiModalBatchedField(),
493
494
495
496
            modality=modality,
        )

    @staticmethod
497
498
499
    def flat(modality: str,
             slices: Union[Sequence[slice], Sequence[Sequence[slice]]],
             dim: int = 0):
500
501
502
503
504
505
506
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
507
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
508
                slices (dim>0) that is used to extract the data corresponding
509
510
                to it.
            dim: The dimension to extract data, default to 0.
511
512
513

        Example:

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
543
        """
544
        return MultiModalFieldConfig(
545
            field=MultiModalFlatField(slices=slices, dim=dim),
546
547
548
            modality=modality,
        )

549
    @staticmethod
550
    def flat_from_sizes(modality: str,
551
                        size_per_item: "torch.Tensor",
552
                        dim: int = 0):
553
554
555
556
557
558
559
560
561
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            slices: For each multi-modal item, the size of the slice that
                is used to extract the data corresponding to it.
562
            dim: The dimension to slice, default to 0.
563
564
565

        Example:

566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [3, 4, 2]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

593
        Info:
594
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
595
596
        """

597
598
599
600
        if size_per_item.ndim != 1:
            raise ValueError("size_per_item should be a 1-D tensor, "
                             f"but found shape: {size_per_item.shape}")

601
        slice_idxs = [0, *accumulate(size_per_item)]
602
603
604
        slices = [(slice(None, None, None), ) * dim +
                  (slice(slice_idxs[i], slice_idxs[i + 1]), )
                  for i in range(len(size_per_item))]
605

606
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
607

608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

623
624
625
        ```
        Given:
            batch_size: 4
626

627
628
        Input:
            Data: [XYZ]
629

630
631
632
633
634
635
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
636
637
638
639
640
641
642
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
643
644
        super().__init__()

645
        self.field = field
646
        self.modality = modality
647

648
    def build_elems(
649
650
651
        self,
        key: str,
        batch: NestedTensors,
652
    ) -> Sequence[MultiModalFieldElem]:
653
        return self.field.build_elems(self.modality, key, batch)
654
655


656
657
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
658
659
660
661
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
662
    """
663

664
665
666
667
668
669
670
671
672
673
674
    @staticmethod
    def dummy(modality: str):
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
            data=torch.empty(1),
            field=MultiModalSharedField(1),
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

675
676
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
677
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
678

679
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
680
681
        super().__init__(data)

682
        modalities = {elem.modality for elem in self.values()}
683
        assert len(modalities) == 1, f"Found different modalities={modalities}"
684
685
686
687
688
689
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

690
    def get_data(self) -> dict[str, NestedTensors]:
691
        return {key: elem.data for key, elem in self.items()}
692
693


694
695
696
697
698
699
700
701
702
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
    Optional[MultiModalKwargsItem],
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
703
    """
704
705
706
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
707
708
    """

709
710
    @staticmethod
    def from_hf_inputs(
711
        hf_inputs: "BatchFeature",
712
713
714
715
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
                    f"Found: {batch_sizes=}")

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

741
        return MultiModalKwargsItems.from_seq(items)
742

743
744
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
745
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
746
        return MultiModalKwargsItems(items_by_modality)
747

748
    def __getitem__(self, modality: str) -> Sequence[_I]:
749
750
751
752
        if modality not in self:
            raise KeyError(f"Modality {modality!r} not found. "
                           f"Available modalities: {set(self.keys())}")

753
        return super().__getitem__(modality)  # type: ignore[return-value]
754
755
756

    def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
757
758
759
760
761
762
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
                    raise RuntimeError("Cannot build data from empty "
                                       f"mm_items[{modality}][{i}]")

763
764
765
766
767
768
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

        return MultiModalKwargs({
            key:
            elems[0].field.reduce_data(elems, pin_memory=pin_memory)
769
            for key, elems in elems_by_key.items()
770
        })
771

772

773
774
775
776
777
778
MultiModalKwargsOptionalItems: TypeAlias = Union[
    MultiModalKwargsItems[MultiModalKwargsItem],
    MultiModalKwargsItems[Optional[MultiModalKwargsItem]],
]


779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    [`torch.nn.Module.forward`][].
    """

    @staticmethod
    @deprecated("`MultiModalKwargs.from_hf_inputs` is deprecated and "
                "will be removed in v0.13. "
                "Please use `MultiModalKwargsItems.from_hf_inputs` and "
                "access the tensor data using `.get_data()`.")
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key) \
            .get_data()

    @staticmethod
    @deprecated("`MultiModalKwargs.from_items` is deprecated and "
                "will be removed in v0.13. "
                "Please use `MultiModalKwargsItems.from_seq` and "
                "access the tensor data using `.get_data()`.")
    def from_items(
        items: Sequence[MultiModalKwargsItem],
        *,
        pin_memory: bool = False,
    ):
        return MultiModalKwargsItems.from_seq(items) \
            .get_data(pin_memory=pin_memory)
809

810
    @staticmethod
811
812
    def _try_stack(nested_tensors: NestedTensors,
                   pin_memory: bool = False) -> NestedTensors:
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

829
830
831
        stacked = [
            MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors
        ]
832
833
834
835
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

836
        tensors_ = cast(list[torch.Tensor], stacked)
837
838
839
840
841
842
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

843
844
845
846
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

847
848
849
850
851
852
        outputs = torch.empty(len(tensors_),
                              *tensors_[0].shape,
                              dtype=tensors_[0].dtype,
                              device=tensors_[0].device,
                              pin_memory=pin_memory)
        return torch.stack(tensors_, out=outputs)
853
854

    @staticmethod
855
856
    def batch(inputs_list: list["MultiModalKwargs"],
              pin_memory: bool = False) -> BatchedTensorInputs:
857
858
859
860
861
862
863
864
865
866
867
868
869
870
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
871
        item_lists = defaultdict[str, list[NestedTensors]](list)
872
873
874
875
876
877

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
878
            k: MultiModalKwargs._try_stack(item_list, pin_memory)
879
880
881
882
883
884
885
886
887
888
889
890
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
891
            lambda x: x.to(device=device, non_blocking=True),
892
893
894
895
896
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)

897
    def __getitem__(self, key: str):
898
899
900
901
902
        if key not in self:
            raise KeyError(f"Keyword argument {key!r} not found. "
                           f"Available keys: {set(self.keys())}")

        return super().__getitem__(key)
903

904
905
906
907
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

908
909
910
911
912
        for k in self:
            if k not in other:
                return False
            if not nested_tensors_equal(self[k], other[k]):
                return False
913

914
        return True
915

916

917
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
918
"""
919
A dictionary containing placeholder ranges for each modality.
920
921
922
"""


923
class MultiModalInputs(TypedDict):
924
    """
925
    Represents the outputs of
926
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
927
928
929
930
931
932
933
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

    prompt: str
934
    """The processed prompt text."""
935

936
    prompt_token_ids: list[int]
937
938
    """The processed token IDs which includes placeholder tokens."""

939
    mm_kwargs: MultiModalKwargsOptionalItems
940
941
    """Keyword arguments to be directly passed to the model after batching."""

942
    mm_hashes: "MultiModalHashDict"
943
944
    """The hashes of the multi-modal data."""

945
    mm_placeholders: "MultiModalPlaceholderDict"
946
947
    """
    For each modality, information about the placeholder tokens in
948
    `prompt_token_ids`.
949
    """
950

951
952
953
954
955
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

956
957
958

class MultiModalEncDecInputs(MultiModalInputs):
    """
959
960
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
961
962
963
964
965
966
967
968
    ready to be passed to vLLM internals.
    """

    encoder_prompt: str
    """The processed encoder prompt text."""

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""