inputs.py 30.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    TypeAlias,
    TypedDict,
    Union,
    cast,
)
19
20

import numpy as np
Roger Wang's avatar
Roger Wang committed
21
from PIL.Image import Image
22
from typing_extensions import TypeVar
23

24
from vllm.utils.collection_utils import is_list_of
25
from vllm.utils.import_utils import LazyLoader
26
from vllm.utils.jsontree import json_map_leaves
27

28
29
from .media import MediaWithBytes

30
if TYPE_CHECKING:
31
32
33
34
35
    import torch
    import torch.types
    from transformers.feature_extraction_utils import BatchFeature
else:
    torch = LazyLoader("torch", globals(), "torch")
36

37

38
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
39
"""
40
A `transformers.image_utils.ImageInput` representing a single image
41
item, which can be passed to a HuggingFace `ImageProcessor`.
42
43
"""

44
45
46
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
47
"""
48
A `transformers.image_utils.VideoInput` representing a single video
49
item, which can be passed to a HuggingFace `VideoProcessor`.
50
51
"""

52
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
53
"""
54
Represents a single audio
55
item, which can be passed to a HuggingFace `AudioProcessor`.
56
57
"""

58
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", MediaWithBytes[HfImageItem]]
59
"""
60
A `transformers.image_utils.ImageInput` representing a single image
61
item, which can be passed to a HuggingFace `ImageProcessor`.
62
63
64
65
66
67

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

68
69
70
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
71
"""
72
73
74
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
75
76
77
78
79
80

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

81
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
82
83
"""
Represents a single audio
84
item, which can be passed to a HuggingFace `AudioProcessor`.
85
86
87
88
89
90
91
92
93
94

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

95

Roger Wang's avatar
Roger Wang committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class VisionChunkImage(TypedDict):
    """Represents an image wrapped as a vision chunk."""

    type: Literal["image"]
    image: Image
    uuid: str | None


class VisionChunkVideo(TypedDict):
    """Represents a video chunk with metadata."""

    type: Literal["video_chunk"]
    video_chunk: list[Image]
    uuid: str | None
    prompt: str
    video_idx: int


114
VisionChunk: TypeAlias = VisionChunkImage | VisionChunkVideo
Roger Wang's avatar
Roger Wang committed
115
116
117
"""A vision chunk is either an image or a video chunk."""


118
119
@dataclass(frozen=True)
class PlaceholderRange:
120
121
122
    """
    Placeholder location information for multi-modal data.

123
124
    Example:

125
    Prompt: `AAAA BBBB What is in these images?`
126

127
    Images A and B will have:
128

129
130
131
132
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
133
134
135
136
137
138
139
140
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

141
    is_embed: "torch.Tensor | None" = None
142
143
144
145
146
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

147
148
    @cached_property
    def embeds_cumsum(self) -> torch.Tensor | None:
149
        return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
150
151
152

    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
153
154
            return self.length

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        return int(self.embeds_cumsum[-1])

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

        embeds_start_idx = (
            int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
        )
        embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])

        return embeds_start_idx, embeds_end_idx
179

180
181
182
183
184
185
186
187
188
189
190
191
192
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
193
            return [(self.offset, self.offset + self.length - 1)]
194
195
196
197
198
199
200
201
202
203
204

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

205
206
207
208
209
210
211
212
213
214
215
216
217
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

218

219
220
221
222
223
224
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
225
226
227
228
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

229
230

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
231
232
233
234
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
235
    if isinstance(a, torch.Tensor):
236
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
237
    elif isinstance(b, torch.Tensor):
238
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
239
240

    if isinstance(a, list):
241
242
243
        return isinstance(b, list) and all(
            nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
        )
244
    if isinstance(b, list):
245
246
247
        return isinstance(a, list) and all(
            nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
        )
248
249
250
251
252

    # Both a and b are scalars
    return a == b


253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


270
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
271
272
"""
A dictionary containing nested tensors which have been batched via
273
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
274
275
276
"""


277
278
279
280
281
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
282
    return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
283
284


285
286
287
288
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
289

290
291
292
    Used to track multimodal data through processing and caching.
    A request containing multiple multimodal items will have one
    `MultiModalFeatureSpec` per item.
293
294
    """

295
    data: "MultiModalKwargsItem | None"
296
297
298
299
300
301
    """
    Represents multimodal data for this feature.

    Can be `None` if the item is cached, to skip IPC between API server
    and engine core processes.
    """
302
303

    modality: str
304
    """The input modality, e.g., `"image"`, `"audio"`, `"video"`."""
305
306

    identifier: str
307
    """The hash for caching encoder outputs (with LoRA prefix if applicable)."""
308
309

    mm_position: PlaceholderRange
310
311
312
313
    """
    The location of the `modality` tokens corresponding to this item
    in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`.
    """
314

315
    mm_hash: str | None = None
316
    """The hash for caching processor outputs (without LoRA prefix)."""
317

318
319
320
321
322
323
324
325
326
327
328
329
330
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

331

332
@dataclass
333
class MultiModalFieldElem:
334
    """
335
    Represents a processed keyword argument to pass to a model for a
336
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
337
338
    """

339
    data: NestedTensors
340
    """
341
    The tensor data of this field in
342
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
343
    i.e. the value of the keyword argument to be passed to the model.
344
345
346

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
347
348
349
350
351
352
353
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
354
355
356
357
358

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

359
360
361
362
363
364
365
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

366
        return data_equal and type(self.field) is type(other.field)  # noqa: E721
367
368


369
@dataclass(frozen=True, kw_only=True)
370
class BaseMultiModalField(ABC):
371
    """
372
373
374
    Defines how to interpret tensor data belonging to a keyword argument for
    [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
    and vice versa.
375
376
    """

377
378
    keep_on_cpu: bool = False
    """
379
380
381
    If `True`, then this field is excluded from being moved to the accelerator when
    [`group_and_batch_mm_items`][vllm.multimodal.utils.group_and_batch_mm_items]
    is called to batch the data.
382
383
    """

384
385
    def _field_factory(self):
        f = partial(MultiModalFieldElem, field=self)
386
387
388
389
390
391

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
392
393

    @abstractmethod
394
395
396
397
398
399
400
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
401
402
403
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
404

405
406
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
407
        """
408
409
        raise NotImplementedError

410
    @abstractmethod
411
412
413
414
415
416
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
417
        raise NotImplementedError
418

419
420
421
422
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
423
        device: torch.types.Device = None,
424
425
        pin_memory: bool = False,
    ) -> NestedTensors:
426
        """
427
428
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
429

430
431
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
432
433
434
435
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
436

437
438
439
440
441
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

442
        batch = [elem.data for elem in elems]
443
444
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
445
446


447
@dataclass(frozen=True, kw_only=True)
448
449
class MultiModalBatchedField(BaseMultiModalField):
    """
450
    Info:
451
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
452
453
    """

454
455
456
457
458
459
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
460
        field_factory = self._field_factory()
461
        return [field_factory(item) for item in data]
462

463
464
465
466
467
468
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
469
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
470
            batch = cast(list[torch.Tensor], batch)
471
472
473
474
475
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
476
            first_shape = batch[0].shape
477
            if all(elem.shape == first_shape for elem in batch):
478
479
480
481
482
483
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
484
                return torch.stack(batch, out=out)
485
486
487
488

        return batch


489
@dataclass(frozen=True, kw_only=True)
490
491
class MultiModalFlatField(BaseMultiModalField):
    """
492
    Info:
493
494
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
495
    """
496

497
    slices: Sequence[slice] | Sequence[Sequence[slice]]
498
    dim: int = 0
499

500
    def build_elems(
501
        self,
502
503
504
505
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
506
        field_factory = self._field_factory()
507
        if not is_list_of(self.slices, slice, check="all"):
508
            assert isinstance(data, torch.Tensor), (
509
                "torch.Tensor is required for multiple slices"
510
            )
511
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
512

513
514
515
516
517
518
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
519
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
520
            batch = cast(list[torch.Tensor], batch)
521
522
523
524
525
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
526

527
528
529
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
530
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
531

532
            first_shape = _shape_before_after(batch[0])
533

534
535
536
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
537
538
539
540
541
542
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
543
                return torch.concat(batch, dim=self.dim, out=out)
544

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
            # Variable-length case: non-concat dimensions differ
            # (e.g., Ultravox with different audio durations).
            # Use slice-assign approach (more efficient than padding).
            # See: https://github.com/vllm-project/vllm/issues/31658

            ndim = batch[0].ndim

            # Step 1: Compute output shape
            # - Non-concat dims: take max across batch
            # - Concat dim: sum across batch
            max_sizes: list[int] = []
            for d in range(ndim):
                if d == dim:
                    max_sizes.append(sum(t.shape[d] for t in batch))
                else:
                    max_sizes.append(max(t.shape[d] for t in batch))

            # Step 2: Create zero-initialized output tensor
            out = torch.zeros(
                max_sizes,
                dtype=batch[0].dtype,
                device=batch[0].device,
                pin_memory=pin_memory,
            )

            # Step 3: Slice-assign each tensor to its proper position
            concat_offset = 0
            for tensor in batch:
                slices: list[slice] = []
                for d in range(ndim):
                    if d == dim:
                        slices.append(
                            slice(concat_offset, concat_offset + tensor.shape[d])
                        )
                    else:
                        slices.append(slice(0, tensor.shape[d]))
                out[tuple(slices)] = tensor
                concat_offset += tensor.shape[dim]

            return out

586
        assert self.dim == 0, "dim == 0 is required for nested list"
587
        return [e for elem in batch for e in elem]
588
589


590
@dataclass(frozen=True, kw_only=True)
591
592
class MultiModalSharedField(BaseMultiModalField):
    """
593
    Info:
594
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
595
    """
596

597
598
599
600
601
602
603
604
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
605
        field_factory = self._field_factory()
606
607
        return [field_factory(data)] * self.batch_size

608
609
610
611
612
613
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
614
615
616
        return batch[0]


617
@dataclass(frozen=True)
618
619
class MultiModalFieldConfig:
    @staticmethod
620
    def batched(modality: str, *, keep_on_cpu: bool = False):
621
622
623
624
625
626
627
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
628
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
629
630
631

        Example:

632
633
634
635
636
637
638
639
640
641
642
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
643
        """
644
        return MultiModalFieldConfig(
645
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
646
647
648
649
            modality=modality,
        )

    @staticmethod
650
651
    def flat(
        modality: str,
652
        slices: Sequence[slice] | Sequence[Sequence[slice]],
653
        dim: int = 0,
654
655
        *,
        keep_on_cpu: bool = False,
656
    ):
657
658
659
660
661
662
663
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
664
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
665
                slices (dim>0) that is used to extract the data corresponding
666
667
                to it.
            dim: The dimension to extract data, default to 0.
668
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
669
670
671

        Example:

672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
701
        """
702
        return MultiModalFieldConfig(
703
704
705
706
707
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
708
709
710
            modality=modality,
        )

711
    @staticmethod
712
713
714
715
716
717
718
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
719
720
721
722
723
724
725
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
726
727
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
728
            dim: The dimension to slice, default to 0.
729
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
730
731
732

        Example:

733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
748
            size_per_item: [3, 4, 2]
749
750
751
752
753
754
755
756
757
758
759
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

760
        Info:
761
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
762
763
        """

764
        if size_per_item.ndim != 1:
765
766
767
768
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
769

770
        slice_idxs = [0, *accumulate(size_per_item)]
771
772
773
774
775
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
776

777
778
779
780
781
782
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
783

784
    @staticmethod
785
786
787
788
789
790
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
791
792
793
794
795
796
797
798
799
800
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
801
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
802
803
804

        Example:

805
806
807
        ```
        Given:
            batch_size: 4
808

809
810
        Input:
            Data: [XYZ]
811

812
813
814
815
816
817
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
818
819
        """
        return MultiModalFieldConfig(
820
821
822
823
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
824
825
826
            modality=modality,
        )

827
828
    field: BaseMultiModalField
    modality: str
829

830
    def build_elems(
831
832
833
        self,
        key: str,
        batch: NestedTensors,
834
    ) -> Sequence[MultiModalFieldElem]:
835
        return self.field.build_elems(self.modality, key, batch)
836
837


838
839
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
840
841
    A dictionary of processed keyword arguments to pass to the model,
    corresponding to a single item in
842
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
843
    """
844

845
    @staticmethod
846
    def dummy(nbytes: int = 1):
847
848
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
849
            data=torch.empty(nbytes, dtype=torch.uint8),
850
            field=MultiModalSharedField(batch_size=1),
851
        )
852
        return MultiModalKwargsItem({"dummy": mm_elem})
853

854
    def get_data(self) -> dict[str, NestedTensors]:
855
        return {key: elem.data for key, elem in self.items()}
856
857


858
859
860
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
861
    MultiModalKwargsItem | None,
862
863
864
865
866
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
867
    """
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    A dictionary of processed multi-modal inputs by modality.

    For example, given a processor that processes
    images into `pixel_values` and `image_grid_thw`,
    and audios into `input_audio_features`,
    a prompt with 2 images and 1 audio will be processed
    into a `MultiModalKwargsItems` with the following structure:

    ```python
    MultiModalKwargsItems(
        {
            "image": [
                # For the first image
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
                # For the second imgae
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
            ],
            "audio": [
                # For the first audio
                MultiModalKwargsItem({"input_audio_features": ...}),
            ],
        }
    )
    ```

    Unlike HF processing which returns all items
    in a single dictionary with batched keyword arguments,
    we split up the items because some of them may already be cached.
    Also, items from multiple requests may be batched together to improve throughput,
    using the logic defined by the
    [`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
    for each keyword argument.
900
901
    """

902
903
    @staticmethod
    def from_hf_inputs(
904
        hf_inputs: "BatchFeature",
905
906
907
908
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
909
910
911
912
913
914
915
916
917
918
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

919
        items_by_modality = dict[str, list[MultiModalKwargsItem]]()
920
921
922
923
924
925
926
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
927
928
                    f"Found: {batch_sizes=}"
                )
929
930

            batch_size = next(iter(batch_sizes.values()))
931
932
933
934
            items_by_modality[modality] = [
                MultiModalKwargsItem({k: v[i] for k, v in elems_in_modality.items()})
                for i in range(batch_size)
            ]
935

936
        return MultiModalKwargsItems(items_by_modality)
937

938
    def __getitem__(self, modality: str) -> Sequence[_I]:
939
        if modality not in self:
940
941
942
943
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
944

945
        return super().__getitem__(modality)  # type: ignore[return-value]
946

947
948
949
950
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
951
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
952
953
954

        return self  # type: ignore[return-value]

955
956
957
958
959
960
961
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
        from .utils import group_and_batch_mm_items

        items_by_modality = self.require_data()
        batches_by_modality = {
            modality: [
                data
                for _, data in group_and_batch_mm_items(
                    items,
                    device=device,
                    pin_memory=pin_memory,
                )
            ]
            for modality, items in items_by_modality.items()
            if len(items) > 0
        }
977

978
979
980
981
982
983
984
985
986
987
988
989
        out_data: BatchedTensorInputs = {}
        for _, batches in batches_by_modality.items():
            if len(batches) != 1:
                num_batches_by_modality = {
                    modality: len(batches)
                    for modality, batches in batches_by_modality.items()
                }

                raise RuntimeError(
                    f"Some modalities cannot be merged into a single batch "
                    f"({num_batches_by_modality=})"
                )
990

991
            out_data.update(batches[0])
992

993
        return out_data
994

995

996
997
998
999
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)