inputs.py 32.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
20
21

import numpy as np
Roger Wang's avatar
Roger Wang committed
22
from PIL.Image import Image
23
from typing_extensions import NotRequired, TypeVar
24

25
from vllm.utils.collection_utils import is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
30
from .media import MediaWithBytes

31
if TYPE_CHECKING:
32
33
34
35
36
    import torch
    import torch.types
    from transformers.feature_extraction_utils import BatchFeature
else:
    torch = LazyLoader("torch", globals(), "torch")
37

38
39
_T = TypeVar("_T")

40
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
41
"""
42
A `transformers.image_utils.ImageInput` representing a single image
43
item, which can be passed to a HuggingFace `ImageProcessor`.
44
45
"""

46
47
48
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
49
"""
50
A `transformers.image_utils.VideoInput` representing a single video
51
item, which can be passed to a HuggingFace `VideoProcessor`.
52
53
"""

54
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
55
"""
56
Represents a single audio
57
item, which can be passed to a HuggingFace `AudioProcessor`.
58
59
"""

60
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", MediaWithBytes[HfImageItem]]
61
"""
62
A `transformers.image_utils.ImageInput` representing a single image
63
item, which can be passed to a HuggingFace `ImageProcessor`.
64
65
66
67
68
69

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

70
71
72
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
73
"""
74
75
76
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
77
78
79
80
81
82

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

83
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
84
85
"""
Represents a single audio
86
item, which can be passed to a HuggingFace `AudioProcessor`.
87
88
89
90
91
92
93
94
95
96

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

97
ModalityData: TypeAlias = _T | list[_T | None] | None
98
"""
99
100
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
101
102

The number of data items allowed per modality is restricted by
103
`--limit-mm-per-prompt`.
104
105
106
"""


Roger Wang's avatar
Roger Wang committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class VisionChunkImage(TypedDict):
    """Represents an image wrapped as a vision chunk."""

    type: Literal["image"]
    image: Image
    uuid: str | None


class VisionChunkVideo(TypedDict):
    """Represents a video chunk with metadata."""

    type: Literal["video_chunk"]
    video_chunk: list[Image]
    uuid: str | None
    prompt: str
    video_idx: int


VisionChunk = VisionChunkImage | VisionChunkVideo
"""A vision chunk is either an image or a video chunk."""


129
130
131
132
@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

133
    image: ModalityData[ImageItem]
134
135
    """The input image(s)."""

136
    video: ModalityData[VideoItem]
137
138
    """The input video(s)."""

139
    audio: ModalityData[AudioItem]
140
141
    """The input audio(s)."""

Roger Wang's avatar
Roger Wang committed
142
143
144
    vision_chunk: ModalityData[VisionChunk]
    """The input visual atom(s) - unified modality for images and video chunks."""

145

146
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
147
148
"""
A dictionary containing an entry for each modality type to input.
149

150
151
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
152
153
"""

154
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
155
156
157
158
159
160
161
162
163
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

164

165
166
@dataclass(frozen=True)
class PlaceholderRange:
167
168
169
    """
    Placeholder location information for multi-modal data.

170
171
    Example:

172
    Prompt: `AAAA BBBB What is in these images?`
173

174
    Images A and B will have:
175

176
177
178
179
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
180
181
182
183
184
185
186
187
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

188
    is_embed: "torch.Tensor | None" = None
189
190
191
192
193
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

194
195
    @cached_property
    def embeds_cumsum(self) -> torch.Tensor | None:
196
        return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
197
198
199
200

    @cached_property
    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
201
202
            return self.length

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        return int(self.embeds_cumsum[-1])

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

        embeds_start_idx = (
            int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
        )
        embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])

        return embeds_start_idx, embeds_end_idx
227

228
229
230
231
232
233
234
235
236
237
238
239
240
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
241
            return [(self.offset, self.offset + self.length - 1)]
242
243
244
245
246
247
248
249
250
251
252

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

253
254
255
256
257
258
259
260
261
262
263
264
265
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

266

267
268
269
270
271
272
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
273
274
275
276
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

277
278

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
279
280
281
282
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
283
    if isinstance(a, torch.Tensor):
284
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
285
    elif isinstance(b, torch.Tensor):
286
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
287
288

    if isinstance(a, list):
289
290
291
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
292
    if isinstance(b, list):
293
294
295
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
296
297
298
299
300

    # Both a and b are scalars
    return a == b


301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


318
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
319
320
"""
A dictionary containing nested tensors which have been batched via
321
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
322
323
324
"""


325
326
327
328
329
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
330
    return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
331
332


333
334
335
336
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
337

338
339
340
    Used to track multimodal data through processing and caching.
    A request containing multiple multimodal items will have one
    `MultiModalFeatureSpec` per item.
341
342
    """

343
    data: "MultiModalKwargsItem | None"
344
345
346
347
348
349
    """
    Represents multimodal data for this feature.

    Can be `None` if the item is cached, to skip IPC between API server
    and engine core processes.
    """
350
351

    modality: str
352
    """The input modality, e.g., `"image"`, `"audio"`, `"video"`."""
353
354

    identifier: str
355
    """The hash for caching encoder outputs (with LoRA prefix if applicable)."""
356
357

    mm_position: PlaceholderRange
358
359
360
361
    """
    The location of the `modality` tokens corresponding to this item
    in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`.
    """
362

363
    mm_hash: str | None = None
364
    """The hash for caching processor outputs (without LoRA prefix)."""
365

366
367
368
369
370
371
372
373
374
375
376
377
378
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

379

380
@dataclass
381
class MultiModalFieldElem:
382
    """
383
    Represents a processed keyword argument to pass to a model for a
384
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
385
386
    """

387
    data: NestedTensors
388
    """
389
    The tensor data of this field in
390
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
391
    i.e. the value of the keyword argument to be passed to the model.
392
393
394

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
395
396
397
398
399
400
401
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
402
403
404
405
406

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

407
408
409
410
411
412
413
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

414
        return data_equal and type(self.field) is type(other.field)  # noqa: E721
415
416


417
@dataclass(frozen=True, kw_only=True)
418
class BaseMultiModalField(ABC):
419
    """
420
421
422
    Defines how to interpret tensor data belonging to a keyword argument for
    [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
    and vice versa.
423
424
    """

425
426
427
428
429
430
    keep_on_cpu: bool = False
    """
    If `True`, then this field is excluded from being moved to the accelerator
    when `MultiModalKwargsItems.get_data()` is called to batch the data.
    """

431
432
    def _field_factory(self):
        f = partial(MultiModalFieldElem, field=self)
433
434
435
436
437
438

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
439
440

    @abstractmethod
441
442
443
444
445
446
447
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
448
449
450
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
451

452
453
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
454
        """
455
456
        raise NotImplementedError

457
    @abstractmethod
458
459
460
461
462
463
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
464
        raise NotImplementedError
465

466
467
468
469
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
470
        device: torch.types.Device = None,
471
472
        pin_memory: bool = False,
    ) -> NestedTensors:
473
        """
474
475
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
476

477
478
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
479
480
481
482
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
483

484
485
486
487
488
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

489
        batch = [elem.data for elem in elems]
490
491
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
492
493


494
@dataclass(frozen=True, kw_only=True)
495
496
class MultiModalBatchedField(BaseMultiModalField):
    """
497
    Info:
498
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
499
500
    """

501
502
503
504
505
506
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
507
        field_factory = self._field_factory()
508
        return [field_factory(item) for item in data]
509

510
511
512
513
514
515
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
516
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
517
            batch = cast(list[torch.Tensor], batch)
518
519
520
521
522
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
523
            first_shape = batch[0].shape
524
            if all(elem.shape == first_shape for elem in batch):
525
526
527
528
529
530
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
531
                return torch.stack(batch, out=out)
532
533
534
535

        return batch


536
@dataclass(frozen=True, kw_only=True)
537
538
class MultiModalFlatField(BaseMultiModalField):
    """
539
    Info:
540
541
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
542
    """
543

544
    slices: Sequence[slice] | Sequence[Sequence[slice]]
545
    dim: int = 0
546

547
    def build_elems(
548
        self,
549
550
551
552
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
553
        field_factory = self._field_factory()
554
        if not is_list_of(self.slices, slice, check="all"):
555
            assert isinstance(data, torch.Tensor), (
556
                "torch.Tensor is required for multiple slices"
557
            )
558
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
559

560
561
562
563
564
565
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
566
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
567
            batch = cast(list[torch.Tensor], batch)
568
569
570
571
572
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
573

574
575
576
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
577
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
578

579
            first_shape = _shape_before_after(batch[0])
580

581
582
583
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
584
585
586
587
588
589
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
590
                return torch.concat(batch, dim=self.dim, out=out)
591

592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
            # Variable-length case: non-concat dimensions differ
            # (e.g., Ultravox with different audio durations).
            # Use slice-assign approach (more efficient than padding).
            # See: https://github.com/vllm-project/vllm/issues/31658

            ndim = batch[0].ndim

            # Step 1: Compute output shape
            # - Non-concat dims: take max across batch
            # - Concat dim: sum across batch
            max_sizes: list[int] = []
            for d in range(ndim):
                if d == dim:
                    max_sizes.append(sum(t.shape[d] for t in batch))
                else:
                    max_sizes.append(max(t.shape[d] for t in batch))

            # Step 2: Create zero-initialized output tensor
            out = torch.zeros(
                max_sizes,
                dtype=batch[0].dtype,
                device=batch[0].device,
                pin_memory=pin_memory,
            )

            # Step 3: Slice-assign each tensor to its proper position
            concat_offset = 0
            for tensor in batch:
                slices: list[slice] = []
                for d in range(ndim):
                    if d == dim:
                        slices.append(
                            slice(concat_offset, concat_offset + tensor.shape[d])
                        )
                    else:
                        slices.append(slice(0, tensor.shape[d]))
                out[tuple(slices)] = tensor
                concat_offset += tensor.shape[dim]

            return out

633
        assert self.dim == 0, "dim == 0 is required for nested list"
634
        return [e for elem in batch for e in elem]
635
636


637
@dataclass(frozen=True, kw_only=True)
638
639
class MultiModalSharedField(BaseMultiModalField):
    """
640
    Info:
641
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
642
    """
643

644
645
646
647
648
649
650
651
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
652
        field_factory = self._field_factory()
653
654
        return [field_factory(data)] * self.batch_size

655
656
657
658
659
660
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
661
662
663
        return batch[0]


664
@dataclass(frozen=True)
665
666
class MultiModalFieldConfig:
    @staticmethod
667
    def batched(modality: str, *, keep_on_cpu: bool = False):
668
669
670
671
672
673
674
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
675
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
676
677
678

        Example:

679
680
681
682
683
684
685
686
687
688
689
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
690
        """
691
        return MultiModalFieldConfig(
692
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
693
694
695
696
            modality=modality,
        )

    @staticmethod
697
698
    def flat(
        modality: str,
699
        slices: Sequence[slice] | Sequence[Sequence[slice]],
700
        dim: int = 0,
701
702
        *,
        keep_on_cpu: bool = False,
703
    ):
704
705
706
707
708
709
710
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
711
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
712
                slices (dim>0) that is used to extract the data corresponding
713
714
                to it.
            dim: The dimension to extract data, default to 0.
715
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
716
717
718

        Example:

719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
748
        """
749
        return MultiModalFieldConfig(
750
751
752
753
754
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
755
756
757
            modality=modality,
        )

758
    @staticmethod
759
760
761
762
763
764
765
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
766
767
768
769
770
771
772
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
773
774
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
775
            dim: The dimension to slice, default to 0.
776
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
777
778
779

        Example:

780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
795
            size_per_item: [3, 4, 2]
796
797
798
799
800
801
802
803
804
805
806
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

807
        Info:
808
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
809
810
        """

811
        if size_per_item.ndim != 1:
812
813
814
815
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
816

817
        slice_idxs = [0, *accumulate(size_per_item)]
818
819
820
821
822
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
823

824
825
826
827
828
829
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
830

831
    @staticmethod
832
833
834
835
836
837
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
838
839
840
841
842
843
844
845
846
847
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
848
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
849
850
851

        Example:

852
853
854
        ```
        Given:
            batch_size: 4
855

856
857
        Input:
            Data: [XYZ]
858

859
860
861
862
863
864
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
865
866
        """
        return MultiModalFieldConfig(
867
868
869
870
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
871
872
873
            modality=modality,
        )

874
875
    field: BaseMultiModalField
    modality: str
876

877
    def build_elems(
878
879
880
        self,
        key: str,
        batch: NestedTensors,
881
    ) -> Sequence[MultiModalFieldElem]:
882
        return self.field.build_elems(self.modality, key, batch)
883
884


885
886
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
887
888
    A dictionary of processed keyword arguments to pass to the model,
    corresponding to a single item in
889
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
890
    """
891

892
    @staticmethod
893
    def dummy(nbytes: int = 1):
894
895
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
896
            data=torch.empty(nbytes, dtype=torch.uint8),
897
            field=MultiModalSharedField(batch_size=1),
898
        )
899
        return MultiModalKwargsItem({"dummy": mm_elem})
900

901
    def get_data(self) -> dict[str, NestedTensors]:
902
        return {key: elem.data for key, elem in self.items()}
903
904


905
906
907
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
908
    MultiModalKwargsItem | None,
909
910
911
912
913
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
914
    """
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
    A dictionary of processed multi-modal inputs by modality.

    For example, given a processor that processes
    images into `pixel_values` and `image_grid_thw`,
    and audios into `input_audio_features`,
    a prompt with 2 images and 1 audio will be processed
    into a `MultiModalKwargsItems` with the following structure:

    ```python
    MultiModalKwargsItems(
        {
            "image": [
                # For the first image
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
                # For the second imgae
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
            ],
            "audio": [
                # For the first audio
                MultiModalKwargsItem({"input_audio_features": ...}),
            ],
        }
    )
    ```

    Unlike HF processing which returns all items
    in a single dictionary with batched keyword arguments,
    we split up the items because some of them may already be cached.
    Also, items from multiple requests may be batched together to improve throughput,
    using the logic defined by the
    [`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
    for each keyword argument.
947
948
    """

949
950
    @staticmethod
    def from_hf_inputs(
951
        hf_inputs: "BatchFeature",
952
953
954
955
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
956
957
958
959
960
961
962
963
964
965
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

966
        items_by_modality = dict[str, list[MultiModalKwargsItem]]()
967
968
969
970
971
972
973
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
974
975
                    f"Found: {batch_sizes=}"
                )
976
977

            batch_size = next(iter(batch_sizes.values()))
978
979
980
981
            items_by_modality[modality] = [
                MultiModalKwargsItem({k: v[i] for k, v in elems_in_modality.items()})
                for i in range(batch_size)
            ]
982

983
        return MultiModalKwargsItems(items_by_modality)
984

985
    def __getitem__(self, modality: str) -> Sequence[_I]:
986
        if modality not in self:
987
988
989
990
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
991

992
        return super().__getitem__(modality)  # type: ignore[return-value]
993

994
995
996
997
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
998
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
999
1000
1001

        return self  # type: ignore[return-value]

1002
1003
1004
1005
1006
1007
1008
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
1009
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
1010
1011
1012
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
1013
1014
1015
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
1016

1017
1018
1019
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

1020
        data = {
1021
1022
1023
1024
1025
            key: elems[0].field.reduce_data(
                elems,
                device=device,
                pin_memory=pin_memory,
            )
1026
1027
1028
1029
            for key, elems in elems_by_key.items()
        }

        return data
1030

1031

1032
1033
1034
1035
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
1036
1037


1038
1039
1040
1041
1042
1043
MultiModalHashes = dict[str, list[str]]
"""
A dictionary containing per-item hashes for each modality.
"""


1044
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
1045
"""
1046
A dictionary containing per-item placeholder ranges for each modality.
1047
1048
1049
"""


1050
class MultiModalInputs(TypedDict):
1051
    """
1052
    Represents the outputs of
1053
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
1054
1055
1056
1057
1058
1059
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

1060
    prompt_token_ids: list[int]
1061
1062
    """The processed token IDs which includes placeholder tokens."""

1063
    mm_kwargs: MultiModalKwargsOptionalItems
1064
1065
    """Keyword arguments to be directly passed to the model after batching."""

1066
    mm_hashes: MultiModalHashes
1067
1068
    """The hashes of the multi-modal data."""

1069
    mm_placeholders: MultiModalPlaceholderDict
1070
1071
    """
    For each modality, information about the placeholder tokens in
1072
    `prompt_token_ids`.
1073
    """
1074

1075
1076
1077
1078
1079
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1080
1081
1082

class MultiModalEncDecInputs(MultiModalInputs):
    """
1083
1084
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1085
1086
1087
1088
1089
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""