inputs.py 33.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
20
21

import numpy as np
Roger Wang's avatar
Roger Wang committed
22
from PIL.Image import Image
23
from typing_extensions import TypeVar
24

25
from vllm.utils.collection_utils import is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
30
from .media import MediaWithBytes

31
if TYPE_CHECKING:
32
33
34
    import torch
    import torch.types
    from transformers.feature_extraction_utils import BatchFeature
35
36

    from vllm.inputs.data import _InputOptions
37
38
else:
    torch = LazyLoader("torch", globals(), "torch")
39

40
41
    _InputOptions = dict

42
43
_T = TypeVar("_T")

44
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
45
"""
46
A `transformers.image_utils.ImageInput` representing a single image
47
item, which can be passed to a HuggingFace `ImageProcessor`.
48
49
"""

50
51
52
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
53
"""
54
A `transformers.image_utils.VideoInput` representing a single video
55
item, which can be passed to a HuggingFace `VideoProcessor`.
56
57
"""

58
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
59
"""
60
Represents a single audio
61
item, which can be passed to a HuggingFace `AudioProcessor`.
62
63
"""

64
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", MediaWithBytes[HfImageItem]]
65
"""
66
A `transformers.image_utils.ImageInput` representing a single image
67
item, which can be passed to a HuggingFace `ImageProcessor`.
68
69
70
71
72
73

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

74
75
76
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
77
"""
78
79
80
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
81
82
83
84
85
86

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

87
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
88
89
"""
Represents a single audio
90
item, which can be passed to a HuggingFace `AudioProcessor`.
91
92
93
94
95
96
97
98
99
100

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

101
ModalityData: TypeAlias = _T | list[_T | None] | None
102
"""
103
104
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
105
106

The number of data items allowed per modality is restricted by
107
`--limit-mm-per-prompt`.
108
109
110
"""


Roger Wang's avatar
Roger Wang committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class VisionChunkImage(TypedDict):
    """Represents an image wrapped as a vision chunk."""

    type: Literal["image"]
    image: Image
    uuid: str | None


class VisionChunkVideo(TypedDict):
    """Represents a video chunk with metadata."""

    type: Literal["video_chunk"]
    video_chunk: list[Image]
    uuid: str | None
    prompt: str
    video_idx: int


VisionChunk = VisionChunkImage | VisionChunkVideo
"""A vision chunk is either an image or a video chunk."""


133
134
135
136
@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

137
    image: ModalityData[ImageItem]
138
139
    """The input image(s)."""

140
    video: ModalityData[VideoItem]
141
142
    """The input video(s)."""

143
    audio: ModalityData[AudioItem]
144
145
    """The input audio(s)."""

Roger Wang's avatar
Roger Wang committed
146
147
148
    vision_chunk: ModalityData[VisionChunk]
    """The input visual atom(s) - unified modality for images and video chunks."""

149

150
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
151
152
"""
A dictionary containing an entry for each modality type to input.
153

154
155
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
156
157
"""

158
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
159
160
161
162
163
164
165
166
167
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

168

169
170
@dataclass(frozen=True)
class PlaceholderRange:
171
172
173
    """
    Placeholder location information for multi-modal data.

174
175
    Example:

176
    Prompt: `AAAA BBBB What is in these images?`
177

178
    Images A and B will have:
179

180
181
182
183
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
184
185
186
187
188
189
190
191
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

192
    is_embed: "torch.Tensor | None" = None
193
194
195
196
197
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

198
199
    @cached_property
    def embeds_cumsum(self) -> torch.Tensor | None:
200
        return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
201
202
203

    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
204
205
            return self.length

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        return int(self.embeds_cumsum[-1])

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

        embeds_start_idx = (
            int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
        )
        embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])

        return embeds_start_idx, embeds_end_idx
230

231
232
233
234
235
236
237
238
239
240
241
242
243
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
244
            return [(self.offset, self.offset + self.length - 1)]
245
246
247
248
249
250
251
252
253
254
255

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

256
257
258
259
260
261
262
263
264
265
266
267
268
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

269

270
271
272
273
274
275
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
276
277
278
279
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

280
281

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
282
283
284
285
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
286
    if isinstance(a, torch.Tensor):
287
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
288
    elif isinstance(b, torch.Tensor):
289
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
290
291

    if isinstance(a, list):
292
293
294
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
295
    if isinstance(b, list):
296
297
298
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
299
300
301
302
303

    # Both a and b are scalars
    return a == b


304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


321
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
322
323
"""
A dictionary containing nested tensors which have been batched via
324
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
325
326
327
"""


328
329
330
331
332
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
333
    return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
334
335


336
337
338
339
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
340

341
342
343
    Used to track multimodal data through processing and caching.
    A request containing multiple multimodal items will have one
    `MultiModalFeatureSpec` per item.
344
345
    """

346
    data: "MultiModalKwargsItem | None"
347
348
349
350
351
352
    """
    Represents multimodal data for this feature.

    Can be `None` if the item is cached, to skip IPC between API server
    and engine core processes.
    """
353
354

    modality: str
355
    """The input modality, e.g., `"image"`, `"audio"`, `"video"`."""
356
357

    identifier: str
358
    """The hash for caching encoder outputs (with LoRA prefix if applicable)."""
359
360

    mm_position: PlaceholderRange
361
362
363
364
    """
    The location of the `modality` tokens corresponding to this item
    in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`.
    """
365

366
    mm_hash: str | None = None
367
    """The hash for caching processor outputs (without LoRA prefix)."""
368

369
370
371
372
373
374
375
376
377
378
379
380
381
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

382

383
@dataclass
384
class MultiModalFieldElem:
385
    """
386
    Represents a processed keyword argument to pass to a model for a
387
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
388
389
    """

390
    data: NestedTensors
391
    """
392
    The tensor data of this field in
393
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
394
    i.e. the value of the keyword argument to be passed to the model.
395
396
397

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
398
399
400
401
402
403
404
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
405
406
407
408
409

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

410
411
412
413
414
415
416
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

417
        return data_equal and type(self.field) is type(other.field)  # noqa: E721
418
419


420
@dataclass(frozen=True, kw_only=True)
421
class BaseMultiModalField(ABC):
422
    """
423
424
425
    Defines how to interpret tensor data belonging to a keyword argument for
    [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
    and vice versa.
426
427
    """

428
429
    keep_on_cpu: bool = False
    """
430
431
432
    If `True`, then this field is excluded from being moved to the accelerator when
    [`group_and_batch_mm_items`][vllm.multimodal.utils.group_and_batch_mm_items]
    is called to batch the data.
433
434
    """

435
436
    def _field_factory(self):
        f = partial(MultiModalFieldElem, field=self)
437
438
439
440
441
442

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
443
444

    @abstractmethod
445
446
447
448
449
450
451
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
452
453
454
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
455

456
457
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
458
        """
459
460
        raise NotImplementedError

461
    @abstractmethod
462
463
464
465
466
467
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
468
        raise NotImplementedError
469

470
471
472
473
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
474
        device: torch.types.Device = None,
475
476
        pin_memory: bool = False,
    ) -> NestedTensors:
477
        """
478
479
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
480

481
482
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
483
484
485
486
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
487

488
489
490
491
492
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

493
        batch = [elem.data for elem in elems]
494
495
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
496
497


498
@dataclass(frozen=True, kw_only=True)
499
500
class MultiModalBatchedField(BaseMultiModalField):
    """
501
    Info:
502
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
503
504
    """

505
506
507
508
509
510
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
511
        field_factory = self._field_factory()
512
        return [field_factory(item) for item in data]
513

514
515
516
517
518
519
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
520
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
521
            batch = cast(list[torch.Tensor], batch)
522
523
524
525
526
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
527
            first_shape = batch[0].shape
528
            if all(elem.shape == first_shape for elem in batch):
529
530
531
532
533
534
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
535
                return torch.stack(batch, out=out)
536
537
538
539

        return batch


540
@dataclass(frozen=True, kw_only=True)
541
542
class MultiModalFlatField(BaseMultiModalField):
    """
543
    Info:
544
545
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
546
    """
547

548
    slices: Sequence[slice] | Sequence[Sequence[slice]]
549
    dim: int = 0
550

551
    def build_elems(
552
        self,
553
554
555
556
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
557
        field_factory = self._field_factory()
558
        if not is_list_of(self.slices, slice, check="all"):
559
            assert isinstance(data, torch.Tensor), (
560
                "torch.Tensor is required for multiple slices"
561
            )
562
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
563

564
565
566
567
568
569
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
570
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
571
            batch = cast(list[torch.Tensor], batch)
572
573
574
575
576
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
577

578
579
580
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
581
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
582

583
            first_shape = _shape_before_after(batch[0])
584

585
586
587
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
588
589
590
591
592
593
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
594
                return torch.concat(batch, dim=self.dim, out=out)
595

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
            # Variable-length case: non-concat dimensions differ
            # (e.g., Ultravox with different audio durations).
            # Use slice-assign approach (more efficient than padding).
            # See: https://github.com/vllm-project/vllm/issues/31658

            ndim = batch[0].ndim

            # Step 1: Compute output shape
            # - Non-concat dims: take max across batch
            # - Concat dim: sum across batch
            max_sizes: list[int] = []
            for d in range(ndim):
                if d == dim:
                    max_sizes.append(sum(t.shape[d] for t in batch))
                else:
                    max_sizes.append(max(t.shape[d] for t in batch))

            # Step 2: Create zero-initialized output tensor
            out = torch.zeros(
                max_sizes,
                dtype=batch[0].dtype,
                device=batch[0].device,
                pin_memory=pin_memory,
            )

            # Step 3: Slice-assign each tensor to its proper position
            concat_offset = 0
            for tensor in batch:
                slices: list[slice] = []
                for d in range(ndim):
                    if d == dim:
                        slices.append(
                            slice(concat_offset, concat_offset + tensor.shape[d])
                        )
                    else:
                        slices.append(slice(0, tensor.shape[d]))
                out[tuple(slices)] = tensor
                concat_offset += tensor.shape[dim]

            return out

637
        assert self.dim == 0, "dim == 0 is required for nested list"
638
        return [e for elem in batch for e in elem]
639
640


641
@dataclass(frozen=True, kw_only=True)
642
643
class MultiModalSharedField(BaseMultiModalField):
    """
644
    Info:
645
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
646
    """
647

648
649
650
651
652
653
654
655
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
656
        field_factory = self._field_factory()
657
658
        return [field_factory(data)] * self.batch_size

659
660
661
662
663
664
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
665
666
667
        return batch[0]


668
@dataclass(frozen=True)
669
670
class MultiModalFieldConfig:
    @staticmethod
671
    def batched(modality: str, *, keep_on_cpu: bool = False):
672
673
674
675
676
677
678
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
679
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
680
681
682

        Example:

683
684
685
686
687
688
689
690
691
692
693
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
694
        """
695
        return MultiModalFieldConfig(
696
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
697
698
699
700
            modality=modality,
        )

    @staticmethod
701
702
    def flat(
        modality: str,
703
        slices: Sequence[slice] | Sequence[Sequence[slice]],
704
        dim: int = 0,
705
706
        *,
        keep_on_cpu: bool = False,
707
    ):
708
709
710
711
712
713
714
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
715
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
716
                slices (dim>0) that is used to extract the data corresponding
717
718
                to it.
            dim: The dimension to extract data, default to 0.
719
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
720
721
722

        Example:

723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
752
        """
753
        return MultiModalFieldConfig(
754
755
756
757
758
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
759
760
761
            modality=modality,
        )

762
    @staticmethod
763
764
765
766
767
768
769
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
770
771
772
773
774
775
776
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
777
778
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
779
            dim: The dimension to slice, default to 0.
780
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
781
782
783

        Example:

784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
799
            size_per_item: [3, 4, 2]
800
801
802
803
804
805
806
807
808
809
810
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

811
        Info:
812
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
813
814
        """

815
        if size_per_item.ndim != 1:
816
817
818
819
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
820

821
        slice_idxs = [0, *accumulate(size_per_item)]
822
823
824
825
826
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
827

828
829
830
831
832
833
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
834

835
    @staticmethod
836
837
838
839
840
841
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
842
843
844
845
846
847
848
849
850
851
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
852
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
853
854
855

        Example:

856
857
858
        ```
        Given:
            batch_size: 4
859

860
861
        Input:
            Data: [XYZ]
862

863
864
865
866
867
868
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
869
870
        """
        return MultiModalFieldConfig(
871
872
873
874
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
875
876
877
            modality=modality,
        )

878
879
    field: BaseMultiModalField
    modality: str
880

881
    def build_elems(
882
883
884
        self,
        key: str,
        batch: NestedTensors,
885
    ) -> Sequence[MultiModalFieldElem]:
886
        return self.field.build_elems(self.modality, key, batch)
887
888


889
890
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
891
892
    A dictionary of processed keyword arguments to pass to the model,
    corresponding to a single item in
893
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
894
    """
895

896
    @staticmethod
897
    def dummy(nbytes: int = 1):
898
899
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
900
            data=torch.empty(nbytes, dtype=torch.uint8),
901
            field=MultiModalSharedField(batch_size=1),
902
        )
903
        return MultiModalKwargsItem({"dummy": mm_elem})
904

905
    def get_data(self) -> dict[str, NestedTensors]:
906
        return {key: elem.data for key, elem in self.items()}
907
908


909
910
911
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
912
    MultiModalKwargsItem | None,
913
914
915
916
917
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
918
    """
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
    A dictionary of processed multi-modal inputs by modality.

    For example, given a processor that processes
    images into `pixel_values` and `image_grid_thw`,
    and audios into `input_audio_features`,
    a prompt with 2 images and 1 audio will be processed
    into a `MultiModalKwargsItems` with the following structure:

    ```python
    MultiModalKwargsItems(
        {
            "image": [
                # For the first image
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
                # For the second imgae
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
            ],
            "audio": [
                # For the first audio
                MultiModalKwargsItem({"input_audio_features": ...}),
            ],
        }
    )
    ```

    Unlike HF processing which returns all items
    in a single dictionary with batched keyword arguments,
    we split up the items because some of them may already be cached.
    Also, items from multiple requests may be batched together to improve throughput,
    using the logic defined by the
    [`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
    for each keyword argument.
951
952
    """

953
954
    @staticmethod
    def from_hf_inputs(
955
        hf_inputs: "BatchFeature",
956
957
958
959
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
960
961
962
963
964
965
966
967
968
969
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

970
        items_by_modality = dict[str, list[MultiModalKwargsItem]]()
971
972
973
974
975
976
977
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
978
979
                    f"Found: {batch_sizes=}"
                )
980
981

            batch_size = next(iter(batch_sizes.values()))
982
983
984
985
            items_by_modality[modality] = [
                MultiModalKwargsItem({k: v[i] for k, v in elems_in_modality.items()})
                for i in range(batch_size)
            ]
986

987
        return MultiModalKwargsItems(items_by_modality)
988

989
    def __getitem__(self, modality: str) -> Sequence[_I]:
990
        if modality not in self:
991
992
993
994
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
995

996
        return super().__getitem__(modality)  # type: ignore[return-value]
997

998
999
1000
1001
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
1002
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
1003
1004
1005

        return self  # type: ignore[return-value]

1006
1007
1008
1009
1010
1011
1012
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
        from .utils import group_and_batch_mm_items

        items_by_modality = self.require_data()
        batches_by_modality = {
            modality: [
                data
                for _, data in group_and_batch_mm_items(
                    items,
                    device=device,
                    pin_memory=pin_memory,
                )
            ]
            for modality, items in items_by_modality.items()
            if len(items) > 0
        }
1028

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
        out_data: BatchedTensorInputs = {}
        for _, batches in batches_by_modality.items():
            if len(batches) != 1:
                num_batches_by_modality = {
                    modality: len(batches)
                    for modality, batches in batches_by_modality.items()
                }

                raise RuntimeError(
                    f"Some modalities cannot be merged into a single batch "
                    f"({num_batches_by_modality=})"
                )
1041

1042
            out_data.update(batches[0])
1043

1044
        return out_data
1045

1046

1047
1048
1049
1050
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
1051
1052


1053
1054
1055
1056
1057
1058
MultiModalHashes = dict[str, list[str]]
"""
A dictionary containing per-item hashes for each modality.
"""


1059
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
1060
"""
1061
A dictionary containing per-item placeholder ranges for each modality.
1062
1063
1064
"""


1065
class MultiModalInputs(_InputOptions):
1066
    """
1067
    Represents the outputs of
1068
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
1069
1070
1071
1072
1073
1074
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

1075
    prompt_token_ids: list[int]
1076
1077
    """The processed token IDs which includes placeholder tokens."""

1078
    mm_kwargs: MultiModalKwargsOptionalItems
1079
1080
    """Keyword arguments to be directly passed to the model after batching."""

1081
    mm_hashes: MultiModalHashes
1082
1083
    """The hashes of the multi-modal data."""

1084
    mm_placeholders: MultiModalPlaceholderDict
1085
1086
    """
    For each modality, information about the placeholder tokens in
1087
    `prompt_token_ids`.
1088
    """
1089
1090
1091
1092


class MultiModalEncDecInputs(MultiModalInputs):
    """
1093
1094
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1095
    ready to be passed to vLLM internals.
1096
1097
1098

    Note: Even text-only encoder-decoder models are currently implemented
    as multi-modal models for convenience.
1099
    (Example: https://github.com/vllm-project/bart-plugin)
1100
1101
1102
1103
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""