inputs.py 17.1 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from collections import UserDict, defaultdict
3
4
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
5
6
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
                    final)
7
8
9
10
11

import numpy as np
import torch
import torch.types
from PIL.Image import Image
12
from transformers import BatchFeature
13
from typing_extensions import NotRequired, TypeAlias
14

15
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
16
17
18

_T = TypeVar("_T")

19
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
20
"""
21
22
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
23
24
"""

25
26
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
                               list[np.ndarray], list[torch.Tensor]]
27
"""
28
29
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
30
31
"""

32
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
33
"""
34
35
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
36
37
"""

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
                             torch.Tensor]
"""
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

ModalityData: TypeAlias = Union[_T, list[_T]]
74
75
76
77
78
79
80
81
82
83
84
85
"""
Either a single data item, or a list of data items.

The number of data items allowed per modality is restricted by
:code:`--limit-mm-per-prompt`.
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

86
    image: ModalityData[ImageItem]
87
88
    """The input image(s)."""

89
    video: ModalityData[VideoItem]
90
91
    """The input video(s)."""

92
    audio: ModalityData[AudioItem]
93
94
95
    """The input audio(s)."""


96
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
97
98
99
100
101
102
103
"""
A dictionary containing an entry for each modality type to input.

Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin
    is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
104
    Read more on that :ref:`here <adding-multimodal-plugin>`.
105
106
107
108
109
110
111
"""


class PlaceholderRange(TypedDict):
    """
    Placeholder location information for multi-modal data.

112
113
114
115
    Example:

        Prompt: :code:`AAAA BBBB What is in these images?`

116
        Images A and B will have:
117
118
119

        .. code-block::

120
121
122
123
124
125
126
127
128
129
130
            A: { "offset": 0, "length": 4 }
            B: { "offset": 5, "length": 4 }
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""


131
132
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
                      tuple[torch.Tensor, ...]]
133
134
135
136
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
    """Equality check between :data:`NestedTensors` objects."""
    if isinstance(a, torch.Tensor):
        return isinstance(b, torch.Tensor) and bool((a == b).all().item())
    elif isinstance(b, torch.Tensor):
        return isinstance(a, torch.Tensor) and bool((b == a).all().item())

    if isinstance(a, list):
        return (isinstance(b, list)
                and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
    if isinstance(b, list):
        return (isinstance(a, list)
                and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))

    # Both a and b are scalars
    return a == b


BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
157
158
159
160
161
162
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalKwargs.batch`.
"""


163
@dataclass(frozen=True)
164
165
class MultiModalFieldElem:
    """Contains metadata and data of an item in :class:`MultiModalKwargs`."""
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    field: "BaseMultiModalField"
    data: NestedTensors

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

        return (self.field == other.field
                and nested_tensors_equal(self.data, other.data))


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
    """Abstract base class for a field in :class:`MultiModalKwargs`."""
    key: str
    modality: str

    @abstractmethod
    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        raise NotImplementedError

187
188
    def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
        return MultiModalFieldElem(self, data)
189

190
191
    def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
        """Merge multiple instances of :class:`MultiModalFieldElem` together."""
192
193
194
195
196
197
        fields = [item.field for item in batch]
        if len(set(fields)) > 1:
            raise ValueError(f"Cannot merge different {fields=}")

        data = self._reduce_data([item.data for item in batch])

198
        return self._build_elem(data)
199
200
201
202
203


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
204
205
    A :class:`BaseMultiModalField` implementation where an element in the batch
    is obtained by indexing into the first dimension of the underlying data.
206
207
    """

208
209
    def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
        return [self._build_elem(item) for item in batch]
210
211
212
213

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
            first_shape = batch[0].shape
214
            if all(elem.shape == first_shape for elem in batch):
215
216
217
218
219
220
221
222
                return torch.stack(batch)

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
223
224
    A :class:`BaseMultiModalField` implementation where an element in the batch
    is obtained by slicing along the first dimension of the underlying data.
225
226
    """

227
    def build_elems(
228
229
230
        self,
        batch: NestedTensors,
        slices: Sequence[slice],
231
232
    ) -> list[MultiModalFieldElem]:
        return [self._build_elem(batch[slice_]) for slice_ in slices]
233
234
235
236

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
            first_shape = batch[0].shape
237
            if all(elem.shape[1:] == first_shape[1:] for elem in batch):
238
239
                return torch.concat(batch)

240
        return [e for elem in batch for e in elem]
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267


class MultiModalFieldConfig:

    @staticmethod
    def batched(modality: str):
        return MultiModalFieldConfig(
            field_cls=MultiModalBatchedField,
            modality=modality,
        )

    @staticmethod
    def flat(modality: str, slices: Sequence[slice]):
        return MultiModalFieldConfig(
            field_cls=MultiModalFlatField,
            modality=modality,
            slices=slices,
        )

    def __init__(
        self,
        field_cls: type[BaseMultiModalField],
        modality: str,
        **field_config: Any,
    ) -> None:
        super().__init__()

268
269
270
        self.field_cls = field_cls
        self.modality = modality
        self.field_config = field_config
271

272
    def build_elems(
273
274
275
        self,
        key: str,
        batch: NestedTensors,
276
277
278
    ) -> Sequence[MultiModalFieldElem]:
        field = self.field_cls(key=key, modality=self.modality)
        return field.build_elems(batch, **self.field_config)  # type: ignore
279
280


281
282
283
284
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
    A collection of :class:`MultiModalFieldElem`
    corresponding to a data item in :class:`MultiModalDataItems`.
285
    """
286

287
288
289
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
        return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
290

291
292
293
294
295
    @property
    def modality(self) -> str:
        modalities = {elem.field.modality for elem in self.data.values()}
        assert len(modalities) == 1, f"Found different modalities={modalities}"
        return next(iter(modalities))
296
297


298
299
300
301
302
303
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
304

305
306
307
    The metadata :code:`items` enables us to obtain the keyword arguments
    corresponding to each data item in :class:`MultiModalDataItems`, via
    :meth:`get_item` and :meth:`get_items`.
308
309
    """

310
311
312
313
314
315
316
    @staticmethod
    def from_hf_inputs(
        hf_inputs: BatchFeature,
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
                    f"Found: {batch_sizes=}")

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

        return MultiModalKwargs.from_items(items)
343
344

    @staticmethod
345
346
347
348
349
350
351
    def from_items(items: Sequence[MultiModalKwargsItem]):
        """Construct a new :class:`MultiModalKwargs` from multiple items."""
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
        for item in items:
            for key, elem in item.items():
                elems_by_key[key].append(elem)

352
        data = {
353
354
            key: elems[0].field.reduce(elems).data
            for key, elems in elems_by_key.items() if len(elems) > 0
355
356
        }

357
        return MultiModalKwargs(data, items=items)
358
359
360
361
362

    def __init__(
        self,
        data: Mapping[str, NestedTensors],
        *,
363
        items: Optional[Sequence[MultiModalKwargsItem]] = None,
364
365
366
    ) -> None:
        super().__init__(data)

367
368
        items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
        self._items_by_modality = dict(items_by_modality)
369

370
371
372
    @property
    def modalities(self):
        return self._items_by_modality.keys()
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    @staticmethod
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

        stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

397
        tensors_ = cast(list[torch.Tensor], stacked)
398
399
400
401
402
403
404
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

        return torch.stack(tensors_)

    @staticmethod
405
    def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
420
        item_lists = defaultdict[str, list[NestedTensors]](list)
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
            k: MultiModalKwargs._try_stack(item_list)
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
            lambda x: x.to(device, non_blocking=True),
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)

446
447
448
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
449
        if self._items_by_modality != other._items_by_modality:
450
451
452
453
454
455
            return False

        ks = self.keys()
        return (ks == other.keys()
                and all(nested_tensors_equal(self[k], other[k]) for k in ks))

456
457
458
459
460
    def _validate_modality(self, method_name: str, modality: str) -> None:
        if not self._items_by_modality:
            raise RuntimeError(
                f"`{method_name}` is not supported when "
                "MultiModalKwargs is not initialized with `items`")
461

462
463
        if modality not in self._items_by_modality:
            available_modalities = set(self._items_by_modality.keys())
464
465
466
            raise KeyError(f"Modality {modality!r} not found. "
                           f"Available modalities: {available_modalities}")

467
468
469
470
    def get_item_count(self, modality: str) -> int:
        """Get the number of items belonging to a modality."""
        self._validate_modality("get_item_count", modality)
        return len(self._items_by_modality[modality])
471

472
473
474
475
476
477
478
    def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
        """
        Get the keyword arguments corresponding to an item identified by
        its modality and index.
        """
        self._validate_modality("get_item", modality)
        return self._items_by_modality[modality][item_index]
479

480
    def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
481
        """
482
483
        Get the keyword arguments corresponding to each item belonging to
        a modality.
484
        """
485
486
        self._validate_modality("get_items", modality)
        return self._items_by_modality[modality]
487

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504

MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
"""


class MultiModalInputsV2(TypedDict):
    """
    Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`,
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

    prompt: str
505
    """The processed prompt text."""
506

507
    prompt_token_ids: list[int]
508
509
    """The processed token IDs which includes placeholder tokens."""

510
    token_type_ids: NotRequired[list[int]]
511
512
    """The token type IDs of the prompt."""

513
514
515
    mm_kwargs: MultiModalKwargs
    """Keyword arguments to be directly passed to the model after batching."""

516
    mm_hashes: NotRequired[list[str]]
517
518
    """The hashes of the multi-modal data."""

519
520
521
522
523
    mm_placeholders: MultiModalPlaceholderDict
    """
    For each modality, information about the placeholder tokens in
    :code:`prompt_token_ids`.
    """