inputs.py 27.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
7
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
8
from functools import partial
9
from itertools import accumulate
10
11
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
                    Union, cast, final)
12
13

import numpy as np
14
from typing_extensions import NotRequired, TypeAlias
15

16
from vllm.jsontree import JSONTree, json_map_leaves
17
from vllm.utils import LazyLoader, full_groupby, is_list_of
18

19
if TYPE_CHECKING:
20
21
22
23
24
    import torch
    import torch.types
    from PIL.Image import Image
    from transformers.feature_extraction_utils import BatchFeature

25
    from .hasher import MultiModalHashDict
26
27
else:
    torch = LazyLoader("torch", globals(), "torch")
28

29
30
_T = TypeVar("_T")

31
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
32
"""
33
A `transformers.image_utils.ImageInput` representing a single image
34
item, which can be passed to a HuggingFace `ImageProcessor`.
35
36
"""

37
38
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
                               list[np.ndarray], list["torch.Tensor"]]
39
"""
40
A `transformers.image_utils.VideoInput` representing a single video
41
item, which can be passed to a HuggingFace `VideoProcessor`.
42
43
"""

44
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
45
"""
46
Represents a single audio
47
item, which can be passed to a HuggingFace `AudioProcessor`.
48
49
"""

50
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
51
"""
52
A `transformers.image_utils.ImageInput` representing a single image
53
item, which can be passed to a HuggingFace `ImageProcessor`.
54
55
56
57
58
59

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

60
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"]
61
"""
62
A `transformers.image_utils.VideoInput` representing a single video
63
item, which can be passed to a HuggingFace `VideoProcessor`.
64
65
66
67
68
69
70

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
71
                             "torch.Tensor"]
72
73
"""
Represents a single audio
74
item, which can be passed to a HuggingFace `AudioProcessor`.
75
76
77
78
79
80
81
82
83
84
85

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

ModalityData: TypeAlias = Union[_T, list[_T]]
86
87
88
89
"""
Either a single data item, or a list of data items.

The number of data items allowed per modality is restricted by
90
`--limit-mm-per-prompt`.
91
92
93
94
95
96
97
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

98
    image: ModalityData[ImageItem]
99
100
    """The input image(s)."""

101
    video: ModalityData[VideoItem]
102
103
    """The input video(s)."""

104
    audio: ModalityData[AudioItem]
105
106
107
    """The input audio(s)."""


108
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
109
110
"""
A dictionary containing an entry for each modality type to input.
111

112
113
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
114
115
116
"""


117
118
@dataclass(frozen=True)
class PlaceholderRange:
119
120
121
    """
    Placeholder location information for multi-modal data.

122
123
    Example:

124
    Prompt: `AAAA BBBB What is in these images?`
125

126
    Images A and B will have:
127

128
129
130
131
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
132
133
134
135
136
137
138
139
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

140
    is_embed: Optional["torch.Tensor"] = None
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

    def get_num_embeds(self) -> int:
        if self.is_embed is None:
            return self.length

        return int(self.is_embed.sum().item())

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

165

166
167
NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
                                 "torch.Tensor", tuple["torch.Tensor", ...]]
168
169
170
171
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

172
173

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
174
175
    """Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
176
    if isinstance(a, torch.Tensor):
177
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
178
    elif isinstance(b, torch.Tensor):
179
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
180
181
182
183
184
185
186
187
188
189
190
191
192

    if isinstance(a, list):
        return (isinstance(b, list)
                and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
    if isinstance(b, list):
        return (isinstance(a, list)
                and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))

    # Both a and b are scalars
    return a == b


BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
193
194
"""
A dictionary containing nested tensors which have been batched via
195
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
196
197
198
"""


199
@dataclass(frozen=True)
200
class MultiModalFieldElem:
201
202
    """
    Represents a keyword argument corresponding to a multi-modal item
203
    in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
204
205
206
207
208
209
210
211
212
213
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
214
215
    The key of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
216
217
218
    i.e. the name of the keyword argument to be passed to the model.
    """

219
    data: NestedTensors
220
    """
221
222
    The tensor data of this field in
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
223
224
225
226
227
228
229
230
    i.e. the value of the keyword argument to be passed to the model.
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
231
232
233
234
235

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

236
237
238
        return ((self.modality, self.key) == (other.modality, other.key)
                and nested_tensors_equal(self.data, other.data)
                and type(self.field) == type(other.field))  # noqa: E721
239
240
241
242


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
243
244
    """
    Defines how to interpret tensor data belonging to a keyword argument in
245
246
    [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
    multi-modal items, and vice versa.
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    """

    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
262
263

    @abstractmethod
264
265
266
267
268
269
270
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
271
272
273
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
274

275
276
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
277
        """
278
279
        raise NotImplementedError

280
281
282
    @abstractmethod
    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        raise NotImplementedError
283

284
285
    def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
        """
286
287
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
288

289
290
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
291
292
293
294
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
295

296
        return self._reduce_data([item.data for item in elems])
297
298
299
300
301


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
302
    Info:
303
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
304
305
    """

306
307
308
309
310
311
312
313
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
314
315
316

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
317
318
319
320
321
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
322
            first_shape = batch[0].shape
323
            if all(elem.shape == first_shape for elem in batch):
324
325
326
327
328
329
330
331
                return torch.stack(batch)

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
332
    Info:
333
334
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
335
    """
336
337
    slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
    dim: int = 0
338

339
    def build_elems(
340
        self,
341
342
343
344
345
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
346
347
348
349
        if not is_list_of(self.slices, slice, check="all"):
            assert isinstance(data, torch.Tensor), \
                "torch.Tensor is required for multiple slices"
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
350
351
352

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
353
354
355
356
357
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
358

359
360
361
362
363
364
365
366
367
            def _expect_same_shape(tensor: torch.Tensor):
                return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:]

            first_shape = _expect_same_shape(batch[0])

            if all(_expect_same_shape(elem) == first_shape for elem in batch):
                return torch.concat(batch, dim=self.dim)

        assert self.dim == 0, "dim == 0 is required for nested list"
368
        return [e for elem in batch for e in elem]
369
370


371
372
373
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
    """
374
    Info:
375
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    """
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        return batch[0]


392
393
394
395
class MultiModalFieldConfig:

    @staticmethod
    def batched(modality: str):
396
397
398
399
400
401
402
403
404
405
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.

        Example:

406
407
408
409
410
411
412
413
414
415
416
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
417
        """
418
        return MultiModalFieldConfig(
419
            field=MultiModalBatchedField(),
420
421
422
423
            modality=modality,
        )

    @staticmethod
424
425
426
    def flat(modality: str,
             slices: Union[Sequence[slice], Sequence[Sequence[slice]]],
             dim: int = 0):
427
428
429
430
431
432
433
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
434
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
435
                slices (dim>0) that is used to extract the data corresponding
436
437
                to it.
            dim: The dimension to extract data, default to 0.
438
439
440

        Example:

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
470
        """
471
        return MultiModalFieldConfig(
472
            field=MultiModalFlatField(slices=slices, dim=dim),
473
474
475
            modality=modality,
        )

476
    @staticmethod
477
    def flat_from_sizes(modality: str,
478
                        size_per_item: "torch.Tensor",
479
                        dim: int = 0):
480
481
482
483
484
485
486
487
488
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            slices: For each multi-modal item, the size of the slice that
                is used to extract the data corresponding to it.
489
            dim: The dimension to slice, default to 0.
490
491
492

        Example:

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [3, 4, 2]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

520
        Info:
521
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
522
523
        """

524
525
526
527
        if size_per_item.ndim != 1:
            raise ValueError("size_per_item should be a 1-D tensor, "
                             f"but found shape: {size_per_item.shape}")

528
        slice_idxs = [0, *accumulate(size_per_item)]
529
530
531
        slices = [(slice(None, None, None), ) * dim +
                  (slice(slice_idxs[i], slice_idxs[i + 1]), )
                  for i in range(len(size_per_item))]
532

533
        return MultiModalFieldConfig.flat(modality, slices, dim=dim)
534

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
    @staticmethod
    def shared(modality: str, batch_size: int):
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.

        Example:

550
551
552
        ```
        Given:
            batch_size: 4
553

554
555
        Input:
            Data: [XYZ]
556

557
558
559
560
561
562
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
563
564
565
566
567
568
569
        """
        return MultiModalFieldConfig(
            field=MultiModalSharedField(batch_size),
            modality=modality,
        )

    def __init__(self, field: BaseMultiModalField, modality: str) -> None:
570
571
        super().__init__()

572
        self.field = field
573
        self.modality = modality
574

575
    def build_elems(
576
577
578
        self,
        key: str,
        batch: NestedTensors,
579
    ) -> Sequence[MultiModalFieldElem]:
580
        return self.field.build_elems(self.modality, key, batch)
581
582


583
584
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
585
586
587
588
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
589
    """
590

591
592
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
593
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
594

595
596
    @property
    def modality(self) -> str:
597
        modalities = {elem.modality for elem in self.data.values()}
598
599
        assert len(modalities) == 1, f"Found different modalities={modalities}"
        return next(iter(modalities))
600
601


602
603
604
605
606
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
607
    [`torch.nn.Module.forward`][].
608

609
    The metadata `items` enables us to obtain the keyword arguments
610
611
612
613
    corresponding to each data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via
    [`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and
    [`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items].
614
615
    """

616
617
    @staticmethod
    def from_hf_inputs(
618
        hf_inputs: "BatchFeature",
619
620
621
622
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
                    f"Found: {batch_sizes=}")

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

        return MultiModalKwargs.from_items(items)
649
650

    @staticmethod
651
    def from_items(items: Sequence[MultiModalKwargsItem]):
652
653
654
        """Construct a new
        [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
        from multiple items."""
655
656
657
658
659
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
        for item in items:
            for key, elem in item.items():
                elems_by_key[key].append(elem)

660
        data = {
661
            key: elems[0].field.reduce_data(elems)
662
            for key, elems in elems_by_key.items() if len(elems) > 0
663
664
        }

665
        return MultiModalKwargs(data, items=items)
666
667
668
669
670

    def __init__(
        self,
        data: Mapping[str, NestedTensors],
        *,
671
        items: Optional[Sequence[MultiModalKwargsItem]] = None,
672
673
674
    ) -> None:
        super().__init__(data)

675
676
        items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
        self._items_by_modality = dict(items_by_modality)
677

678
679
680
    @property
    def modalities(self):
        return self._items_by_modality.keys()
681

682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
    @staticmethod
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

        stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

705
        tensors_ = cast(list[torch.Tensor], stacked)
706
707
708
709
710
711
        if len(tensors_) == 1:
            # An optimization when `tensors_` contains only one tensor:
            # - produce exactly same result as `torch.stack(tensors_)`
            # - will achieve zero-copy if the tensor is contiguous
            return tensors_[0].unsqueeze(0).contiguous()

712
713
714
715
716
717
718
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

        return torch.stack(tensors_)

    @staticmethod
719
    def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
720
721
722
723
724
725
726
727
728
729
730
731
732
733
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
734
        item_lists = defaultdict[str, list[NestedTensors]](list)
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
            k: MultiModalKwargs._try_stack(item_list)
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
754
            lambda x: x.to(device=device, non_blocking=True),
755
756
757
758
759
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)

760
761
762
763
764
765
766
    def __delitem__(self, key: str) -> None:
        super().__delitem__(key)

        for items in self._items_by_modality.values():
            for item in items:
                item.pop(key, None)

767
768
769
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
770
        if self._items_by_modality != other._items_by_modality:
771
772
773
774
775
776
            return False

        ks = self.keys()
        return (ks == other.keys()
                and all(nested_tensors_equal(self[k], other[k]) for k in ks))

777
778
779
780
781
    def _validate_modality(self, method_name: str, modality: str) -> None:
        if not self._items_by_modality:
            raise RuntimeError(
                f"`{method_name}` is not supported when "
                "MultiModalKwargs is not initialized with `items`")
782

783
784
        if modality not in self._items_by_modality:
            available_modalities = set(self._items_by_modality.keys())
785
786
787
            raise KeyError(f"Modality {modality!r} not found. "
                           f"Available modalities: {available_modalities}")

788
789
790
791
    def get_item_count(self, modality: str) -> int:
        """Get the number of items belonging to a modality."""
        self._validate_modality("get_item_count", modality)
        return len(self._items_by_modality[modality])
792

793
794
795
796
797
798
799
    def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
        """
        Get the keyword arguments corresponding to an item identified by
        its modality and index.
        """
        self._validate_modality("get_item", modality)
        return self._items_by_modality[modality][item_index]
800

801
    def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
802
        """
803
804
        Get the keyword arguments corresponding to each item belonging to
        a modality.
805
        """
806
807
        self._validate_modality("get_items", modality)
        return self._items_by_modality[modality]
808

809

810
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
811
"""
812
A dictionary containing placeholder ranges for each modality.
813
814
815
"""


816
class MultiModalInputs(TypedDict):
817
    """
818
    Represents the outputs of
819
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
820
821
822
823
824
825
826
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

    prompt: str
827
    """The processed prompt text."""
828

829
    prompt_token_ids: list[int]
830
831
    """The processed token IDs which includes placeholder tokens."""

832
    token_type_ids: NotRequired[list[int]]
833
834
    """The token type IDs of the prompt."""

835
836
837
    mm_kwargs: MultiModalKwargs
    """Keyword arguments to be directly passed to the model after batching."""

838
    mm_hashes: Optional["MultiModalHashDict"]
839
840
    """The hashes of the multi-modal data."""

841
    mm_placeholders: "MultiModalPlaceholderDict"
842
843
    """
    For each modality, information about the placeholder tokens in
844
    `prompt_token_ids`.
845
    """
846

847
848
849
850
851
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

852
853
854

class MultiModalEncDecInputs(MultiModalInputs):
    """
855
856
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
857
858
859
860
861
862
863
864
865
866
867
    ready to be passed to vLLM internals.
    """

    encoder_prompt: str
    """The processed encoder prompt text."""

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""

    encoder_token_type_ids: NotRequired[list[int]]
    """The token type IDs of the encoder prompt."""