inputs.py 20.1 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from collections import UserDict, defaultdict
3
4
5
6
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import (Any, Literal, NamedTuple, TypedDict, TypeVar, Union, cast,
                    final)
7
8
9
10
11

import numpy as np
import torch
import torch.types
from PIL.Image import Image
12
13
from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias, assert_never
14
15
16
17
18
19
20
21

from vllm.utils import JSONTree, is_list_of, json_map_leaves

_T = TypeVar("_T")

# yapf: disable
ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
"""
22
23
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
24
25
26
"""

VideoItem: TypeAlias = Union[
27
    list[Image],
28
29
    np.ndarray,
    torch.Tensor,
30
31
    list[np.ndarray],
    list[torch.Tensor],
32
33
]
"""
34
35
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
36
37
38
39
"""

AudioItem: TypeAlias = Union[
    np.ndarray,
40
41
42
43
    list[float],
    # `(audio, sampling_rate)`: If the audio's sampling rate is different
    # from that expected by the model, we need to resample it.
    tuple[np.ndarray, float],
44
45
]
"""
46
47
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
48
49
50
"""
# yapf: enable

51
MultiModalData: TypeAlias = Union[_T, list[_T]]
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Either a single data item, or a list of data items.

The number of data items allowed per modality is restricted by
:code:`--limit-mm-per-prompt`.
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

    image: MultiModalData[ImageItem]
    """The input image(s)."""

    video: MultiModalData[VideoItem]
    """The input video(s)."""

    audio: MultiModalData[AudioItem]
    """The input audio(s)."""


MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]]
"""
A dictionary containing an entry for each modality type to input.

Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin
    is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
82
    Read more on that :ref:`here <adding-multimodal-plugin>`.
83
84
85
"""


86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class ImageSize(NamedTuple):
    width: int
    height: int


class MultiModalDataItems(UserDict[str, list[Any]]):
    """
    As :class:`MultiModalDataDict`, but normalized such that each entry
    corresponds to a list.
    """

    @staticmethod
    def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
        """
        multi_data = MultiModalDataItems()

        for k, v in data.items():
            # TODO: Make a separate modality for embedding inputs
            # to avoid confusion
            # yapf: disable
            if k == "video":
                # Special case since even a single item can be a list
                multi_data[k] = (  # type: ignore[index]
                    v if (
                        isinstance(v, torch.Tensor)
                        or is_list_of(v, list)
                        or isinstance(v[0], (np.ndarray, torch.Tensor))
                           and v[0].ndim == 4
                    ) else [v]
                )
            elif k in ("image", "audio"):
                multi_data[k] = (  # type: ignore[index]
                    v if isinstance(v, (torch.Tensor, list)) else [v]
                )
            else:
                multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
            # yapf: enable

        return multi_data

    # NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
    # `self.images` doesn't update this dictionary, which may be confusing
    # We annotate the getter methods as `Sequence` to prevent others from
    # trying to update the list in this way
    @property
    def images(self) -> Sequence[ImageItem]:
        return self.get("image", [])

    @property
    def videos(self) -> Sequence[VideoItem]:
        return self.get("video", [])

    @property
    def audios(self) -> Sequence[AudioItem]:
        return self.get("audio", [])

    def get_item_counts(self) -> Mapping[str, int]:
        return {m: len(items) for m, items in self.items()}

    def has_embedding_inputs(self) -> bool:
        return any(
            any(isinstance(item, torch.Tensor) for item in items)
            for items in self.values())

    def get_image_size(self, item_idx: int) -> ImageSize:
        image = self.images[item_idx]

        if isinstance(image, Image):
            return ImageSize(*image.size)
        if isinstance(image, (np.ndarray, torch.Tensor)):
            _, h, w = image.shape
            return ImageSize(w, h)

        assert_never(image)

    def get_audio_with_sr(
        self,
        item_idx: int,
        *,
        default_sr: float,
    ) -> tuple[np.ndarray, float]:
        audio = self.audios[item_idx]

        if isinstance(audio, tuple):
            return audio
        if isinstance(audio, list):
            return np.array(audio), default_sr
        if isinstance(audio, np.ndarray):
            return audio, default_sr

        assert_never(audio)

    def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
        """
        If :code:`drop_sr=True`, the audio items in this dictionary are updated
        to be NumPy arrays which implicitly means that their sampling rate is
        the same as the model's expected sampling rate; otherwise, they remain
        as :code:`(audio, new_sr)` tuples.
        """
        # Avoid circular import
        from .audio import resample_audio

        if not self.audios:
            return

        new_audios = []
        for item_idx in range(len(self.audios)):
            audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
            audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)

            new_audios.append(audio if drop_sr else (audio, new_sr))

        self["audio"] = new_audios


203
204
205
206
class PlaceholderRange(TypedDict):
    """
    Placeholder location information for multi-modal data.

207
208
209
210
    Example:

        Prompt: :code:`AAAA BBBB What is in these images?`

211
        Images A and B will have:
212
213
214

        .. code-block::

215
216
217
218
219
220
221
222
223
224
225
            A: { "offset": 0, "length": 4 }
            B: { "offset": 5, "length": 4 }
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""


226
227
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
                      tuple[torch.Tensor, ...]]
228
229
230
231
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
    """Equality check between :data:`NestedTensors` objects."""
    if isinstance(a, torch.Tensor):
        return isinstance(b, torch.Tensor) and bool((a == b).all().item())
    elif isinstance(b, torch.Tensor):
        return isinstance(a, torch.Tensor) and bool((b == a).all().item())

    if isinstance(a, list):
        return (isinstance(b, list)
                and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
    if isinstance(b, list):
        return (isinstance(a, list)
                and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))

    # Both a and b are scalars
    return a == b


BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
252
253
254
255
256
257
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalKwargs.batch`.
"""


258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
@dataclass(frozen=True)
class MultiModalFieldItem:
    """
    Contains metadata and data in :class:`MultiModalKwargs`
    corresponding to a data item in :class:`MultiModalDataItems`.
    """
    field: "BaseMultiModalField"
    data: NestedTensors

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

        return (self.field == other.field
                and nested_tensors_equal(self.data, other.data))


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
    """Abstract base class for a field in :class:`MultiModalKwargs`."""
    key: str
    modality: str

    @abstractmethod
    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        raise NotImplementedError

    def _build_item(self, data: NestedTensors) -> MultiModalFieldItem:
        return MultiModalFieldItem(self, data)

    def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem:
        """Merge multiple instances of :class:`MultiModalFieldItem` together."""
        fields = [item.field for item in batch]
        if len(set(fields)) > 1:
            raise ValueError(f"Cannot merge different {fields=}")

        data = self._reduce_data([item.data for item in batch])

        return self._build_item(data)


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
    A :class:`BaseMultiModalField` implementation where an item is obtained by
    directly indexing into the first dimension of the underlying data.
    """

    def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]:
        return [self._build_item(item) for item in batch]

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
            first_shape = batch[0].shape
            if all(item.shape == first_shape for item in batch):
                return torch.stack(batch)

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
    A :class:`BaseMultiModalField` implementation where an item is obtained by
    slicing along the first dimension of the underlying data.
    """

    def build_items(
        self,
        batch: NestedTensors,
        slices: Sequence[slice],
    ) -> list[MultiModalFieldItem]:
        return [self._build_item(batch[slice_]) for slice_ in slices]

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
            first_shape = batch[0].shape
            if all(item.shape[1:] == first_shape[1:] for item in batch):
                return torch.concat(batch)

        return [elem for item in batch for elem in item]


class MultiModalFieldConfig:

    @staticmethod
    def batched(modality: str):
        return MultiModalFieldConfig(
            field_cls=MultiModalBatchedField,
            modality=modality,
        )

    @staticmethod
    def flat(modality: str, slices: Sequence[slice]):
        return MultiModalFieldConfig(
            field_cls=MultiModalFlatField,
            modality=modality,
            slices=slices,
        )

    def __init__(
        self,
        field_cls: type[BaseMultiModalField],
        modality: str,
        **field_config: Any,
    ) -> None:
        super().__init__()

        self._field_cls = field_cls
        self._modality = modality
        self._field_config = field_config

    def build_items(
        self,
        key: str,
        batch: NestedTensors,
    ) -> list[MultiModalFieldItem]:
        field = self._field_cls(key=key, modality=self._modality)
        return field.build_items(batch, **self._field_config)  # type: ignore


379
380
381
382
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408

    The metadata :code:`items_by_key` defines how to split batched keyword
    arguments corresponding to each data item in :class:`MultiModalDataItems`:

    - For a keyword argument, we can access the :code:`i` th item in the batch
      via :code:`items_by_key[key][i]`.
    - We can gather the keyword arguments belonging to a modality by finding
      the keys with items that belong to that modality, then accessing
      the :code:`i` th item in the batch for each such key.

    Example:

        .. code-block:: python

            # All items belong to the "image" modality
            items_by_key={
                "pixel_values": [a, b, c, d],  # "image" modality
                "image_grid_thw": [e, f, g, h],  # "image" modality
                "pixel_values_video": [h, i, j],  # "video" modality
                "video_grid_thw": [k, l, m],  # "video" modality
            }

        - The keyword arguments belonging to the first image are
          :code:`{"pixel_values": a, "image_grid_thw": e}`.
        - The keyword arguments belonging to the second video are
          :code:`{"pixel_values_video": i, "video_grid_thw": l}`.
409
410
    """

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    @staticmethod
    def from_hf_inputs(
        hf_inputs: BatchFeature,
        config_by_key: Mapping[str, MultiModalFieldConfig],
        *,
        enable_sanity_checks: bool = False,
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
        items_by_key = {
            key: config.build_items(key, batch)
            for key, config in config_by_key.items()
            if (batch := hf_inputs.get(key)) is not None
        }

        return MultiModalKwargs.from_items_by_key(
            items_by_key,
            enable_sanity_checks=enable_sanity_checks,
        )

    @staticmethod
    def from_items_by_key(
        items_by_key: Mapping[str, list[MultiModalFieldItem]],
        *,
        enable_sanity_checks: bool = False,
    ) -> "MultiModalKwargs":
        data = {
            key: items[0].field.reduce(items).data
            for key, items in items_by_key.items()
        }

        return MultiModalKwargs(data,
                                items_by_key=items_by_key,
                                enable_sanity_checks=enable_sanity_checks)

    def __init__(
        self,
        data: Mapping[str, NestedTensors],
        *,
        items_by_key: Mapping[str, list[MultiModalFieldItem]] = {},
        enable_sanity_checks: bool = False,
    ) -> None:
        super().__init__(data)

        # Shallow copy to avoid footgun in case a defaultdict is passed in
        self._items_by_key = dict(items_by_key)

        keys_by_modality = defaultdict[str, set[str]](set)
        for key, items in items_by_key.items():
            for item in items:
                keys_by_modality[item.field.modality].add(key)

        self._keys_by_modality = dict(keys_by_modality)

        if enable_sanity_checks:
            for modality, keys in keys_by_modality.items():
                items_in_modality = {k: items_by_key[k] for k in keys}
                batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
                batch_size = next(iter(batch_sizes.values()), 0)
                assert all(bs == batch_size
                           for bs in batch_sizes.values()), dict(
                               modality=modality,
                               batch_sizes=batch_sizes,
                               items_by_key=items_by_key)

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    @staticmethod
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

        stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

499
        tensors_ = cast(list[torch.Tensor], stacked)
500
501
502
503
504
505
506
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

        return torch.stack(tensors_)

    @staticmethod
507
    def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
522
        item_lists = defaultdict[str, list[NestedTensors]](list)
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
            k: MultiModalKwargs._try_stack(item_list)
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
            lambda x: x.to(device, non_blocking=True),
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)

548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        if self._items_by_key != other._items_by_key:
            return False

        ks = self.keys()
        return (ks == other.keys()
                and all(nested_tensors_equal(self[k], other[k]) for k in ks))

    def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
        return self._items_by_key[key][item_index]

    def get_items_by_modality(
        self,
        modality: str,
        item_index: int,
    ) -> Mapping[str, MultiModalFieldItem]:
        """
        Get the keyword arguments corresponding to an item identified by
        its modality and index.
        """
        keys_to_gather = self._keys_by_modality[modality]

        return {
            key: self.get_item(key, item_index)
            for key in keys_to_gather if key in self
        }

    @staticmethod
    def from_items_by_modality(
        items_by_modality: Mapping[str, list[Mapping[str,
                                                     MultiModalFieldItem]]],
        *,
        enable_sanity_checks: bool = False,
    ) -> "MultiModalKwargs":
        """
        Construct a new :class:`MultiModalKwargs` from multiple items returned
        by :meth:`get_fields_by_modality`.
        """
        items_by_key = defaultdict[str, list[MultiModalFieldItem]](list)
        for fields in items_by_modality.values():
            for field in fields:
                for k, v in field.items():
                    items_by_key[k].append(v)

        return MultiModalKwargs.from_items_by_key(
            items_by_key,
            enable_sanity_checks=enable_sanity_checks,
        )

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615

MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
"""


class MultiModalInputsV2(TypedDict):
    """
    Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`,
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

    prompt: str
616
    """The processed prompt text."""
617

618
    prompt_token_ids: list[int]
619
620
    """The processed token IDs which includes placeholder tokens."""

621
    token_type_ids: NotRequired[list[int]]
622
623
    """The token type IDs of the prompt."""

624
625
626
    mm_kwargs: MultiModalKwargs
    """Keyword arguments to be directly passed to the model after batching."""

627
    mm_hashes: NotRequired[list[str]]
628
629
    """The hashes of the multi-modal data."""

630
631
632
633
634
    mm_placeholders: MultiModalPlaceholderDict
    """
    For each modality, information about the placeholder tokens in
    :code:`prompt_token_ids`.
    """