inputs.py 30.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    TypeAlias,
    TypedDict,
    Union,
    cast,
)
19
20

import numpy as np
Roger Wang's avatar
Roger Wang committed
21
from PIL.Image import Image
22
from typing_extensions import TypeVar
23

24
from vllm.utils.collection_utils import is_list_of
25
from vllm.utils.import_utils import LazyLoader
26
from vllm.utils.jsontree import json_map_leaves
27

28
29
from .media import MediaWithBytes

30
if TYPE_CHECKING:
31
32
33
34
35
    import torch
    import torch.types
    from transformers.feature_extraction_utils import BatchFeature
else:
    torch = LazyLoader("torch", globals(), "torch")
36

37

38
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
39
"""
40
A `transformers.image_utils.ImageInput` representing a single image
41
item, which can be passed to a HuggingFace `ImageProcessor`.
42
43
"""

44
45
46
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
47
"""
48
A `transformers.image_utils.VideoInput` representing a single video
49
item, which can be passed to a HuggingFace `VideoProcessor`.
50
51
"""

52
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
53
"""
54
Represents a single audio
55
item, which can be passed to a HuggingFace `AudioProcessor`.
56
57
"""

58
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", MediaWithBytes[HfImageItem]]
59
"""
60
A `transformers.image_utils.ImageInput` representing a single image
61
item, which can be passed to a HuggingFace `ImageProcessor`.
62
63
64
65
66
67

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

68
69
70
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
71
"""
72
73
74
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
75
76
77
78
79
80

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

81
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
82
83
"""
Represents a single audio
84
item, which can be passed to a HuggingFace `AudioProcessor`.
85
86
87
88
89
90
91
92
93
94

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

95

Roger Wang's avatar
Roger Wang committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class VisionChunkImage(TypedDict):
    """Represents an image wrapped as a vision chunk."""

    type: Literal["image"]
    image: Image
    uuid: str | None


class VisionChunkVideo(TypedDict):
    """Represents a video chunk with metadata."""

    type: Literal["video_chunk"]
    video_chunk: list[Image]
    uuid: str | None
    prompt: str
    video_idx: int


114
VisionChunk: TypeAlias = VisionChunkImage | VisionChunkVideo
Roger Wang's avatar
Roger Wang committed
115
116
117
"""A vision chunk is either an image or a video chunk."""


118
119
@dataclass(frozen=True)
class PlaceholderRange:
120
121
122
    """
    Placeholder location information for multi-modal data.

123
124
    Example:

125
    Prompt: `AAAA BBBB What is in these images?`
126

127
    Images A and B will have:
128

129
130
131
132
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
133
134
135
136
137
138
139
140
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

141
    is_embed: "torch.Tensor | None" = None
142
143
144
145
146
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

147
148
    @cached_property
    def embeds_cumsum(self) -> torch.Tensor | None:
149
        return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
150
151
152

    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
153
154
            return self.length

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        return int(self.embeds_cumsum[-1])

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

        embeds_start_idx = (
            int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
        )
        embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])

        return embeds_start_idx, embeds_end_idx
179

180
181
182
183
184
185
186
187
188
189
190
191
192
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
193
            return [(self.offset, self.offset + self.length - 1)]
194
195
196
197
198
199
200
201
202
203
204

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

205
206
207
208
209
210
211
212
213
214
215
216
217
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

218

219
220
221
222
223
224
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
225
226
227
228
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

229
230

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
231
232
233
234
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
235
    if isinstance(a, torch.Tensor):
236
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
237
    elif isinstance(b, torch.Tensor):
238
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
239
240

    if isinstance(a, list):
241
242
243
244
        return (
            isinstance(b, list)
            and len(a) == len(b)
            and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))
245
        )
246
    if isinstance(b, list):
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        return (
            isinstance(a, list)
            and len(b) == len(a)
            and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))
        )

    if isinstance(a, tuple):
        return (
            isinstance(b, tuple)
            and len(a) == len(b)
            and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))
        )
    if isinstance(b, tuple):
        return (
            isinstance(a, tuple)
            and len(b) == len(a)
            and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))
264
        )
265
266
267
268
269

    # Both a and b are scalars
    return a == b


270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


287
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
288
289
"""
A dictionary containing nested tensors which have been batched via
290
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
291
292
293
"""


294
295
296
297
298
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
299
    return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
300
301


302
303
304
305
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
306

307
308
309
    Used to track multimodal data through processing and caching.
    A request containing multiple multimodal items will have one
    `MultiModalFeatureSpec` per item.
310
311
    """

312
    data: "MultiModalKwargsItem | None"
313
314
315
316
317
318
    """
    Represents multimodal data for this feature.

    Can be `None` if the item is cached, to skip IPC between API server
    and engine core processes.
    """
319
320

    modality: str
321
    """The input modality, e.g., `"image"`, `"audio"`, `"video"`."""
322
323

    identifier: str
324
    """The hash for caching encoder outputs (with LoRA prefix if applicable)."""
325
326

    mm_position: PlaceholderRange
327
328
329
330
    """
    The location of the `modality` tokens corresponding to this item
    in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`.
    """
331

332
    mm_hash: str | None = None
333
    """The hash for caching processor outputs (without LoRA prefix)."""
334

335
336
337
338
339
340
341
342
343
344
345
346
347
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

348

349
@dataclass
350
class MultiModalFieldElem:
351
    """
352
    Represents a processed keyword argument to pass to a model for a
353
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
354
355
    """

356
    data: NestedTensors
357
    """
358
    The tensor data of this field in
359
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
360
    i.e. the value of the keyword argument to be passed to the model.
361
362
363

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
364
365
366
367
368
369
370
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
371
372
373
374
375

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

376
377
378
379
380
381
382
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

383
        return data_equal and type(self.field) is type(other.field)  # noqa: E721
384
385


386
@dataclass(frozen=True, kw_only=True)
387
class BaseMultiModalField(ABC):
388
    """
389
390
391
    Defines how to interpret tensor data belonging to a keyword argument for
    [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
    and vice versa.
392
393
    """

394
395
    keep_on_cpu: bool = False
    """
396
397
398
    If `True`, then this field is excluded from being moved to the accelerator when
    [`group_and_batch_mm_items`][vllm.multimodal.utils.group_and_batch_mm_items]
    is called to batch the data.
399
400
    """

401
402
    def _field_factory(self):
        f = partial(MultiModalFieldElem, field=self)
403
404
405
406
407
408

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
409
410

    @abstractmethod
411
412
413
414
415
416
417
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
418
419
420
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
421

422
423
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
424
        """
425
426
        raise NotImplementedError

427
    @abstractmethod
428
429
430
431
432
433
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
434
        raise NotImplementedError
435

436
437
438
439
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
440
        device: torch.types.Device = None,
441
442
        pin_memory: bool = False,
    ) -> NestedTensors:
443
        """
444
445
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
446

447
448
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
449
450
451
452
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
453

454
455
456
457
458
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

459
        batch = [elem.data for elem in elems]
460
461
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
462
463


464
@dataclass(frozen=True, kw_only=True)
465
466
class MultiModalBatchedField(BaseMultiModalField):
    """
467
    Info:
468
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
469
470
    """

471
472
473
474
475
476
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
477
        field_factory = self._field_factory()
478
        return [field_factory(item) for item in data]
479

480
481
482
483
484
485
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
486
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
487
            batch = cast(list[torch.Tensor], batch)
488
489
490
491
492
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
493
            first_shape = batch[0].shape
494
            if all(elem.shape == first_shape for elem in batch):
495
496
497
498
499
500
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
501
                return torch.stack(batch, out=out)
502
503
504
505

        return batch


506
@dataclass(frozen=True, kw_only=True)
507
508
class MultiModalFlatField(BaseMultiModalField):
    """
509
    Info:
510
511
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
512
    """
513

514
    slices: Sequence[slice] | Sequence[Sequence[slice]]
515
    dim: int = 0
516

517
    def build_elems(
518
        self,
519
520
521
522
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
523
        field_factory = self._field_factory()
524
        if not is_list_of(self.slices, slice, check="all"):
525
            assert isinstance(data, torch.Tensor), (
526
                "torch.Tensor is required for multiple slices"
527
            )
528
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
529

530
531
532
533
534
535
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
536
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
537
            batch = cast(list[torch.Tensor], batch)
538
539
540
541
542
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
543

544
545
546
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
547
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
548

549
            first_shape = _shape_before_after(batch[0])
550

551
552
553
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
554
555
556
557
558
559
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
560
                return torch.concat(batch, dim=self.dim, out=out)
561

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
            # Variable-length case: non-concat dimensions differ
            # (e.g., Ultravox with different audio durations).
            # Use slice-assign approach (more efficient than padding).
            # See: https://github.com/vllm-project/vllm/issues/31658

            ndim = batch[0].ndim

            # Step 1: Compute output shape
            # - Non-concat dims: take max across batch
            # - Concat dim: sum across batch
            max_sizes: list[int] = []
            for d in range(ndim):
                if d == dim:
                    max_sizes.append(sum(t.shape[d] for t in batch))
                else:
                    max_sizes.append(max(t.shape[d] for t in batch))

            # Step 2: Create zero-initialized output tensor
            out = torch.zeros(
                max_sizes,
                dtype=batch[0].dtype,
                device=batch[0].device,
                pin_memory=pin_memory,
            )

            # Step 3: Slice-assign each tensor to its proper position
            concat_offset = 0
            for tensor in batch:
                slices: list[slice] = []
                for d in range(ndim):
                    if d == dim:
                        slices.append(
                            slice(concat_offset, concat_offset + tensor.shape[d])
                        )
                    else:
                        slices.append(slice(0, tensor.shape[d]))
                out[tuple(slices)] = tensor
                concat_offset += tensor.shape[dim]

            return out

603
        assert self.dim == 0, "dim == 0 is required for nested list"
604
        return [e for elem in batch for e in elem]
605
606


607
@dataclass(frozen=True, kw_only=True)
608
609
class MultiModalSharedField(BaseMultiModalField):
    """
610
    Info:
611
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
612
    """
613

614
615
616
617
618
619
620
621
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
622
        field_factory = self._field_factory()
623
624
        return [field_factory(data)] * self.batch_size

625
626
627
628
629
630
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
631
632
633
        return batch[0]


634
@dataclass(frozen=True)
635
636
class MultiModalFieldConfig:
    @staticmethod
637
    def batched(modality: str, *, keep_on_cpu: bool = False):
638
639
640
641
642
643
644
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
645
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
646
647
648

        Example:

649
650
651
652
653
654
655
656
657
658
659
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
660
        """
661
        return MultiModalFieldConfig(
662
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
663
664
665
666
            modality=modality,
        )

    @staticmethod
667
668
    def flat(
        modality: str,
669
        slices: Sequence[slice] | Sequence[Sequence[slice]],
670
        dim: int = 0,
671
672
        *,
        keep_on_cpu: bool = False,
673
    ):
674
675
676
677
678
679
680
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
681
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
682
                slices (dim>0) that is used to extract the data corresponding
683
684
                to it.
            dim: The dimension to extract data, default to 0.
685
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
686
687
688

        Example:

689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
718
        """
719
        return MultiModalFieldConfig(
720
721
722
723
724
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
725
726
727
            modality=modality,
        )

728
    @staticmethod
729
730
731
732
733
734
735
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
736
737
738
739
740
741
742
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
743
744
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
745
            dim: The dimension to slice, default to 0.
746
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
747
748
749

        Example:

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
765
            size_per_item: [3, 4, 2]
766
767
768
769
770
771
772
773
774
775
776
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

777
        Info:
778
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
779
780
        """

781
        if size_per_item.ndim != 1:
782
783
784
785
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
786

787
        slice_idxs = [0, *accumulate(size_per_item)]
788
789
790
791
792
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
793

794
795
796
797
798
799
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
800

801
    @staticmethod
802
803
804
805
806
807
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
808
809
810
811
812
813
814
815
816
817
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
818
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
819
820
821

        Example:

822
823
824
        ```
        Given:
            batch_size: 4
825

826
827
        Input:
            Data: [XYZ]
828

829
830
831
832
833
834
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
835
836
        """
        return MultiModalFieldConfig(
837
838
839
840
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
841
842
843
            modality=modality,
        )

844
845
    field: BaseMultiModalField
    modality: str
846

847
    def build_elems(
848
849
850
        self,
        key: str,
        batch: NestedTensors,
851
    ) -> Sequence[MultiModalFieldElem]:
852
        return self.field.build_elems(self.modality, key, batch)
853
854


855
856
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
857
858
    A dictionary of processed keyword arguments to pass to the model,
    corresponding to a single item in
859
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
860
    """
861

862
    @staticmethod
863
    def dummy(nbytes: int = 1):
864
865
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
866
            data=torch.empty(nbytes, dtype=torch.uint8),
867
            field=MultiModalSharedField(batch_size=1),
868
        )
869
        return MultiModalKwargsItem({"dummy": mm_elem})
870

871
    def get_data(self) -> dict[str, NestedTensors]:
872
        return {key: elem.data for key, elem in self.items()}
873
874


875
876
877
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
878
    MultiModalKwargsItem | None,
879
880
881
882
883
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
884
    """
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
    A dictionary of processed multi-modal inputs by modality.

    For example, given a processor that processes
    images into `pixel_values` and `image_grid_thw`,
    and audios into `input_audio_features`,
    a prompt with 2 images and 1 audio will be processed
    into a `MultiModalKwargsItems` with the following structure:

    ```python
    MultiModalKwargsItems(
        {
            "image": [
                # For the first image
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
                # For the second imgae
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
            ],
            "audio": [
                # For the first audio
                MultiModalKwargsItem({"input_audio_features": ...}),
            ],
        }
    )
    ```

    Unlike HF processing which returns all items
    in a single dictionary with batched keyword arguments,
    we split up the items because some of them may already be cached.
    Also, items from multiple requests may be batched together to improve throughput,
    using the logic defined by the
    [`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
    for each keyword argument.
917
918
    """

919
920
    @staticmethod
    def from_hf_inputs(
921
        hf_inputs: "BatchFeature",
922
923
924
925
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
926
927
928
929
930
931
932
933
934
935
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

936
        items_by_modality = dict[str, list[MultiModalKwargsItem]]()
937
938
939
940
941
942
943
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
944
945
                    f"Found: {batch_sizes=}"
                )
946
947

            batch_size = next(iter(batch_sizes.values()))
948
949
950
951
            items_by_modality[modality] = [
                MultiModalKwargsItem({k: v[i] for k, v in elems_in_modality.items()})
                for i in range(batch_size)
            ]
952

953
        return MultiModalKwargsItems(items_by_modality)
954

955
    def __getitem__(self, modality: str) -> Sequence[_I]:
956
        if modality not in self:
957
958
959
960
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
961

962
        return super().__getitem__(modality)  # type: ignore[return-value]
963

964
965
966
967
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
968
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
969
970
971

        return self  # type: ignore[return-value]

972
973
974
975
976
977
978
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
        from .utils import group_and_batch_mm_items

        items_by_modality = self.require_data()
        batches_by_modality = {
            modality: [
                data
                for _, data in group_and_batch_mm_items(
                    items,
                    device=device,
                    pin_memory=pin_memory,
                )
            ]
            for modality, items in items_by_modality.items()
            if len(items) > 0
        }
994

995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
        out_data: BatchedTensorInputs = {}
        for _, batches in batches_by_modality.items():
            if len(batches) != 1:
                num_batches_by_modality = {
                    modality: len(batches)
                    for modality, batches in batches_by_modality.items()
                }

                raise RuntimeError(
                    f"Some modalities cannot be merged into a single batch "
                    f"({num_batches_by_modality=})"
                )
1007

1008
            out_data.update(batches[0])
1009

1010
        return out_data
1011

1012

1013
1014
1015
1016
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)