base.py 19 KB
Newer Older
1
import sys
2
from abc import ABC, abstractmethod
3
from collections import UserDict, defaultdict
4
5
6
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
                    NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar,
                    Union, cast, final)
7

8
import numpy as np
9
10
11
12
import torch
import torch.types
from PIL import Image
from torch import nn
13
from typing_extensions import TypeAlias
14

15
from vllm.inputs import InputContext
16
from vllm.logger import init_logger
17
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
18
                        json_map_leaves, resolve_mm_processor_kwargs)
19

20
21
22
23
if TYPE_CHECKING:
    from vllm.config import ModelConfig
    from vllm.sequence import SequenceGroupMetadata

24
25
logger = init_logger(__name__)

26
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
27
"""
28
Uses a list instead of a tensor if the dimensions of each element do not match.
29
30
"""

31
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
32
33
34
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
35
36
37
38
39
40
41
42
"""

if sys.version_info < (3, 9):
    # UserDict cannot be subscripted
    class _MultiModalInputsBase(UserDict):
        pass
else:

43
    class _MultiModalInputsBase(UserDict[str, NestedTensors]):
44
45
46
47
48
49
50
51
52
53
        pass


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
54
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
55
        """
56
        Recursively stacks lists of tensors when they all have the same shape.
57
        """
58
59
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors
60

61
62
63
64
65
66
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)

        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

67
        stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
68
69
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
70
            return stacked
71

72
73
74
        tensors_ = cast(List[torch.Tensor], stacked)
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
75
            return tensors_
76

77
        return torch.stack(tensors_)
78
79

    @staticmethod
80
81
82
83
84
85
86
87
88
89
    def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
90
91
92
        if len(inputs_list) == 0:
            return {}

93
        item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
94
95

        for inputs in inputs_list:
96
97
98
            # For models that supports multiple modalities (e.g. Qwen2-VL),
            # different modalities will return different data keys,
            # so batch() should skip the same key check.
99
100
101
102
103

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
104
            k: MultiModalInputs._try_stack(item_list)
105
            for k, item_list in item_lists.items()
106
        }
107
108
109
110
111
112
113

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
114
115
116
117
118
119
120
121
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
            lambda x: x.to(device, non_blocking=True),
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)
122
123


124
125
126
127
128
129
130
131
132
133
134
135
_T = TypeVar("_T")

MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.

The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""


@final
136
class MultiModalDataBuiltins(TypedDict, total=False):
137
138
    """Modality types that are predefined by vLLM."""

139
140
    image: MultiModalData[Image.Image]
    """The input image(s)."""
141

142
143
    audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
    """The input audio item(s) and corresponding sampling rate(s)."""
144

145

146
147
MultiModalDataDict = Union[MultiModalDataBuiltins,
                           Mapping[str, MultiModalData[object]]]
148
149
"""
A dictionary containing an item for each modality type to input.
150

151
152
153
154
155
Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
    through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
156
"""
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

class PlaceholderRange(TypedDict):
    """
    Placeholder location information for multi-modal data.

    For example:
        Prompt: AAAA BBBB What is in these images?
        Images A and B will have:
            A: { "offset": 0, "length": 4 }
            B: { "offset": 5, "length": 4 }
    """

    offset: int
    """The start index of the placeholder in the prompt."""

    length: int
    """The length of the placeholder."""


MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
"""

182
183
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
                                 MultiModalInputs]
184
185
"""
Return a dictionary to be passed as keyword arguments to
186
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
187
188
189
190
191
192
193
194
195
196
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
197

198
199
N = TypeVar("N", bound=Type[nn.Module])

200

201
class MultiModalPlugin(ABC):
202
203
204
205
206
207
208
209
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
210
211
212

    See also:
        :ref:`adding_multimodal_plugin`
213
214
215
    """

    def __init__(self) -> None:
216
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
217
        self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
218
219

    @abstractmethod
220
    def get_data_key(self) -> str:
221
        """
222
        Get the data key corresponding to the modality.
223
224
225
226
        """
        raise NotImplementedError

    @abstractmethod
227
228
229
230
    def _default_input_mapper(
        self,
        ctx: InputContext,
        data: MultiModalData[object],
231
        **mm_processor_kwargs,
232
    ) -> MultiModalInputs:
233
234
        """
        Return a dictionary to be passed as keyword arguments to
235
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
236
        tokenizers and processors in HuggingFace Transformers.
237
238

        If the data is not supported, throw :exc:`TypeError`.
239
240
241
        """
        raise NotImplementedError

242
243
    def register_input_mapper(
        self,
244
        mapper: Optional[MultiModalInputMapper] = None,
245
    ):
246
        """
247
        Register an input mapper to a model class.
248

249
        When the model receives input data that matches the modality served by
250
        this plugin (see :meth:`get_data_key`), the provided function is
251
        invoked to transform the data into a dictionary of model inputs.
252

253
254
255
        If `None` is provided, then the default input mapper is used instead.

        See also:
256
257
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
258
259
260
        """

        def wrapper(model_cls: N) -> N:
261
            if model_cls in self._input_mappers:
262
                logger.warning(
263
                    "Model class %s already has an input mapper "
264
265
266
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

267
268
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
269
270
271
272
273

            return model_cls

        return wrapper

274
    def map_input(self, model_config: "ModelConfig",
275
276
                  data: MultiModalData[object],
                  mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs:
277
        """
278
279
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
280
281
282

        The model is identified by ``model_config``.

283
284
285
        Raises:
            TypeError: If the data type is not supported.

286
        See also:
287
288
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
289
        """
290
291
292
293
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
294

295
        mapper = self._input_mappers.get(model_cls)
296

297
298
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
299
300
                           f"model class {model_cls.__name__}.")

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        # In the case of the default mapper, we have to get resource
        # processor through its HuggingFace autoclass; since this goes
        # through **kwargs, we can't inspect it the same way, so we allow
        # drop mm_processor_kwargs based on signature inspection
        # if we're using the default mapper.
        #
        # This should be safe in general due to the sanitation, since the
        # transformers resource should filter unused kwargs anyway.
        uses_default_mapper = mapper == self._default_input_mapper
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
            mm_processor_kwargs,
            callable=mapper,
            allow_var_kwargs=uses_default_mapper,
        )
316
        return mapper(InputContext(model_config), data, **mm_processor_kwargs)
317
318
319
320

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
321
322
        Calculate the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model.
323
324
325
326
327
328
329
330
331
332
333
334
335
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
336
337
338
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model
        for a model class.
339
340
341
342

        If `None` is provided, then the default calculation is used instead.

        See also:
343
            :ref:`enabling_multimodal_inputs`
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._max_mm_tokens:
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
                    model_cls, self)

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

            self._max_mm_tokens[model_cls] = max_mm_tokens \
                or self._default_max_multimodal_tokens

            return model_cls

        return wrapper

363
    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
364
365
366
367
368
369
370
371
372
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
373
            :ref:`enabling_multimodal_inputs`
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)

        if model_cls not in self._input_mappers:
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
            raise KeyError(f"No maximum number of multi-modal tokens is given "
                           f"for model class {model_cls.__name__} in {self}.")

        if callable(max_mm_tokens):
389
390
391
392
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                max_mm_tokens, overrides=model_config.mm_processor_kwargs)
            max_mm_tokens = max_mm_tokens(InputContext(model_config),
                                          **mm_processor_kwargs)
393
394
395
396

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572


class MultiModalPlaceholderMap:
    """
    Relates multi-modal embeddings to their corresponding placeholders.
    """

    class IndexMap(NamedTuple):
        src: List[int]
        dest: List[int]

    src_ranges: List[range]
    """
    The indices of the multi-modal embeddings that will replace the
    corresponding placeholder embeddings pointed to by ``dest_ranges``.
    """

    src_len: int
    """
    The total number of flattened multi-modal embeddings.
    """

    dest_ranges: List[range]
    """
    The indices of the placeholder embeddings that will be replaced by the
    multimodal embeddings.
    """

    dest_len: int
    """
    The total number of embeddings in the destination tensor.
    """

    def __init__(self):
        self.src_ranges = []
        self.src_len = 0
        self.dest_ranges = []
        self.dest_len = 0

    @classmethod
    def from_seq_group(
        cls, seq_group: "SequenceGroupMetadata", positions: range
    ) -> Tuple[Optional[MultiModalDataDict], Dict[str,
                                                  "MultiModalPlaceholderMap"]]:
        """
        Returns the multi-modal items that intersect with the portion of a
        prompt (``seq_group``) represented by ``positions``, as well as a
        ``MultiModalPlaceholderMap`` that relates the multi-modal embedding
        vectors to their corresponding placeholders.

        Consider the following scenarios:

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |.................................|

            images      = [A, B]
            src_ranges  = [(0, 4), (4, 8)]
            dest_ranges = [(0, 4), (5, 9)]

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |  .....                          |

            images      = [A, B]
            src_ranges  = [(2, 4), (4, 6)]
            dest_ranges = [(0, 2), (3, 5)]

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |     .........                   |

            images      = [B]
            src_ranges  = [(0, 4)]
            dest_ranges = [(0, 4)]

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |          .......................|

            images      = []
            src_ranges  = []
            dest_ranges = []
        """
        if (not seq_group.multi_modal_data
                or not seq_group.multi_modal_placeholders):
            return seq_group.multi_modal_data, {}

        mm_data = {**seq_group.multi_modal_data}
        placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict(
            MultiModalPlaceholderMap)

        for modality, placeholders in seq_group.multi_modal_placeholders.items(
        ):
            mm_items = mm_data.pop(modality)
            if not isinstance(mm_items, list):
                mm_items = [mm_items]

            if positions:
                intersecting_items = placeholder_maps[
                    modality].append_items_from_seq_group(
                        positions, mm_items, placeholders)

                if intersecting_items:
                    mm_data[modality] = intersecting_items

        return mm_data, placeholder_maps

    def append_items_from_seq_group(
            self, positions: range, multi_modal_items: List[_T],
            multi_modal_placeholders: List[PlaceholderRange]) -> List[_T]:
        """
        Adds the multi-modal items that intersect ```positions`` to this
        placeholder map and returns the intersecting items.
        """
        intersecting_items = []

        if len(multi_modal_items) != len(multi_modal_placeholders):
            raise ValueError(
                "Multi-modal placeholders and items must have the same length."
            )
        for placeholder_dict, mm_item in zip(multi_modal_placeholders,
                                             multi_modal_items):
            placeholder = range(
                placeholder_dict["offset"],
                placeholder_dict["offset"] + placeholder_dict["length"])
            intersection = range(max(positions.start, placeholder.start),
                                 min(positions.stop, placeholder.stop))

            if not intersection:
                # Skip this multi-modal item.
                continue

            token_embedding_range = range(intersection.start - positions.start,
                                          intersection.stop - positions.start)

            multimodal_embedding_range = range(
                intersection.start - placeholder.start + self.src_len,
                intersection.stop - placeholder.start + self.src_len)

            intersecting_items.append(mm_item)
            self.dest_ranges.append(token_embedding_range)
            self.src_ranges.append(multimodal_embedding_range)
            self.src_len += len(placeholder)

        self.dest_len += len(positions)
        return intersecting_items

    def extend(self, other: "MultiModalPlaceholderMap"):
        """
        Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
        instance based on the source and destination tensors being
        concatenated.
        """

        self.src_ranges.extend(
            range(self.src_len + r.start, self.src_len + r.stop)
            for r in other.src_ranges)
        self.src_len += other.src_len
        self.dest_ranges.extend(
            range(self.dest_len + r.start, self.dest_len + r.stop)
            for r in other.dest_ranges)
        self.dest_len += other.dest_len

    def index_map(self) -> "IndexMap":
        """
        Finalizes the placeholder map into lists of indices that can be used to
        index the source and destination tensors.
        """

        src_indices = [i for r in self.src_ranges for i in r]
        dest_indices = [i for r in self.dest_ranges for i in r]

        if len(src_indices) != len(dest_indices):
            raise ValueError(
                f"The number of source ({len(src_indices)}) and destination "
                f"indices ({len(dest_indices)}) must be the same.")

        return MultiModalPlaceholderMap.IndexMap(src=src_indices,
                                                 dest=dest_indices)