base.py 19.2 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from collections import UserDict, defaultdict
3
4
5
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
                    NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar,
                    Union, cast, final)
6

7
import numpy as np
8
9
10
11
import torch
import torch.types
from PIL import Image
from torch import nn
12
from typing_extensions import TypeAlias
13

14
from vllm.inputs import InputContext
15
from vllm.logger import init_logger
16
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
17
                        json_map_leaves, resolve_mm_processor_kwargs)
18

19
20
21
22
if TYPE_CHECKING:
    from vllm.config import ModelConfig
    from vllm.sequence import SequenceGroupMetadata

23
24
logger = init_logger(__name__)

25
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
26
"""
27
Uses a list instead of a tensor if the dimensions of each element do not match.
28
29
"""

30
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
31
32
33
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
34
35
36
"""


37
38
class _MultiModalInputsBase(UserDict[str, NestedTensors]):
    pass
39
40
41
42
43
44
45
46
47


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
48
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
49
        """
50
        Recursively stacks lists of tensors when they all have the same shape.
51
        """
52
53
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors
54

55
56
57
58
59
60
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)

        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

61
        stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
62
63
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
64
            return stacked
65

66
67
68
        tensors_ = cast(List[torch.Tensor], stacked)
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
69
            return tensors_
70

71
        return torch.stack(tensors_)
72
73

    @staticmethod
74
75
76
77
78
79
80
81
82
83
    def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
84
85
86
        if len(inputs_list) == 0:
            return {}

87
        item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
88
89

        for inputs in inputs_list:
90
91
92
            # For models that supports multiple modalities (e.g. Qwen2-VL),
            # different modalities will return different data keys,
            # so batch() should skip the same key check.
93
94
95
96
97

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
98
            k: MultiModalInputs._try_stack(item_list)
99
            for k, item_list in item_lists.items()
100
        }
101
102
103
104
105
106
107

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
108
109
110
111
112
113
114
115
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
            lambda x: x.to(device, non_blocking=True),
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)
116
117


118
119
120
121
122
123
124
125
126
127
128
129
_T = TypeVar("_T")

MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.

The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""


@final
130
class MultiModalDataBuiltins(TypedDict, total=False):
131
132
    """Modality types that are predefined by vLLM."""

133
134
    image: MultiModalData[Image.Image]
    """The input image(s)."""
135

136
137
    audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
    """The input audio item(s) and corresponding sampling rate(s)."""
138

139
140
141
    video: MultiModalData[Tuple[np.ndarray]]
    """The input video(s)."""

142

143
144
MultiModalDataDict = Union[MultiModalDataBuiltins,
                           Mapping[str, MultiModalData[object]]]
145
146
"""
A dictionary containing an item for each modality type to input.
147

148
149
150
151
152
Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
    through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
153
"""
154

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

class PlaceholderRange(TypedDict):
    """
    Placeholder location information for multi-modal data.

    For example:
        Prompt: AAAA BBBB What is in these images?
        Images A and B will have:
            A: { "offset": 0, "length": 4 }
            B: { "offset": 5, "length": 4 }
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""


MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
"""

179
180
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
                                 MultiModalInputs]
181
182
"""
Return a dictionary to be passed as keyword arguments to
183
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
184
185
186
187
188
189
190
191
192
193
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
194

195
196
N = TypeVar("N", bound=Type[nn.Module])

197

198
class MultiModalPlugin(ABC):
199
200
201
202
203
204
205
206
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
207
208
209

    See also:
        :ref:`adding_multimodal_plugin`
210
211
212
    """

    def __init__(self) -> None:
213
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
214
        self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
215
216

    @abstractmethod
217
    def get_data_key(self) -> str:
218
        """
219
        Get the data key corresponding to the modality.
220
221
222
223
        """
        raise NotImplementedError

    @abstractmethod
224
225
226
227
    def _default_input_mapper(
        self,
        ctx: InputContext,
        data: MultiModalData[object],
228
        **mm_processor_kwargs,
229
    ) -> MultiModalInputs:
230
231
        """
        Return a dictionary to be passed as keyword arguments to
232
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
233
        tokenizers and processors in HuggingFace Transformers.
234
235

        If the data is not supported, throw :exc:`TypeError`.
236
237
238
        """
        raise NotImplementedError

239
240
    def register_input_mapper(
        self,
241
        mapper: Optional[MultiModalInputMapper] = None,
242
    ):
243
        """
244
        Register an input mapper to a model class.
245

246
        When the model receives input data that matches the modality served by
247
        this plugin (see :meth:`get_data_key`), the provided function is
248
        invoked to transform the data into a dictionary of model inputs.
249

250
251
252
        If `None` is provided, then the default input mapper is used instead.

        See also:
253
254
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
255
256
257
        """

        def wrapper(model_cls: N) -> N:
258
            if model_cls in self._input_mappers:
259
                logger.warning(
260
                    "Model class %s already has an input mapper "
261
                    "registered to %s. It is overwritten by the new one.",
262
263
264
                    model_cls,
                    self,
                )
265

266
267
            self._input_mappers[model_cls] = (mapper
                                              or self._default_input_mapper)
268
269
270
271
272

            return model_cls

        return wrapper

273
274
275
276
277
278
    def map_input(
        self,
        model_config: "ModelConfig",
        data: MultiModalData[object],
        mm_processor_kwargs: Dict[str, Any],
    ) -> MultiModalInputs:
279
        """
280
281
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
282
283
284

        The model is identified by ``model_config``.

285
286
287
        Raises:
            TypeError: If the data type is not supported.

288
        See also:
289
290
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
291
        """
292
293
294
295
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
296

297
        mapper = self._input_mappers.get(model_cls)
298

299
300
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
301
302
                           f"model class {model_cls.__name__}.")

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        # In the case of the default mapper, we have to get resource
        # processor through its HuggingFace autoclass; since this goes
        # through **kwargs, we can't inspect it the same way, so we allow
        # drop mm_processor_kwargs based on signature inspection
        # if we're using the default mapper.
        #
        # This should be safe in general due to the sanitation, since the
        # transformers resource should filter unused kwargs anyway.
        uses_default_mapper = mapper == self._default_input_mapper
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
            mm_processor_kwargs,
            callable=mapper,
            allow_var_kwargs=uses_default_mapper,
        )
318
        return mapper(InputContext(model_config), data, **mm_processor_kwargs)
319
320
321
322

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
323
324
        Calculate the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model.
325
326
327
328
329
330
331
332
333
334
335
336
337
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
338
339
340
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model
        for a model class.
341
342
343
344

        If `None` is provided, then the default calculation is used instead.

        See also:
345
            :ref:`enabling_multimodal_inputs`
346
347
348
349
350
351
352
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._max_mm_tokens:
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
353
354
355
                    model_cls,
                    self,
                )
356
357
358
359

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

360
361
            self._max_mm_tokens[model_cls] = (
                max_mm_tokens or self._default_max_multimodal_tokens)
362
363
364
365
366

            return model_cls

        return wrapper

367
    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
368
369
370
371
372
373
374
375
376
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
377
            :ref:`enabling_multimodal_inputs`
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)

        if model_cls not in self._input_mappers:
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
            raise KeyError(f"No maximum number of multi-modal tokens is given "
                           f"for model class {model_cls.__name__} in {self}.")

        if callable(max_mm_tokens):
393
394
395
396
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                max_mm_tokens, overrides=model_config.mm_processor_kwargs)
            max_mm_tokens = max_mm_tokens(InputContext(model_config),
                                          **mm_processor_kwargs)
397
398
399
400

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488


class MultiModalPlaceholderMap:
    """
    Relates multi-modal embeddings to their corresponding placeholders.
    """

    class IndexMap(NamedTuple):
        src: List[int]
        dest: List[int]

    src_ranges: List[range]
    """
    The indices of the multi-modal embeddings that will replace the
    corresponding placeholder embeddings pointed to by ``dest_ranges``.
    """

    src_len: int
    """
    The total number of flattened multi-modal embeddings.
    """

    dest_ranges: List[range]
    """
    The indices of the placeholder embeddings that will be replaced by the
    multimodal embeddings.
    """

    dest_len: int
    """
    The total number of embeddings in the destination tensor.
    """

    def __init__(self):
        self.src_ranges = []
        self.src_len = 0
        self.dest_ranges = []
        self.dest_len = 0

    @classmethod
    def from_seq_group(
        cls, seq_group: "SequenceGroupMetadata", positions: range
    ) -> Tuple[Optional[MultiModalDataDict], Dict[str,
                                                  "MultiModalPlaceholderMap"]]:
        """
        Returns the multi-modal items that intersect with the portion of a
        prompt (``seq_group``) represented by ``positions``, as well as a
        ``MultiModalPlaceholderMap`` that relates the multi-modal embedding
        vectors to their corresponding placeholders.

        Consider the following scenarios:

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |.................................|

            images      = [A, B]
            src_ranges  = [(0, 4), (4, 8)]
            dest_ranges = [(0, 4), (5, 9)]

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |  .....                          |

            images      = [A, B]
            src_ranges  = [(2, 4), (4, 6)]
            dest_ranges = [(0, 2), (3, 5)]

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |     .........                   |

            images      = [B]
            src_ranges  = [(0, 4)]
            dest_ranges = [(0, 4)]

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |          .......................|

            images      = []
            src_ranges  = []
            dest_ranges = []
        """
        if (not seq_group.multi_modal_data
                or not seq_group.multi_modal_placeholders):
            return seq_group.multi_modal_data, {}

        mm_data = {**seq_group.multi_modal_data}
        placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict(
            MultiModalPlaceholderMap)

489
490
491
492
        for (
                modality,
                placeholders,
        ) in seq_group.multi_modal_placeholders.items():
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
            mm_items = mm_data.pop(modality)
            if not isinstance(mm_items, list):
                mm_items = [mm_items]

            if positions:
                intersecting_items = placeholder_maps[
                    modality].append_items_from_seq_group(
                        positions, mm_items, placeholders)

                if intersecting_items:
                    mm_data[modality] = intersecting_items

        return mm_data, placeholder_maps

    def append_items_from_seq_group(
508
509
510
511
512
        self,
        positions: range,
        multi_modal_items: List[_T],
        multi_modal_placeholders: List[PlaceholderRange],
    ) -> List[_T]:
513
514
515
516
517
518
519
520
521
522
523
524
525
526
        """
        Adds the multi-modal items that intersect ```positions`` to this
        placeholder map and returns the intersecting items.
        """
        intersecting_items = []

        if len(multi_modal_items) != len(multi_modal_placeholders):
            raise ValueError(
                "Multi-modal placeholders and items must have the same length."
            )
        for placeholder_dict, mm_item in zip(multi_modal_placeholders,
                                             multi_modal_items):
            placeholder = range(
                placeholder_dict["offset"],
527
528
529
530
531
532
                placeholder_dict["offset"] + placeholder_dict["length"],
            )
            intersection = range(
                max(positions.start, placeholder.start),
                min(positions.stop, placeholder.stop),
            )
533
534
535
536
537

            if not intersection:
                # Skip this multi-modal item.
                continue

538
539
540
541
            token_embedding_range = range(
                intersection.start - positions.start,
                intersection.stop - positions.start,
            )
542
543
544

            multimodal_embedding_range = range(
                intersection.start - placeholder.start + self.src_len,
545
546
                intersection.stop - placeholder.start + self.src_len,
            )
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587

            intersecting_items.append(mm_item)
            self.dest_ranges.append(token_embedding_range)
            self.src_ranges.append(multimodal_embedding_range)
            self.src_len += len(placeholder)

        self.dest_len += len(positions)
        return intersecting_items

    def extend(self, other: "MultiModalPlaceholderMap"):
        """
        Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
        instance based on the source and destination tensors being
        concatenated.
        """

        self.src_ranges.extend(
            range(self.src_len + r.start, self.src_len + r.stop)
            for r in other.src_ranges)
        self.src_len += other.src_len
        self.dest_ranges.extend(
            range(self.dest_len + r.start, self.dest_len + r.stop)
            for r in other.dest_ranges)
        self.dest_len += other.dest_len

    def index_map(self) -> "IndexMap":
        """
        Finalizes the placeholder map into lists of indices that can be used to
        index the source and destination tensors.
        """

        src_indices = [i for r in self.src_ranges for i in r]
        dest_indices = [i for r in self.dest_ranges for i in r]

        if len(src_indices) != len(dest_indices):
            raise ValueError(
                f"The number of source ({len(src_indices)}) and destination "
                f"indices ({len(dest_indices)}) must be the same.")

        return MultiModalPlaceholderMap.IndexMap(src=src_indices,
                                                 dest=dest_indices)