inputs.py 32.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
21
22

import numpy as np
Roger Wang's avatar
Roger Wang committed
23
from PIL.Image import Image
24
from typing_extensions import NotRequired, TypeVar
25

26
from vllm.utils.collection_utils import full_groupby, is_list_of
27
from vllm.utils.import_utils import LazyLoader
28
from vllm.utils.jsontree import json_map_leaves
29

30
if TYPE_CHECKING:
31
32
33
34
    import torch
    import torch.types
    from transformers.feature_extraction_utils import BatchFeature

35
    from .media import MediaWithBytes
36
37
else:
    torch = LazyLoader("torch", globals(), "torch")
38

39
40
_T = TypeVar("_T")

41
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
42
"""
43
A `transformers.image_utils.ImageInput` representing a single image
44
item, which can be passed to a HuggingFace `ImageProcessor`.
45
46
"""

47
48
49
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
50
"""
51
A `transformers.image_utils.VideoInput` representing a single video
52
item, which can be passed to a HuggingFace `VideoProcessor`.
53
54
"""

55
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
56
"""
57
Represents a single audio
58
item, which can be passed to a HuggingFace `AudioProcessor`.
59
60
"""

61
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"]
62
"""
63
A `transformers.image_utils.ImageInput` representing a single image
64
item, which can be passed to a HuggingFace `ImageProcessor`.
65
66
67
68
69
70

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

71
72
73
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
74
"""
75
76
77
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
78
79
80
81
82
83

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

84
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
85
86
"""
Represents a single audio
87
item, which can be passed to a HuggingFace `AudioProcessor`.
88
89
90
91
92
93
94
95
96
97

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

98
ModalityData: TypeAlias = _T | list[_T | None] | None
99
"""
100
101
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
102
103

The number of data items allowed per modality is restricted by
104
`--limit-mm-per-prompt`.
105
106
107
"""


Roger Wang's avatar
Roger Wang committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class VisionChunkImage(TypedDict):
    """Represents an image wrapped as a vision chunk."""

    type: Literal["image"]
    image: Image
    uuid: str | None


class VisionChunkVideo(TypedDict):
    """Represents a video chunk with metadata."""

    type: Literal["video_chunk"]
    video_chunk: list[Image]
    uuid: str | None
    prompt: str
    video_idx: int


VisionChunk = VisionChunkImage | VisionChunkVideo
"""A vision chunk is either an image or a video chunk."""


130
131
132
133
@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

134
    image: ModalityData[ImageItem]
135
136
    """The input image(s)."""

137
    video: ModalityData[VideoItem]
138
139
    """The input video(s)."""

140
    audio: ModalityData[AudioItem]
141
142
    """The input audio(s)."""

Roger Wang's avatar
Roger Wang committed
143
144
145
    vision_chunk: ModalityData[VisionChunk]
    """The input visual atom(s) - unified modality for images and video chunks."""

146

147
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
148
149
"""
A dictionary containing an entry for each modality type to input.
150

151
152
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
153
154
"""

155
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
156
157
158
159
160
161
162
163
164
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

165

166
167
@dataclass(frozen=True)
class PlaceholderRange:
168
169
170
    """
    Placeholder location information for multi-modal data.

171
172
    Example:

173
    Prompt: `AAAA BBBB What is in these images?`
174

175
    Images A and B will have:
176

177
178
179
180
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
181
182
183
184
185
186
187
188
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

189
    is_embed: Optional["torch.Tensor"] = None
190
191
192
193
194
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

195
196
    @cached_property
    def embeds_cumsum(self) -> torch.Tensor | None:
197
        return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
198
199
200
201

    @cached_property
    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
202
203
            return self.length

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        return int(self.embeds_cumsum[-1])

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

        embeds_start_idx = (
            int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
        )
        embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])

        return embeds_start_idx, embeds_end_idx
228

229
230
231
232
233
234
235
236
237
238
239
240
241
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
242
            return [(self.offset, self.offset + self.length - 1)]
243
244
245
246
247
248
249
250
251
252
253

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

254
255
256
257
258
259
260
261
262
263
264
265
266
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

267

268
269
270
271
272
273
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
274
275
276
277
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

278
279

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
280
281
282
283
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
284
    if isinstance(a, torch.Tensor):
285
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
286
    elif isinstance(b, torch.Tensor):
287
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
288
289

    if isinstance(a, list):
290
291
292
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
293
    if isinstance(b, list):
294
295
296
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
297
298
299
300
301

    # Both a and b are scalars
    return a == b


302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


319
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
320
321
"""
A dictionary containing nested tensors which have been batched via
322
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
323
324
325
"""


326
327
328
329
330
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
331
    return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
332
333


334
335
336
337
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
338

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    Used by the V1 engine to track multimodal data through processing and
    caching. A request containing multiple multimodal items will have one
    MultiModalFeatureSpec per item.
    """

    data: Optional["MultiModalKwargsItem"]
    """Multimodal data for this feature"""

    modality: str
    """Based on the input, e.g., "image", "audio", "video"."""

    identifier: str
    """mm_hash or uuid for caching encoder outputs."""

    mm_position: PlaceholderRange
    """e.g., PlaceholderRange(offset=2, length=336)"""

356
357
358
    mm_hash: str | None = None
    """Base mm_hash for processor cache (without LoRA prefix)."""

359
360
361
362
363
364
365
366
367
368
369
370
371
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

372

373
@dataclass
374
class MultiModalFieldElem:
375
    """
376
377
    Represents a keyword argument inside a
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
378
379
380
381
382
383
384
385
386
387
    """

    modality: str
    """
    The modality of the corresponding multi-modal item.
    Each multi-modal item can consist of multiple keyword arguments.
    """

    key: str
    """
388
    The key of this field in
389
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
390
391
392
    i.e. the name of the keyword argument to be passed to the model.
    """

393
    data: NestedTensors
394
    """
395
    The tensor data of this field in
396
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
397
    i.e. the value of the keyword argument to be passed to the model.
398
399
400

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
401
402
403
404
405
406
407
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
408
409
410
411
412

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

413
414
415
416
417
418
419
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

420
421
422
        return (
            (self.modality, self.key) == (other.modality, other.key)
            and data_equal
423
            and type(self.field) is type(other.field)
424
        )  # noqa: E721
425
426


427
@dataclass(frozen=True, kw_only=True)
428
class BaseMultiModalField(ABC):
429
    """
430
431
432
    Defines how to interpret tensor data belonging to a keyword argument for
    [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
    and vice versa.
433
434
    """

435
436
437
438
439
440
    keep_on_cpu: bool = False
    """
    If `True`, then this field is excluded from being moved to the accelerator
    when `MultiModalKwargsItems.get_data()` is called to batch the data.
    """

441
442
443
444
445
446
447
448
449
450
451
452
453
    def _field_factory(self, *, modality: str, key: str):
        f = partial(
            MultiModalFieldElem,
            modality=modality,
            key=key,
            field=self,
        )

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
454
455

    @abstractmethod
456
457
458
459
460
461
462
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
463
464
465
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
466

467
468
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
469
        """
470
471
        raise NotImplementedError

472
    @abstractmethod
473
474
475
476
477
478
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
479
        raise NotImplementedError
480

481
482
483
484
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
485
        device: torch.types.Device = None,
486
487
        pin_memory: bool = False,
    ) -> NestedTensors:
488
        """
489
490
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
491

492
493
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
494
495
496
497
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
498

499
500
501
502
503
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

504
        batch = [elem.data for elem in elems]
505
506
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
507
508


509
@dataclass(frozen=True, kw_only=True)
510
511
class MultiModalBatchedField(BaseMultiModalField):
    """
512
    Info:
513
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
514
515
    """

516
517
518
519
520
521
522
523
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(item) for item in data]
524

525
526
527
528
529
530
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
531
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
532
            batch = cast(list[torch.Tensor], batch)
533
534
535
536
537
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
538
            first_shape = batch[0].shape
539
            if all(elem.shape == first_shape for elem in batch):
540
541
542
543
544
545
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
546
                return torch.stack(batch, out=out)
547
548
549
550

        return batch


551
@dataclass(frozen=True, kw_only=True)
552
553
class MultiModalFlatField(BaseMultiModalField):
    """
554
    Info:
555
556
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
557
    """
558

559
    slices: Sequence[slice] | Sequence[Sequence[slice]]
560
    dim: int = 0
561

562
    def build_elems(
563
        self,
564
565
566
567
568
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
569
        if not is_list_of(self.slices, slice, check="all"):
570
            assert isinstance(data, torch.Tensor), (
571
                "torch.Tensor is required for multiple slices"
572
            )
573
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
574

575
576
577
578
579
580
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
581
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
582
            batch = cast(list[torch.Tensor], batch)
583
584
585
586
587
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
588

589
590
591
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
592
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
593

594
            first_shape = _shape_before_after(batch[0])
595

596
597
598
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
599
600
601
602
603
604
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
605
                return torch.concat(batch, dim=self.dim, out=out)
606

607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
            # Variable-length case: non-concat dimensions differ
            # (e.g., Ultravox with different audio durations).
            # Use slice-assign approach (more efficient than padding).
            # See: https://github.com/vllm-project/vllm/issues/31658

            ndim = batch[0].ndim

            # Step 1: Compute output shape
            # - Non-concat dims: take max across batch
            # - Concat dim: sum across batch
            max_sizes: list[int] = []
            for d in range(ndim):
                if d == dim:
                    max_sizes.append(sum(t.shape[d] for t in batch))
                else:
                    max_sizes.append(max(t.shape[d] for t in batch))

            # Step 2: Create zero-initialized output tensor
            out = torch.zeros(
                max_sizes,
                dtype=batch[0].dtype,
                device=batch[0].device,
                pin_memory=pin_memory,
            )

            # Step 3: Slice-assign each tensor to its proper position
            concat_offset = 0
            for tensor in batch:
                slices: list[slice] = []
                for d in range(ndim):
                    if d == dim:
                        slices.append(
                            slice(concat_offset, concat_offset + tensor.shape[d])
                        )
                    else:
                        slices.append(slice(0, tensor.shape[d]))
                out[tuple(slices)] = tensor
                concat_offset += tensor.shape[dim]

            return out

648
        assert self.dim == 0, "dim == 0 is required for nested list"
649
        return [e for elem in batch for e in elem]
650
651


652
@dataclass(frozen=True, kw_only=True)
653
654
class MultiModalSharedField(BaseMultiModalField):
    """
655
    Info:
656
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
657
    """
658

659
660
661
662
663
664
665
666
667
668
669
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        field_factory = self._field_factory(modality=modality, key=key)
        return [field_factory(data)] * self.batch_size

670
671
672
673
674
675
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
676
677
678
        return batch[0]


679
@dataclass(frozen=True)
680
681
class MultiModalFieldConfig:
    @staticmethod
682
    def batched(modality: str, *, keep_on_cpu: bool = False):
683
684
685
686
687
688
689
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
690
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
691
692
693

        Example:

694
695
696
697
698
699
700
701
702
703
704
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
705
        """
706
        return MultiModalFieldConfig(
707
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
708
709
710
711
            modality=modality,
        )

    @staticmethod
712
713
    def flat(
        modality: str,
714
        slices: Sequence[slice] | Sequence[Sequence[slice]],
715
        dim: int = 0,
716
717
        *,
        keep_on_cpu: bool = False,
718
    ):
719
720
721
722
723
724
725
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
726
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
727
                slices (dim>0) that is used to extract the data corresponding
728
729
                to it.
            dim: The dimension to extract data, default to 0.
730
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
731
732
733

        Example:

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
763
        """
764
        return MultiModalFieldConfig(
765
766
767
768
769
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
770
771
772
            modality=modality,
        )

773
    @staticmethod
774
775
776
777
778
779
780
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
781
782
783
784
785
786
787
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
788
789
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
790
            dim: The dimension to slice, default to 0.
791
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
792
793
794

        Example:

795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
810
            size_per_item: [3, 4, 2]
811
812
813
814
815
816
817
818
819
820
821
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

822
        Info:
823
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
824
825
        """

826
        if size_per_item.ndim != 1:
827
828
829
830
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
831

832
        slice_idxs = [0, *accumulate(size_per_item)]
833
834
835
836
837
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
838

839
840
841
842
843
844
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
845

846
    @staticmethod
847
848
849
850
851
852
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
853
854
855
856
857
858
859
860
861
862
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
863
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
864
865
866

        Example:

867
868
869
        ```
        Given:
            batch_size: 4
870

871
872
        Input:
            Data: [XYZ]
873

874
875
876
877
878
879
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
880
881
        """
        return MultiModalFieldConfig(
882
883
884
885
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
886
887
888
            modality=modality,
        )

889
890
    field: BaseMultiModalField
    modality: str
891

892
    def build_elems(
893
894
895
        self,
        key: str,
        batch: NestedTensors,
896
    ) -> Sequence[MultiModalFieldElem]:
897
        return self.field.build_elems(self.modality, key, batch)
898
899


900
901
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
902
903
904
905
    A collection of
    [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
    corresponding to a data item in
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
906
    """
907

908
    @staticmethod
909
    def dummy(modality: str, nbytes: int = 1):
910
911
912
913
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
            modality=modality,
            key="dummy",
914
            data=torch.empty(nbytes, dtype=torch.uint8),
915
            field=MultiModalSharedField(batch_size=1),
916
917
918
        )
        return MultiModalKwargsItem.from_elems([mm_elem])

919
920
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
921
        return MultiModalKwargsItem({elem.key: elem for elem in elems})
922

923
    def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
924
925
        super().__init__(data)

926
        modalities = {elem.modality for elem in self.values()}
927
        assert len(modalities) == 1, f"Found different modalities={modalities}"
928
929
930
931
932
933
        self._modality = next(iter(modalities))

    @property
    def modality(self) -> str:
        return self._modality

934
    def get_data(self) -> dict[str, NestedTensors]:
935
        return {key: elem.data for key, elem in self.items()}
936
937


938
939
940
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
941
    MultiModalKwargsItem | None,
942
943
944
945
946
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
947
    """
948
949
950
    A dictionary of
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
    by modality.
951
952
    """

953
954
    @staticmethod
    def from_hf_inputs(
955
        hf_inputs: "BatchFeature",
956
957
958
959
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
978
979
                    f"Found: {batch_sizes=}"
                )
980
981
982
983
984
985

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

986
        return MultiModalKwargsItems.from_seq(items)
987

988
989
    @staticmethod
    def from_seq(items: Sequence[MultiModalKwargsItem]):
990
        items_by_modality = full_groupby(items, key=lambda x: x.modality)
991
        return MultiModalKwargsItems(items_by_modality)
992

993
    def __getitem__(self, modality: str) -> Sequence[_I]:
994
        if modality not in self:
995
996
997
998
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
999

1000
        return super().__getitem__(modality)  # type: ignore[return-value]
1001

1002
1003
1004
1005
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
1006
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
1007
1008
1009

        return self  # type: ignore[return-value]

1010
1011
1012
1013
1014
1015
1016
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
1017
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
1018
1019
1020
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
1021
1022
1023
                    raise RuntimeError(
                        f"Cannot build data from empty mm_items[{modality}][{i}]"
                    )
1024

1025
1026
1027
                for key, elem in item.items():
                    elems_by_key[key].append(elem)

1028
        data = {
1029
1030
1031
1032
1033
            key: elems[0].field.reduce_data(
                elems,
                device=device,
                pin_memory=pin_memory,
            )
1034
1035
1036
1037
            for key, elems in elems_by_key.items()
        }

        return data
1038

1039

1040
1041
1042
1043
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
1044
1045


1046
1047
1048
1049
1050
1051
MultiModalHashes = dict[str, list[str]]
"""
A dictionary containing per-item hashes for each modality.
"""


1052
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
1053
"""
1054
A dictionary containing per-item placeholder ranges for each modality.
1055
1056
1057
"""


1058
class MultiModalInputs(TypedDict):
1059
    """
1060
    Represents the outputs of
1061
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
1062
1063
1064
1065
1066
1067
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

1068
    prompt_token_ids: list[int]
1069
1070
    """The processed token IDs which includes placeholder tokens."""

1071
    mm_kwargs: MultiModalKwargsOptionalItems
1072
1073
    """Keyword arguments to be directly passed to the model after batching."""

1074
    mm_hashes: MultiModalHashes
1075
1076
    """The hashes of the multi-modal data."""

1077
    mm_placeholders: MultiModalPlaceholderDict
1078
1079
    """
    For each modality, information about the placeholder tokens in
1080
    `prompt_token_ids`.
1081
    """
1082

1083
1084
1085
1086
1087
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

1088
1089
1090

class MultiModalEncDecInputs(MultiModalInputs):
    """
1091
1092
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1093
1094
1095
1096
1097
    ready to be passed to vLLM internals.
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""