inputs.py 30.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from collections import UserDict, defaultdict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from functools import cached_property, partial
9
from itertools import accumulate
10
11
12
13
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    TypeAlias,
    TypedDict,
    Union,
    cast,
)
19
20

import numpy as np
Roger Wang's avatar
Roger Wang committed
21
from PIL.Image import Image
22
from typing_extensions import TypeVar
23

24
from vllm.utils.collection_utils import is_list_of
25
from vllm.utils.import_utils import LazyLoader
26
from vllm.utils.jsontree import json_map_leaves
27

28
29
from .media import MediaWithBytes

30
if TYPE_CHECKING:
31
32
33
34
35
    import torch
    import torch.types
    from transformers.feature_extraction_utils import BatchFeature
else:
    torch = LazyLoader("torch", globals(), "torch")
36

37

38
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
39
"""
40
A `transformers.image_utils.ImageInput` representing a single image
41
item, which can be passed to a HuggingFace `ImageProcessor`.
42
43
"""

44
45
46
HfVideoItem: TypeAlias = Union[
    list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]
]
47
"""
48
A `transformers.image_utils.VideoInput` representing a single video
49
item, which can be passed to a HuggingFace `VideoProcessor`.
50
51
"""

52
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
53
"""
54
Represents a single audio
55
item, which can be passed to a HuggingFace `AudioProcessor`.
56
57
"""

58
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", MediaWithBytes[HfImageItem]]
59
"""
60
A `transformers.image_utils.ImageInput` representing a single image
61
item, which can be passed to a HuggingFace `ImageProcessor`.
62
63
64
65
66
67

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

68
69
70
VideoItem: TypeAlias = Union[
    HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]]
]
71
"""
72
73
74
A `transformers.video_utils.VideoInput` representing a single video item. 
This can be passed to a HuggingFace `VideoProcessor` 
with `transformers.video_utils.VideoMetadata`.
75
76
77
78
79
80

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

81
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"]
82
83
"""
Represents a single audio
84
item, which can be passed to a HuggingFace `AudioProcessor`.
85
86
87
88
89
90
91
92
93
94

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

95

Roger Wang's avatar
Roger Wang committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class VisionChunkImage(TypedDict):
    """Represents an image wrapped as a vision chunk."""

    type: Literal["image"]
    image: Image
    uuid: str | None


class VisionChunkVideo(TypedDict):
    """Represents a video chunk with metadata."""

    type: Literal["video_chunk"]
    video_chunk: list[Image]
    uuid: str | None
    prompt: str
    video_idx: int


114
VisionChunk: TypeAlias = VisionChunkImage | VisionChunkVideo
Roger Wang's avatar
Roger Wang committed
115
116
117
"""A vision chunk is either an image or a video chunk."""


118
119
@dataclass(frozen=True)
class PlaceholderRange:
120
121
122
    """
    Placeholder location information for multi-modal data.

123
124
    Example:

125
    Prompt: `AAAA BBBB What is in these images?`
126

127
    Images A and B will have:
128

129
130
131
132
    ```
    A: PlaceholderRange(offset=0, length=4)
    B: PlaceholderRange(offset=5, length=4)
    ```
133
134
135
136
137
138
139
140
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""

141
    is_embed: "torch.Tensor | None" = None
142
143
144
145
146
    """
    A boolean mask of shape `(length,)` indicating which positions
    between `offset` and `offset + length` to assign embeddings to.
    """

147
    @cached_property
148
149
150
    def embeds_cumsum(self) -> list[int] | None:
        # python list so python indexing avoids torch C++ overhead/conversions/deallocs
        return None if self.is_embed is None else self.is_embed.cumsum(dim=0).tolist()
151
152
153

    def get_num_embeds(self) -> int:
        if self.embeds_cumsum is None:
154
155
            return self.length

156
        return self.embeds_cumsum[-1] if self.embeds_cumsum else 0
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

    def get_embeds_indices_in_range(
        self, start_idx: int, end_idx: int
    ) -> tuple[int, int]:
        """
        Returns the starting and ending indices of the embeddings of encoder outputs
        in the range of [start_idx, end_idx) in the placeholders.

        For example, given:
        PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])

        If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
        the second and the third embeddings from the encoder output.
        """
        if self.embeds_cumsum is None:
            return start_idx, end_idx

174
175
        embeds_start_idx = self.embeds_cumsum[start_idx - 1] if start_idx > 0 else 0
        embeds_end_idx = self.embeds_cumsum[end_idx - 1] if end_idx > 0 else 0
176
177

        return embeds_start_idx, embeds_end_idx
178

179
180
181
182
183
184
185
186
187
188
189
190
191
    def extract_embeds_range(self) -> list[tuple[int, int]]:
        """Extract the start and end indices of the embedded region in prompt.

        For example, given `PlaceholderRange(offset=2, length=5)` and
        `is_embed = [False, True, False, True, True]`, the output is
        `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

        Returns:
            A tuple `(start, end)` representing the start and end
            indices (inclusive) of the embedded region.
            Returns full placeholder range if `is_embed` is `None`.
        """
        if self.is_embed is None:
192
            return [(self.offset, self.offset + self.length - 1)]
193
194
195
196
197
198
199
200
201
202
203

        mask_i = self.is_embed.int()
        starts = torch.nonzero(
            torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
        ).flatten()
        ends = torch.nonzero(
            torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
        ).flatten()
        ranges = torch.stack((starts, ends), dim=1) + self.offset
        return [tuple(x) for x in ranges.tolist()]

204
205
206
207
208
209
210
211
212
213
214
215
216
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if not (self.offset, self.length) == (other.offset, other.length):
            return False

        if self.is_embed is None:
            return other.is_embed is None
        if other.is_embed is None:
            return self.is_embed is None

        return nested_tensors_equal(self.is_embed, other.is_embed)

217

218
219
220
221
222
223
NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]
224
225
226
227
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

228
229

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
230
231
232
233
    """
    Equality check between
    [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
    """
234
    if isinstance(a, torch.Tensor):
235
        return isinstance(b, torch.Tensor) and torch.equal(a, b)
236
    elif isinstance(b, torch.Tensor):
237
        return isinstance(a, torch.Tensor) and torch.equal(b, a)
238
239

    if isinstance(a, list):
240
241
242
243
        return (
            isinstance(b, list)
            and len(a) == len(b)
            and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))
244
        )
245
    if isinstance(b, list):
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        return (
            isinstance(a, list)
            and len(b) == len(a)
            and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))
        )

    if isinstance(a, tuple):
        return (
            isinstance(b, tuple)
            and len(a) == len(b)
            and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))
        )
    if isinstance(b, tuple):
        return (
            isinstance(a, tuple)
            and len(b) == len(a)
            and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))
263
        )
264
265
266
267
268

    # Both a and b are scalars
    return a == b


269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def _nested_tensors_h2d(
    tensors: NestedTensors,
    device: torch.types.Device,
) -> NestedTensors:
    if device is None:
        return tensors

    return json_map_leaves(
        (
            lambda x: x.to(device=device, non_blocking=True)
            if isinstance(x, torch.Tensor)
            else x
        ),
        tensors,
    )


286
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
287
288
"""
A dictionary containing nested tensors which have been batched via
289
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
290
291
292
"""


293
294
295
296
297
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
    """
    Equality check between
    [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
    """
298
    return all(k in b and nested_tensors_equal(a[k], b[k]) for k in a)
299
300


301
302
303
304
@dataclass
class MultiModalFeatureSpec:
    """
    Represents a single multimodal input with its processed data and metadata.
305

306
307
308
    Used to track multimodal data through processing and caching.
    A request containing multiple multimodal items will have one
    `MultiModalFeatureSpec` per item.
309
310
    """

311
    data: "MultiModalKwargsItem | None"
312
313
314
315
316
317
    """
    Represents multimodal data for this feature.

    Can be `None` if the item is cached, to skip IPC between API server
    and engine core processes.
    """
318
319

    modality: str
320
    """The input modality, e.g., `"image"`, `"audio"`, `"video"`."""
321
322

    identifier: str
323
    """The hash for caching encoder outputs (with LoRA prefix if applicable)."""
324
325

    mm_position: PlaceholderRange
326
327
328
329
    """
    The location of the `modality` tokens corresponding to this item
    in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`.
    """
330

331
    mm_hash: str | None = None
332
    """The hash for caching processor outputs (without LoRA prefix)."""
333

334
335
336
337
338
339
340
341
342
343
344
345
346
    @staticmethod
    def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
        kwargs = defaultdict[str, list[NestedTensors]](list)

        for f in features:
            item = f.data
            if item is not None:
                for k in keys:
                    if k in item:
                        kwargs[k].append(item[k].data)

        return dict(kwargs)

347

348
@dataclass
349
class MultiModalFieldElem:
350
    """
351
    Represents a processed keyword argument to pass to a model for a
352
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
353
354
    """

355
    data: NestedTensors
356
    """
357
    The tensor data of this field in
358
    [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
359
    i.e. the value of the keyword argument to be passed to the model.
360
361
362

    It may be set to `None` if it is determined that the item is cached
    in `EngineCore`.
363
364
365
366
367
368
369
    """

    field: "BaseMultiModalField"
    """
    Defines how to combine the tensor data of this field with others
    in order to batch multi-modal items together for model inference.
    """
370
371
372
373
374

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

375
376
377
378
379
380
381
        if self.data is None:
            data_equal = other.data is None
        elif other.data is None:
            data_equal = self.data is None
        else:
            data_equal = nested_tensors_equal(self.data, other.data)

382
        return data_equal and type(self.field) is type(other.field)  # noqa: E721
383
384


385
@dataclass(frozen=True, kw_only=True)
386
class BaseMultiModalField(ABC):
387
    """
388
389
390
    Defines how to interpret tensor data belonging to a keyword argument for
    [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
    and vice versa.
391
392
    """

393
394
    keep_on_cpu: bool = False
    """
395
396
397
    If `True`, then this field is excluded from being moved to the accelerator when
    [`group_and_batch_mm_items`][vllm.multimodal.utils.group_and_batch_mm_items]
    is called to batch the data.
398
399
    """

400
401
    def _field_factory(self):
        f = partial(MultiModalFieldElem, field=self)
402
403
404
405
406
407

        # Allow passing data as positional argument
        def factory(data: NestedTensors) -> MultiModalFieldElem:
            return f(data=data)

        return factory
408
409

    @abstractmethod
410
411
412
413
414
415
416
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
        """
417
418
419
        Construct
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
        instances to represent the provided data.
420

421
422
        This is the inverse of
        [`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
423
        """
424
425
        raise NotImplementedError

426
    @abstractmethod
427
428
429
430
431
432
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
433
        raise NotImplementedError
434

435
436
437
438
    def reduce_data(
        self,
        elems: list[MultiModalFieldElem],
        *,
439
        device: torch.types.Device = None,
440
441
        pin_memory: bool = False,
    ) -> NestedTensors:
442
        """
443
444
        Merge the data from multiple instances of
        [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
445

446
447
        This is the inverse of
        [`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
448
449
450
451
        """
        field_types = [type(item.field) for item in elems]
        if len(set(field_types)) > 1:
            raise ValueError(f"Cannot merge different {field_types=}")
452

453
454
455
456
457
        if device is not None and self.keep_on_cpu:
            device = "cpu"
        if pin_memory and self.keep_on_cpu:
            pin_memory = False

458
        batch = [elem.data for elem in elems]
459
460
        out = self._reduce_data(batch, pin_memory=pin_memory)
        return _nested_tensors_h2d(out, device=device)
461
462


463
@dataclass(frozen=True, kw_only=True)
464
465
class MultiModalBatchedField(BaseMultiModalField):
    """
466
    Info:
467
        [`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
468
469
    """

470
471
472
473
474
475
    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
476
        field_factory = self._field_factory()
477
        return [field_factory(item) for item in data]
478

479
480
481
482
483
484
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
485
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
486
            batch = cast(list[torch.Tensor], batch)
487
488
489
490
491
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.stack(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].unsqueeze(0).contiguous()
492
            first_shape = batch[0].shape
493
            if all(elem.shape == first_shape for elem in batch):
494
495
496
497
498
499
                out = torch.empty(
                    (len(batch), *batch[0].shape),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
500
                return torch.stack(batch, out=out)
501
502
503
504

        return batch


505
@dataclass(frozen=True, kw_only=True)
506
507
class MultiModalFlatField(BaseMultiModalField):
    """
508
    Info:
509
510
        [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
        [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
511
    """
512

513
    slices: Sequence[slice] | Sequence[Sequence[slice]]
514
    dim: int = 0
515

516
    def build_elems(
517
        self,
518
519
520
521
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
522
        field_factory = self._field_factory()
523
        if not is_list_of(self.slices, slice, check="all"):
524
            assert isinstance(data, torch.Tensor), (
525
                "torch.Tensor is required for multiple slices"
526
            )
527
        return [field_factory(data[cast(slice, s)]) for s in self.slices]
528

529
530
531
532
533
534
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
535
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
536
            batch = cast(list[torch.Tensor], batch)
537
538
539
540
541
            if len(batch) == 1:
                # An optimization when `batch` contains only one tensor:
                # - produce exactly same result as `torch.concat(batch)`
                # - will achieve zero-copy if the tensor is contiguous
                return batch[0].contiguous()
542

543
544
545
            dim = self.dim + (self.dim < 0) * len(batch[0].shape)

            def _shape_before_after(tensor: torch.Tensor):
546
                return tensor.shape[:dim], tensor.shape[dim + 1 :]
547

548
            first_shape = _shape_before_after(batch[0])
549

550
551
552
            if all(_shape_before_after(elem) == first_shape for elem in batch):
                shape_before, shape_after = first_shape
                shape_concat = sum(item.shape[dim] for item in batch)
553
554
555
556
557
558
                out = torch.empty(
                    (*shape_before, shape_concat, *shape_after),
                    dtype=batch[0].dtype,
                    device=batch[0].device,
                    pin_memory=pin_memory,
                )
559
                return torch.concat(batch, dim=self.dim, out=out)
560

561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
            # Variable-length case: non-concat dimensions differ
            # (e.g., Ultravox with different audio durations).
            # Use slice-assign approach (more efficient than padding).
            # See: https://github.com/vllm-project/vllm/issues/31658

            ndim = batch[0].ndim

            # Step 1: Compute output shape
            # - Non-concat dims: take max across batch
            # - Concat dim: sum across batch
            max_sizes: list[int] = []
            for d in range(ndim):
                if d == dim:
                    max_sizes.append(sum(t.shape[d] for t in batch))
                else:
                    max_sizes.append(max(t.shape[d] for t in batch))

            # Step 2: Create zero-initialized output tensor
            out = torch.zeros(
                max_sizes,
                dtype=batch[0].dtype,
                device=batch[0].device,
                pin_memory=pin_memory,
            )

            # Step 3: Slice-assign each tensor to its proper position
            concat_offset = 0
            for tensor in batch:
                slices: list[slice] = []
                for d in range(ndim):
                    if d == dim:
                        slices.append(
                            slice(concat_offset, concat_offset + tensor.shape[d])
                        )
                    else:
                        slices.append(slice(0, tensor.shape[d]))
                out[tuple(slices)] = tensor
                concat_offset += tensor.shape[dim]

            return out

602
        assert self.dim == 0, "dim == 0 is required for nested list"
603
        return [e for elem in batch for e in elem]
604
605


606
@dataclass(frozen=True, kw_only=True)
607
608
class MultiModalSharedField(BaseMultiModalField):
    """
609
    Info:
610
        [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
611
    """
612

613
614
615
616
617
618
619
620
    batch_size: int

    def build_elems(
        self,
        modality: str,
        key: str,
        data: NestedTensors,
    ) -> Sequence[MultiModalFieldElem]:
621
        field_factory = self._field_factory()
622
623
        return [field_factory(data)] * self.batch_size

624
625
626
627
628
629
    def _reduce_data(
        self,
        batch: list[NestedTensors],
        *,
        pin_memory: bool,
    ) -> NestedTensors:
630
631
632
        return batch[0]


633
@dataclass(frozen=True)
634
635
class MultiModalFieldConfig:
    @staticmethod
636
    def batched(modality: str, *, keep_on_cpu: bool = False):
637
638
639
640
641
642
643
        """
        Defines a field where an element in the batch is obtained by
        indexing into the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
644
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
645
646
647

        Example:

648
649
650
651
652
653
654
655
656
657
658
        ```
        Input:
            Data: [[AAAA]
                [BBBB]
                [CCCC]]

        Output:
            Element 1: [AAAA]
            Element 2: [BBBB]
            Element 3: [CCCC]
        ```
659
        """
660
        return MultiModalFieldConfig(
661
            field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
662
663
664
665
            modality=modality,
        )

    @staticmethod
666
667
    def flat(
        modality: str,
668
        slices: Sequence[slice] | Sequence[Sequence[slice]],
669
        dim: int = 0,
670
671
        *,
        keep_on_cpu: bool = False,
672
    ):
673
674
675
676
677
678
679
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
680
            slices: For each multi-modal item, a slice (dim=0) or a tuple of
681
                slices (dim>0) that is used to extract the data corresponding
682
683
                to it.
            dim: The dimension to extract data, default to 0.
684
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
685
686
687

        Example:

688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        ```
        Given:
            slices: [slice(0, 3), slice(3, 7), slice(7, 9)]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
            slices: [
                (slice(None), slice(0, 3)),
                (slice(None), slice(3, 7)),
                (slice(None), slice(7, 9))]
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```
717
        """
718
        return MultiModalFieldConfig(
719
720
721
722
723
            field=MultiModalFlatField(
                slices=slices,
                dim=dim,
                keep_on_cpu=keep_on_cpu,
            ),
724
725
726
            modality=modality,
        )

727
    @staticmethod
728
729
730
731
732
733
734
    def flat_from_sizes(
        modality: str,
        size_per_item: "torch.Tensor",
        dim: int = 0,
        *,
        keep_on_cpu: bool = False,
    ):
735
736
737
738
739
740
741
        """
        Defines a field where an element in the batch is obtained by
        slicing along the first dimension of the underlying data.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
742
743
            size_per_item: For each multi-modal item, the size of the slice
                that is used to extract the data corresponding to it.
744
            dim: The dimension to slice, default to 0.
745
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
746
747
748

        Example:

749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
        ```
        Given:
            size_per_item: [3, 4, 2]

        Input:
            Data: [AAABBBBCC]

        Output:
            Element 1: [AAA]
            Element 2: [BBBB]
            Element 3: [CC]
        ```

        ```
        Given:
764
            size_per_item: [3, 4, 2]
765
766
767
768
769
770
771
772
773
774
775
            dim: 1

        Input:
            Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]

        Output:
            Element 1: [[A],[A],[A]]
            Element 2: [[B],[B],[B],[B]]
            Element 3: [[C],[C]]
        ```

776
        Info:
777
            [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
778
779
        """

780
        if size_per_item.ndim != 1:
781
782
783
784
            raise ValueError(
                "size_per_item should be a 1-D tensor, "
                f"but found shape: {size_per_item.shape}"
            )
785

786
        slice_idxs = [0, *accumulate(size_per_item)]
787
788
789
790
791
        slices = [
            (slice(None, None, None),) * dim
            + (slice(slice_idxs[i], slice_idxs[i + 1]),)
            for i in range(len(size_per_item))
        ]
792

793
794
795
796
797
798
        return MultiModalFieldConfig.flat(
            modality,
            slices,
            dim=dim,
            keep_on_cpu=keep_on_cpu,
        )
799

800
    @staticmethod
801
802
803
804
805
806
    def shared(
        modality: str,
        batch_size: int,
        *,
        keep_on_cpu: bool = False,
    ):
807
808
809
810
811
812
813
814
815
816
        """
        Defines a field where an element in the batch is obtained by
        taking the entirety of the underlying data.

        This means that the data is the same for each element in the batch.

        Args:
            modality: The modality of the multi-modal item that uses this
                keyword argument.
            batch_size: The number of multi-modal items which share this data.
817
            keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
818
819
820

        Example:

821
822
823
        ```
        Given:
            batch_size: 4
824

825
826
        Input:
            Data: [XYZ]
827

828
829
830
831
832
833
        Output:
            Element 1: [XYZ]
            Element 2: [XYZ]
            Element 3: [XYZ]
            Element 4: [XYZ]
        ```
834
835
        """
        return MultiModalFieldConfig(
836
837
838
839
            field=MultiModalSharedField(
                batch_size=batch_size,
                keep_on_cpu=keep_on_cpu,
            ),
840
841
842
            modality=modality,
        )

843
844
    field: BaseMultiModalField
    modality: str
845

846
    def build_elems(
847
848
849
        self,
        key: str,
        batch: NestedTensors,
850
    ) -> Sequence[MultiModalFieldElem]:
851
        return self.field.build_elems(self.modality, key, batch)
852
853


854
855
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
856
857
    A dictionary of processed keyword arguments to pass to the model,
    corresponding to a single item in
858
    [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
859
    """
860

861
    @staticmethod
862
    def dummy(nbytes: int = 1):
863
864
        """Convenience class for testing."""
        mm_elem = MultiModalFieldElem(
865
            data=torch.empty(nbytes, dtype=torch.uint8),
866
            field=MultiModalSharedField(batch_size=1),
867
        )
868
        return MultiModalKwargsItem({"dummy": mm_elem})
869

870
    def get_data(self) -> dict[str, NestedTensors]:
871
        return {key: elem.data for key, elem in self.items()}
872
873


874
875
876
_I = TypeVar(
    "_I",
    MultiModalKwargsItem,
877
    MultiModalKwargsItem | None,
878
879
880
881
882
    default=MultiModalKwargsItem,
)


class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
883
    """
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
    A dictionary of processed multi-modal inputs by modality.

    For example, given a processor that processes
    images into `pixel_values` and `image_grid_thw`,
    and audios into `input_audio_features`,
    a prompt with 2 images and 1 audio will be processed
    into a `MultiModalKwargsItems` with the following structure:

    ```python
    MultiModalKwargsItems(
        {
            "image": [
                # For the first image
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
                # For the second imgae
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
            ],
            "audio": [
                # For the first audio
                MultiModalKwargsItem({"input_audio_features": ...}),
            ],
        }
    )
    ```

    Unlike HF processing which returns all items
    in a single dictionary with batched keyword arguments,
    we split up the items because some of them may already be cached.
    Also, items from multiple requests may be batched together to improve throughput,
    using the logic defined by the
    [`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
    for each keyword argument.
916
917
    """

918
919
    @staticmethod
    def from_hf_inputs(
920
        hf_inputs: "BatchFeature",
921
922
923
924
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
925
926
927
928
929
930
931
932
933
934
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

935
        items_by_modality = dict[str, list[MultiModalKwargsItem]]()
936
937
938
939
940
941
942
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
943
944
                    f"Found: {batch_sizes=}"
                )
945
946

            batch_size = next(iter(batch_sizes.values()))
947
948
949
950
            items_by_modality[modality] = [
                MultiModalKwargsItem({k: v[i] for k, v in elems_in_modality.items()})
                for i in range(batch_size)
            ]
951

952
        return MultiModalKwargsItems(items_by_modality)
953

954
    def __getitem__(self, modality: str) -> Sequence[_I]:
955
        if modality not in self:
956
957
958
959
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )
960

961
        return super().__getitem__(modality)  # type: ignore[return-value]
962

963
964
965
966
    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
967
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")
968
969
970

        return self  # type: ignore[return-value]

971
972
973
974
975
976
977
    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
        from .utils import group_and_batch_mm_items

        items_by_modality = self.require_data()
        batches_by_modality = {
            modality: [
                data
                for _, data in group_and_batch_mm_items(
                    items,
                    device=device,
                    pin_memory=pin_memory,
                )
            ]
            for modality, items in items_by_modality.items()
            if len(items) > 0
        }
993

994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
        out_data: BatchedTensorInputs = {}
        for _, batches in batches_by_modality.items():
            if len(batches) != 1:
                num_batches_by_modality = {
                    modality: len(batches)
                    for modality, batches in batches_by_modality.items()
                }

                raise RuntimeError(
                    f"Some modalities cannot be merged into a single batch "
                    f"({num_batches_by_modality=})"
                )
1006

1007
            out_data.update(batches[0])
1008

1009
        return out_data
1010

1011

1012
1013
1014
1015
MultiModalKwargsOptionalItems: TypeAlias = (
    MultiModalKwargsItems[MultiModalKwargsItem]
    | MultiModalKwargsItems[MultiModalKwargsItem | None]
)