"tests/models/quantization/test_gguf.py" did not exist on "3c9817d297432c39c89fda666291de1f440838f0"
inputs.py 17 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from collections import UserDict, defaultdict
3
4
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
5
6
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
                    Union, cast, final)
7
8
9
10
11

import numpy as np
import torch
import torch.types
from PIL.Image import Image
12
from transformers import BatchFeature
13
from typing_extensions import NotRequired, TypeAlias
14

15
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
16

17
18
if TYPE_CHECKING:
    from .hasher import MultiModalHashDict
19
20
21

_T = TypeVar("_T")

22
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
23
"""
24
25
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
26
27
"""

28
29
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
                               list[np.ndarray], list[torch.Tensor]]
30
"""
31
32
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
33
34
"""

35
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
36
"""
37
38
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
39
40
"""

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""

VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""

AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
                             torch.Tensor]
"""
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.

Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.

Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""

ModalityData: TypeAlias = Union[_T, list[_T]]
77
78
79
80
81
82
83
84
85
86
87
88
"""
Either a single data item, or a list of data items.

The number of data items allowed per modality is restricted by
:code:`--limit-mm-per-prompt`.
"""


@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

89
    image: ModalityData[ImageItem]
90
91
    """The input image(s)."""

92
    video: ModalityData[VideoItem]
93
94
    """The input video(s)."""

95
    audio: ModalityData[AudioItem]
96
97
98
    """The input audio(s)."""


99
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
100
101
102
"""
A dictionary containing an entry for each modality type to input.

103
The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
104
105
106
107
108
109
110
"""


class PlaceholderRange(TypedDict):
    """
    Placeholder location information for multi-modal data.

111
112
113
114
    Example:

        Prompt: :code:`AAAA BBBB What is in these images?`

115
        Images A and B will have:
116
117
118

        .. code-block::

119
120
121
122
123
124
125
126
127
128
129
            A: { "offset": 0, "length": 4 }
            B: { "offset": 5, "length": 4 }
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""


130
131
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
                      tuple[torch.Tensor, ...]]
132
133
134
135
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
    """Equality check between :data:`NestedTensors` objects."""
    if isinstance(a, torch.Tensor):
        return isinstance(b, torch.Tensor) and bool((a == b).all().item())
    elif isinstance(b, torch.Tensor):
        return isinstance(a, torch.Tensor) and bool((b == a).all().item())

    if isinstance(a, list):
        return (isinstance(b, list)
                and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
    if isinstance(b, list):
        return (isinstance(a, list)
                and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))

    # Both a and b are scalars
    return a == b


BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
156
157
158
159
160
161
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalKwargs.batch`.
"""


162
@dataclass(frozen=True)
163
164
class MultiModalFieldElem:
    """Contains metadata and data of an item in :class:`MultiModalKwargs`."""
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    field: "BaseMultiModalField"
    data: NestedTensors

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False

        return (self.field == other.field
                and nested_tensors_equal(self.data, other.data))


@dataclass(frozen=True)
class BaseMultiModalField(ABC):
    """Abstract base class for a field in :class:`MultiModalKwargs`."""
    key: str
    modality: str

    @abstractmethod
    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        raise NotImplementedError

186
187
    def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
        return MultiModalFieldElem(self, data)
188

189
190
    def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
        """Merge multiple instances of :class:`MultiModalFieldElem` together."""
191
192
193
194
195
196
        fields = [item.field for item in batch]
        if len(set(fields)) > 1:
            raise ValueError(f"Cannot merge different {fields=}")

        data = self._reduce_data([item.data for item in batch])

197
        return self._build_elem(data)
198
199
200
201
202


@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
    """
203
204
    A :class:`BaseMultiModalField` implementation where an element in the batch
    is obtained by indexing into the first dimension of the underlying data.
205
206
    """

207
208
    def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
        return [self._build_elem(item) for item in batch]
209
210
211
212

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
            first_shape = batch[0].shape
213
            if all(elem.shape == first_shape for elem in batch):
214
215
216
217
218
219
220
221
                return torch.stack(batch)

        return batch


@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
    """
222
223
    A :class:`BaseMultiModalField` implementation where an element in the batch
    is obtained by slicing along the first dimension of the underlying data.
224
225
    """

226
    def build_elems(
227
228
229
        self,
        batch: NestedTensors,
        slices: Sequence[slice],
230
231
    ) -> list[MultiModalFieldElem]:
        return [self._build_elem(batch[slice_]) for slice_ in slices]
232
233
234
235

    def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
        if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
            first_shape = batch[0].shape
236
            if all(elem.shape[1:] == first_shape[1:] for elem in batch):
237
238
                return torch.concat(batch)

239
        return [e for elem in batch for e in elem]
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266


class MultiModalFieldConfig:

    @staticmethod
    def batched(modality: str):
        return MultiModalFieldConfig(
            field_cls=MultiModalBatchedField,
            modality=modality,
        )

    @staticmethod
    def flat(modality: str, slices: Sequence[slice]):
        return MultiModalFieldConfig(
            field_cls=MultiModalFlatField,
            modality=modality,
            slices=slices,
        )

    def __init__(
        self,
        field_cls: type[BaseMultiModalField],
        modality: str,
        **field_config: Any,
    ) -> None:
        super().__init__()

267
268
269
        self.field_cls = field_cls
        self.modality = modality
        self.field_config = field_config
270

271
    def build_elems(
272
273
274
        self,
        key: str,
        batch: NestedTensors,
275
276
277
    ) -> Sequence[MultiModalFieldElem]:
        field = self.field_cls(key=key, modality=self.modality)
        return field.build_elems(batch, **self.field_config)  # type: ignore
278
279


280
281
282
283
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
    """
    A collection of :class:`MultiModalFieldElem`
    corresponding to a data item in :class:`MultiModalDataItems`.
284
    """
285

286
287
288
    @staticmethod
    def from_elems(elems: Sequence[MultiModalFieldElem]):
        return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
289

290
291
292
293
294
    @property
    def modality(self) -> str:
        modalities = {elem.field.modality for elem in self.data.values()}
        assert len(modalities) == 1, f"Found different modalities={modalities}"
        return next(iter(modalities))
295
296


297
298
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
299
300
301
302
class MultiModalKwargs(UserDict[str, NestedTensors]):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
303

304
305
306
    The metadata :code:`items` enables us to obtain the keyword arguments
    corresponding to each data item in :class:`MultiModalDataItems`, via
    :meth:`get_item` and :meth:`get_items`.
307
308
    """

309
310
311
312
313
314
315
    @staticmethod
    def from_hf_inputs(
        hf_inputs: BatchFeature,
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items = list[MultiModalKwargsItem]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
                    f"Found: {batch_sizes=}")

            batch_size = next(iter(batch_sizes.values()))
            for item_idx in range(batch_size):
                elems = [v[item_idx] for v in elems_in_modality.values()]
                items.append(MultiModalKwargsItem.from_elems(elems))

        return MultiModalKwargs.from_items(items)
342
343

    @staticmethod
344
345
346
347
348
349
350
    def from_items(items: Sequence[MultiModalKwargsItem]):
        """Construct a new :class:`MultiModalKwargs` from multiple items."""
        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
        for item in items:
            for key, elem in item.items():
                elems_by_key[key].append(elem)

351
        data = {
352
353
            key: elems[0].field.reduce(elems).data
            for key, elems in elems_by_key.items() if len(elems) > 0
354
355
        }

356
        return MultiModalKwargs(data, items=items)
357
358
359
360
361

    def __init__(
        self,
        data: Mapping[str, NestedTensors],
        *,
362
        items: Optional[Sequence[MultiModalKwargsItem]] = None,
363
364
365
    ) -> None:
        super().__init__(data)

366
367
        items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
        self._items_by_modality = dict(items_by_modality)
368

369
370
371
    @property
    def modalities(self):
        return self._items_by_modality.keys()
372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    @staticmethod
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
        """
        Stack the inner dimensions that have the same shape in
        a nested list of tensors.

        Thus, a dimension represented by a list means that the inner
        dimensions are different for each element along that dimension.
        """
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors

        # TODO: Remove these once all models have been migrated
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)
        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

        stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
            return stacked

396
        tensors_ = cast(list[torch.Tensor], stacked)
397
398
399
400
401
402
403
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_

        return torch.stack(tensors_)

    @staticmethod
404
    def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
        if len(inputs_list) == 0:
            return {}

        # We need to consider the case where each item in the batch
        # contains different modalities (i.e. different keys).
419
        item_lists = defaultdict[str, list[NestedTensors]](list)
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444

        for inputs in inputs_list:
            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
            k: MultiModalKwargs._try_stack(item_list)
            for k, item_list in item_lists.items()
        }

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
            lambda x: x.to(device, non_blocking=True),
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)

445
446
447
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
448
        if self._items_by_modality != other._items_by_modality:
449
450
451
452
453
454
            return False

        ks = self.keys()
        return (ks == other.keys()
                and all(nested_tensors_equal(self[k], other[k]) for k in ks))

455
456
457
458
459
    def _validate_modality(self, method_name: str, modality: str) -> None:
        if not self._items_by_modality:
            raise RuntimeError(
                f"`{method_name}` is not supported when "
                "MultiModalKwargs is not initialized with `items`")
460

461
462
        if modality not in self._items_by_modality:
            available_modalities = set(self._items_by_modality.keys())
463
464
465
            raise KeyError(f"Modality {modality!r} not found. "
                           f"Available modalities: {available_modalities}")

466
467
468
469
    def get_item_count(self, modality: str) -> int:
        """Get the number of items belonging to a modality."""
        self._validate_modality("get_item_count", modality)
        return len(self._items_by_modality[modality])
470

471
472
473
474
475
476
477
    def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
        """
        Get the keyword arguments corresponding to an item identified by
        its modality and index.
        """
        self._validate_modality("get_item", modality)
        return self._items_by_modality[modality][item_index]
478

479
    def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
480
        """
481
482
        Get the keyword arguments corresponding to each item belonging to
        a modality.
483
        """
484
485
        self._validate_modality("get_items", modality)
        return self._items_by_modality[modality]
486

487
488
489

MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
490
A dictionary containing placeholder ranges for each modality.
491
492
493
"""


494
class MultiModalInputs(TypedDict):
495
    """
496
497
    Represents the outputs of
    :class:`vllm.multimodal.processing.BaseMultiModalProcessor`,
498
499
500
501
502
503
504
    ready to be passed to vLLM internals.
    """

    type: Literal["multimodal"]
    """The type of inputs."""

    prompt: str
505
    """The processed prompt text."""
506

507
    prompt_token_ids: list[int]
508
509
    """The processed token IDs which includes placeholder tokens."""

510
    token_type_ids: NotRequired[list[int]]
511
512
    """The token type IDs of the prompt."""

513
514
515
    mm_kwargs: MultiModalKwargs
    """Keyword arguments to be directly passed to the model after batching."""

516
    mm_hashes: NotRequired[Optional["MultiModalHashDict"]]
517
518
    """The hashes of the multi-modal data."""

519
520
521
522
523
    mm_placeholders: MultiModalPlaceholderDict
    """
    For each modality, information about the placeholder tokens in
    :code:`prompt_token_ids`.
    """