inputs.py 33.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
19
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    final,
)
20
21

import numpy as np
Roger Wang's avatar
Roger Wang committed
22
from PIL.Image import Image
23
from typing_extensions import TypeVar
24

25
from vllm.utils.collection_utils import is_list_of
26
from vllm.utils.import_utils import LazyLoader
27
from vllm.utils.jsontree import json_map_leaves
28

29
30
from .media import MediaWithBytes

31
if TYPE_CHECKING:
32
33
34
    import torch
    import torch.types
    from transformers.feature_extraction_utils import BatchFeature
35
36

    from vllm.inputs.data import _InputOptions
37
38
else:
    torch = LazyLoader("torch", globals(), "torch")
39

40
41
    _InputOptions = dict

42
43
_T = TypeVar("_T")

44
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
45
"""
46
A `transformers.image_utils.ImageInput` representing a single image
47
item, which can be passed to a HuggingFace `ImageProcessor`.
48
49
"""

50
51
52
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
53
"""
54
A `transformers.image_utils.VideoInput` representing a single video
55
item, which can be passed to a HuggingFace `VideoProcessor`.
56
57
"""

58
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
59
"""
60
Represents a single audio
61
item, which can be passed to a HuggingFace `AudioProcessor`.
62
63
"""

64
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", MediaWithBytes[HfImageItem]]
65
"""
66
A `transformers.image_utils.ImageInput` representing a single image
67
item, which can be passed to a HuggingFace `ImageProcessor`.
68
69
70
71
72
73

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

74
75
76
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
77
"""
78
79
80
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
81
82
83
84
85
86

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

87
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
88
89
"""
Represents a single audio
90
item, which can be passed to a HuggingFace `AudioProcessor`.
91
92
93
94
95
96
97
98
99
100

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

101
ModalityData: TypeAlias = _T | list[_T | None] | None
102
"""
103
104
Either a single data item, or a list of data items. Can only be None if UUID
is provided.
105
106

The number of data items allowed per modality is restricted by
107
`--limit-mm-per-prompt`.
108
109
110
"""


Roger Wang's avatar
Roger Wang committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class VisionChunkImage(TypedDict):
    """Represents an image wrapped as a vision chunk."""

    type: Literal["image"]
    image: Image
    uuid: str | None


class VisionChunkVideo(TypedDict):
    """Represents a video chunk with metadata."""

    type: Literal["video_chunk"]
    video_chunk: list[Image]
    uuid: str | None
    prompt: str
    video_idx: int


VisionChunk = VisionChunkImage | VisionChunkVideo
"""A vision chunk is either an image or a video chunk."""


133
134
135
136
@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

137
    image: ModalityData[ImageItem]
138
139
    """The input image(s)."""

140
    video: ModalityData[VideoItem]
141
142
    """The input video(s)."""

143
    audio: ModalityData[AudioItem]
144
145
    """The input audio(s)."""

Roger Wang's avatar
Roger Wang committed
146
147
148
    vision_chunk: ModalityData[VisionChunk]
    """The input visual atom(s) - unified modality for images and video chunks."""

149

150
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
151
152
"""
A dictionary containing an entry for each modality type to input.
153

154
155
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
156
157
"""

158
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str]
159
160
161
162
163
164
165
166
167
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""

168

169
170
@dataclass(frozen=True)
class PlaceholderRange:
171
172
173
    """
    Placeholder location information for multi-modal data.

174
175
    Example:

176
    Prompt: `AAAA BBBB What is in these images?`
177

178
    Images A and B will have:
179

180
181
182
183
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
184
185
186
187
188
189
190
191
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

192
    is_embed: "torch.Tensor | None" = None
193
194
195
196
197
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

198
199
    @cached_property
    def embeds_cumsum(self) -> torch.Tensor | None:
200
        return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
201
202
203
204

    @cached_property
    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
205
206
            return self.length

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        return int(self.embeds_cumsum[-1])

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

        embeds_start_idx = (
            int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
        )
        embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])

        return embeds_start_idx, embeds_end_idx
231

232
233
234
235
236
237
238
239
240
241
242
243
244
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
245
            return [(self.offset, self.offset + self.length - 1)]
246
247
248
249
250
251
252
253
254
255
256

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

257
258
259
260
261
262
263
264
265
266
267
268
269
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

270

271
272
273
274
275
276
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
277
278
279
280
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

281
282

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
283
284
285
286
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
287
    if isinstance(a, torch.Tensor):
288
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
289
    elif isinstance(b, torch.Tensor):
290
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
291
292

    if isinstance(a, list):
293
294
295
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
296
    if isinstance(b, list):
297
298
299
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
300
301
302
303
304

    # Both a and b are scalars
    return a == b


305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


322
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
323
324
"""
A dictionary containing nested tensors which have been batched via
325
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
326
327
328
"""


329
330
331
332
333
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
334
    return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
335
336


337
338
339
340
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
341

342
343
344
    Used to track multimodal data through processing and caching.
    A request containing multiple multimodal items will have one
    `MultiModalFeatureSpec` per item.
345
346
    """

347
    data: "MultiModalKwargsItem | None"
348
349
350
351
352
353
    """
    Represents multimodal data for this feature.

    Can be `None` if the item is cached, to skip IPC between API server
    and engine core processes.
    """
354
355

    modality: str
356
    """The input modality, e.g., `"image"`, `"audio"`, `"video"`."""
357
358

    identifier: str
359
    """The hash for caching encoder outputs (with LoRA prefix if applicable)."""
360
361

    mm_position: PlaceholderRange
362
363
364
365
    """
    The location of the `modality` tokens corresponding to this item
    in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`.
    """
366

367
    mm_hash: str | None = None
368
    """The hash for caching processor outputs (without LoRA prefix)."""
369

370
371
372
373
374
375
376
377
378
379
380
381
382
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

383

384
@dataclass
385
class MultiModalFieldElem:
386
    """
387
    Represents a processed keyword argument to pass to a model for a
388
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
389
390
    """

391
    data: NestedTensors
392
    """
393
    The tensor data of this field in
394
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
395
    i.e. the value of the keyword argument to be passed to the model.
396
397
398

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
399
400
401
402
403
404
405
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
406
407
408
409
410

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

411
412
413
414
415
416
417
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

418
        return data_equal and type(self.field) is type(other.field)  # noqa: E721
419
420


421
@dataclass(frozen=True, kw_only=True)
422
class BaseMultiModalField(ABC):
423
    """
424
425
426
    Defines how to interpret tensor data belonging to a keyword argument for
    [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
    and vice versa.
427
428
    """

429
430
    keep_on_cpu: bool = False
    """
431
432
433
    If `True`, then this field is excluded from being moved to the accelerator when
    [`group_and_batch_mm_items`][vllm.multimodal.utils.group_and_batch_mm_items]
    is called to batch the data.
434
435
    """

436
437
    def _field_factory(self):
        f = partial(MultiModalFieldElem, field=self)
438
439
440
441
442
443

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
444
445

    @abstractmethod
446
447
448
449
450
451
452
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
453
454
455
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
456

457
458
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
459
        """
460
461
        raise NotImplementedError

462
    @abstractmethod
463
464
465
466
467
468
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
469
        raise NotImplementedError
470

471
472
473
474
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
475
        device: torch.types.Device = None,
476
477
        pin_memory: bool = False,
    ) -> NestedTensors:
478
        """
479
480
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
481

482
483
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
484
485
486
487
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
488

489
490
491
492
493
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

494
        batch = [elem.data for elem in elems]
495
496
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
497
498


499
@dataclass(frozen=True, kw_only=True)
500
501
class MultiModalBatchedField(BaseMultiModalField):
    """
502
    Info:
503
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
504
505
    """

506
507
508
509
510
511
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
512
        field_factory = self._field_factory()
513
        return [field_factory(item) for item in data]
514

515
516
517
518
519
520
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
521
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
522
            batch = cast(list[torch.Tensor], batch)
523
524
525
526
527
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
528
            first_shape = batch[0].shape
529
            if all(elem.shape == first_shape for elem in batch):
530
531
532
533
534
535
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
536
                return torch.stack(batch, out=out)
537
538
539
540

        return batch


541
@dataclass(frozen=True, kw_only=True)
542
543
class MultiModalFlatField(BaseMultiModalField):
    """
544
    Info:
545
546
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
547
    """
548

549
    slices: Sequence[slice] | Sequence[Sequence[slice]]
550
    dim: int = 0
551

552
    def build_elems(
553
        self,
554
555
556
557
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
558
        field_factory = self._field_factory()
559
        if not is_list_of(self.slices, slice, check="all"):
560
            assert isinstance(data, torch.Tensor), (
561
                "torch.Tensor is required for multiple slices"
562
            )
563
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
564

565
566
567
568
569
570
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
571
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
572
            batch = cast(list[torch.Tensor], batch)
573
574
575
576
577
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
578

579
580
581
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
582
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
583

584
            first_shape = _shape_before_after(batch[0])
585

586
587
588
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
589
590
591
592
593
594
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
595
                return torch.concat(batch, dim=self.dim, out=out)
596

597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
            # Variable-length case: non-concat dimensions differ
            # (e.g., Ultravox with different audio durations).
            # Use slice-assign approach (more efficient than padding).
            # See: https://github.com/vllm-project/vllm/issues/31658

            ndim = batch[0].ndim

            # Step 1: Compute output shape
            # - Non-concat dims: take max across batch
            # - Concat dim: sum across batch
            max_sizes: list[int] = []
            for d in range(ndim):
                if d == dim:
                    max_sizes.append(sum(t.shape[d] for t in batch))
                else:
                    max_sizes.append(max(t.shape[d] for t in batch))

            # Step 2: Create zero-initialized output tensor
            out = torch.zeros(
                max_sizes,
                dtype=batch[0].dtype,
                device=batch[0].device,
                pin_memory=pin_memory,
            )

            # Step 3: Slice-assign each tensor to its proper position
            concat_offset = 0
            for tensor in batch:
                slices: list[slice] = []
                for d in range(ndim):
                    if d == dim:
                        slices.append(
                            slice(concat_offset, concat_offset + tensor.shape[d])
                        )
                    else:
                        slices.append(slice(0, tensor.shape[d]))
                out[tuple(slices)] = tensor
                concat_offset += tensor.shape[dim]

            return out

638
        assert self.dim == 0, "dim == 0 is required for nested list"
639
        return [e for elem in batch for e in elem]
640
641


642
@dataclass(frozen=True, kw_only=True)
643
644
class MultiModalSharedField(BaseMultiModalField):
    """
645
    Info:
646
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
647
    """
648

649
650
651
652
653
654
655
656
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
657
        field_factory = self._field_factory()
658
659
        return [field_factory(data)] * self.batch_size

660
661
662
663
664
665
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
666
667
668
        return batch[0]


669
@dataclass(frozen=True)
670
671
class MultiModalFieldConfig:
    @staticmethod
672
    def batched(modality: str, *, keep_on_cpu: bool = False):
673
674
675
676
677
678
679
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
680
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
681
682
683

        Example:

684
685
686
687
688
689
690
691
692
693
694
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
695
        """
696
        return MultiModalFieldConfig(
697
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
698
699
700
701
            modality=modality,
        )

    @staticmethod
702
703
    def flat(
        modality: str,
704
        slices: Sequence[slice] | Sequence[Sequence[slice]],
705
        dim: int = 0,
706
707
        *,
        keep_on_cpu: bool = False,
708
    ):
709
710
711
712
713
714
715
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
716
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
717
                slices (dim>0) that is used to extract the data corresponding
718
719
                to it.
            dim: The dimension to extract data, default to 0.
720
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
721
722
723

        Example:

724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
753
        """
754
        return MultiModalFieldConfig(
755
756
757
758
759
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
760
761
762
            modality=modality,
        )

763
    @staticmethod
764
765
766
767
768
769
770
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
771
772
773
774
775
776
777
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
778
779
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
780
            dim: The dimension to slice, default to 0.
781
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
782
783
784

        Example:

785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
800
            size_per_item: [3, 4, 2]
801
802
803
804
805
806
807
808
809
810
811
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

812
        Info:
813
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
814
815
        """

816
        if size_per_item.ndim != 1:
817
818
819
820
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
821

822
        slice_idxs = [0, *accumulate(size_per_item)]
823
824
825
826
827
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
828

829
830
831
832
833
834
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
835

836
    @staticmethod
837
838
839
840
841
842
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
843
844
845
846
847
848
849
850
851
852
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
853
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
854
855
856

        Example:

857
858
859
        ```
        Given:
            batch_size: 4
860

861
862
        Input:
            Data: [XYZ]
863

864
865
866
867
868
869
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
870
871
        """
        return MultiModalFieldConfig(
872
873
874
875
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
876
877
878
            modality=modality,
        )

879
880
    field: BaseMultiModalField
    modality: str
881

882
    def build_elems(
883
884
885
        self,
        key: str,
        batch: NestedTensors,
886
    ) -> Sequence[MultiModalFieldElem]:
887
        return self.field.build_elems(self.modality, key, batch)
888
889


890
891
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
892
893
    A dictionary of processed keyword arguments to pass to the model,
    corresponding to a single item in
894
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
895
    """
896

897
    @staticmethod
898
    def dummy(nbytes: int = 1):
899
900
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
901
            data=torch.empty(nbytes, dtype=torch.uint8),
902
            field=MultiModalSharedField(batch_size=1),
903
        )
904
        return MultiModalKwargsItem({"dummy": mm_elem})
905

906
    def get_data(self) -> dict[str, NestedTensors]:
907
        return {key: elem.data for key, elem in self.items()}
908
909


910
911
912
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
913
    MultiModalKwargsItem | None,
914
915
916
917
918
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
919
    """
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
    A dictionary of processed multi-modal inputs by modality.

    For example, given a processor that processes
    images into `pixel_values` and `image_grid_thw`,
    and audios into `input_audio_features`,
    a prompt with 2 images and 1 audio will be processed
    into a `MultiModalKwargsItems` with the following structure:

    ```python
    MultiModalKwargsItems(
        {
            "image": [
                # For the first image
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
                # For the second imgae
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
            ],
            "audio": [
                # For the first audio
                MultiModalKwargsItem({"input_audio_features": ...}),
            ],
        }
    )
    ```

    Unlike HF processing which returns all items
    in a single dictionary with batched keyword arguments,
    we split up the items because some of them may already be cached.
    Also, items from multiple requests may be batched together to improve throughput,
    using the logic defined by the
    [`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
    for each keyword argument.
952
953
    """

954
955
    @staticmethod
    def from_hf_inputs(
956
        hf_inputs: "BatchFeature",
957
958
959
960
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
961
962
963
964
965
966
967
968
969
970
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

971
        items_by_modality = dict[str, list[MultiModalKwargsItem]]()
972
973
974
975
976
977
978
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
979
980
                    f"Found: {batch_sizes=}"
                )
981
982

            batch_size = next(iter(batch_sizes.values()))
983
984
985
986
            items_by_modality[modality] = [
                MultiModalKwargsItem({k: v[i] for k, v in elems_in_modality.items()})
                for i in range(batch_size)
            ]
987

988
        return MultiModalKwargsItems(items_by_modality)
989

990
    def __getitem__(self, modality: str) -> Sequence[_I]:
991
        if modality not in self:
992
993
994
995
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
996

997
        return super().__getitem__(modality)  # type: ignore[return-value]
998

999
1000
1001
1002
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
1003
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
1004
1005
1006

        return self  # type: ignore[return-value]

1007
1008
1009
1010
1011
1012
1013
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
        from .utils import group_and_batch_mm_items

        items_by_modality = self.require_data()
        batches_by_modality = {
            modality: [
                data
                for _, data in group_and_batch_mm_items(
                    items,
                    device=device,
                    pin_memory=pin_memory,
                )
            ]
            for modality, items in items_by_modality.items()
            if len(items) > 0
        }
1029

1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
        out_data: BatchedTensorInputs = {}
        for _, batches in batches_by_modality.items():
            if len(batches) != 1:
                num_batches_by_modality = {
                    modality: len(batches)
                    for modality, batches in batches_by_modality.items()
                }

                raise RuntimeError(
                    f"Some modalities cannot be merged into a single batch "
                    f"({num_batches_by_modality=})"
                )
1042

1043
            out_data.update(batches[0])
1044

1045
        return out_data
1046

1047

1048
1049
1050
1051
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)
1052
1053


1054
1055
1056
1057
1058
1059
MultiModalHashes = dict[str, list[str]]
"""
A dictionary containing per-item hashes for each modality.
"""


1060
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
1061
"""
1062
A dictionary containing per-item placeholder ranges for each modality.
1063
1064
1065
"""


1066
class MultiModalInputs(_InputOptions):
1067
    """
1068
    Represents the outputs of
1069
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
1070
1071
1072
1073
1074
1075
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

1076
    prompt_token_ids: list[int]
1077
1078
    """The processed token IDs which includes placeholder tokens."""

1079
    mm_kwargs: MultiModalKwargsOptionalItems
1080
1081
    """Keyword arguments to be directly passed to the model after batching."""

1082
    mm_hashes: MultiModalHashes
1083
1084
    """The hashes of the multi-modal data."""

1085
    mm_placeholders: MultiModalPlaceholderDict
1086
1087
    """
    For each modality, information about the placeholder tokens in
1088
    `prompt_token_ids`.
1089
    """
1090
1091
1092
1093


class MultiModalEncDecInputs(MultiModalInputs):
    """
1094
1095
    Represents the outputs of
    [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
1096
    ready to be passed to vLLM internals.
1097
1098
1099

    Note: Even text-only encoder-decoder models are currently implemented
    as multi-modal models for convenience.
1100
    (Example: https://github.com/vllm-project/bart-plugin)
1101
1102
1103
1104
    """

    encoder_prompt_token_ids: list[int]
    """The processed token IDs of the encoder prompt."""