base.py 15.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import ABC, abstractmethod
4
from collections import defaultdict
5
6
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
7
8
                    Optional, Sequence, Tuple, Type, TypeVar, Union)

9
from torch import nn
10

11
from vllm.inputs import InputContext
12
from vllm.logger import init_logger
13
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
14
                        resolve_mm_processor_kwargs)
15

16
17
18
19
if TYPE_CHECKING:
    from vllm.config import ModelConfig
    from vllm.sequence import SequenceGroupMetadata

20
from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs,
21
                     PlaceholderRange)
22

23
logger = init_logger(__name__)
24

25
MultiModalInputMapper = Callable[[InputContext, ModalityData[object]],
26
                                 MultiModalKwargs]
27
28
"""
Return a dictionary to be passed as keyword arguments to
29
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
30
31
32
33
34
35
36
37
38
39
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
40

41
_T = TypeVar("_T")
42
43
N = TypeVar("N", bound=Type[nn.Module])

44

45
class MultiModalPlugin(ABC):
46
47
48
49
50
51
52
53
54
55
56
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
    """

    def __init__(self) -> None:
57
58
        self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
        self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
59
60

    @abstractmethod
61
    def get_data_key(self) -> str:
62
        """
63
        Get the data key corresponding to the modality.
64
65
66
67
        """
        raise NotImplementedError

    @abstractmethod
68
69
70
    def _default_input_mapper(
        self,
        ctx: InputContext,
71
        data: ModalityData[Any],
72
        **mm_processor_kwargs,
73
    ) -> MultiModalKwargs:
74
75
        """
        Return a dictionary to be passed as keyword arguments to
76
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
77
        tokenizers and processors in HuggingFace Transformers.
78
79

        If the data is not supported, throw :exc:`TypeError`.
80
81
82
        """
        raise NotImplementedError

83
84
    def register_input_mapper(
        self,
85
        mapper: Optional[MultiModalInputMapper] = None,
86
    ):
87
        """
88
        Register an input mapper to a model class.
89

90
        When the model receives input data that matches the modality served by
91
        this plugin (see :meth:`get_data_key`), the provided function is
92
        invoked to transform the data into a dictionary of model inputs.
93

94
        If `None` is provided, then the default input mapper is used instead.
95
96
97
        """

        def wrapper(model_cls: N) -> N:
98
            if self._input_mappers.contains(model_cls, strict=True):
99
                logger.warning(
100
                    "Model class %s already has an input mapper "
101
                    "registered to %s. It is overwritten by the new one.",
102
103
104
                    model_cls,
                    self,
                )
105

106
107
            self._input_mappers[model_cls] = (mapper
                                              or self._default_input_mapper)
108
109
110
111
112

            return model_cls

        return wrapper

113
114
115
    def map_input(
        self,
        model_config: "ModelConfig",
116
        data: ModalityData[Any],
117
        mm_processor_kwargs: Optional[dict[str, Any]],
118
    ) -> MultiModalKwargs:
119
        """
120
121
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
122
123
124

        The model is identified by ``model_config``.

125
126
        Raises:
            TypeError: If the data type is not supported.
127
        """
128

129
130
131
132
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
133

134
        mapper = self._input_mappers.get(model_cls)
135

136
137
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
138
139
                           f"model class {model_cls.__name__}.")

140
141
142
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        # In the case of the default mapper, we have to get resource
        # processor through its HuggingFace autoclass; since this goes
        # through **kwargs, we can't inspect it the same way, so we allow
        # drop mm_processor_kwargs based on signature inspection
        # if we're using the default mapper.
        #
        # This should be safe in general due to the sanitation, since the
        # transformers resource should filter unused kwargs anyway.
        uses_default_mapper = mapper == self._default_input_mapper
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
            mm_processor_kwargs,
            callable=mapper,
            allow_var_kwargs=uses_default_mapper,
        )
158
        return mapper(InputContext(model_config), data, **mm_processor_kwargs)
159
160
161
162

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
163
164
        Calculate the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model.
165
166
167
168
169
170
171
172
173
174
175
176
177
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
178
179
180
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model
        for a model class.
181
182
183
184
185

        If `None` is provided, then the default calculation is used instead.
        """

        def wrapper(model_cls: N) -> N:
186
            if self._max_mm_tokens.contains(model_cls, strict=True):
187
188
189
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
190
191
192
                    model_cls,
                    self,
                )
193
194
195
196

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

197
198
            self._max_mm_tokens[model_cls] = (
                max_mm_tokens or self._default_max_multimodal_tokens)
199
200
201
202
203

            return model_cls

        return wrapper

204
    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
205
206
207
208
209
210
211
212
213
214
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture
215
        from vllm.model_executor.models import supports_multimodal
216
217
218

        model_cls, _ = get_model_architecture(model_config)

219
        if not supports_multimodal(model_cls):
220
221
222
223
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
224
            return 0
225
226

        if callable(max_mm_tokens):
227
228
229
230
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                max_mm_tokens, overrides=model_config.mm_processor_kwargs)
            max_mm_tokens = max_mm_tokens(InputContext(model_config),
                                          **mm_processor_kwargs)
231
232
233
234

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens
235
236
237
238
239
240
241
242


class MultiModalPlaceholderMap:
    """
    Relates multi-modal embeddings to their corresponding placeholders.
    """

    class IndexMap(NamedTuple):
243
244
        src: list[int]
        dest: list[int]
245

246
    src_ranges: list[range]
247
248
249
250
251
252
253
254
255
256
    """
    The indices of the multi-modal embeddings that will replace the
    corresponding placeholder embeddings pointed to by ``dest_ranges``.
    """

    src_len: int
    """
    The total number of flattened multi-modal embeddings.
    """

257
    dest_ranges: list[range]
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    """
    The indices of the placeholder embeddings that will be replaced by the
    multimodal embeddings.
    """

    dest_len: int
    """
    The total number of embeddings in the destination tensor.
    """

    def __init__(self):
        self.src_ranges = []
        self.src_len = 0
        self.dest_ranges = []
        self.dest_len = 0

    @classmethod
    def from_seq_group(
        cls, seq_group: "SequenceGroupMetadata", positions: range
277
    ) -> Tuple[Optional[MultiModalDataDict], dict[str,
278
279
280
281
282
283
284
                                                  "MultiModalPlaceholderMap"]]:
        """
        Returns the multi-modal items that intersect with the portion of a
        prompt (``seq_group``) represented by ``positions``, as well as a
        ``MultiModalPlaceholderMap`` that relates the multi-modal embedding
        vectors to their corresponding placeholders.

285
        Examples:
286

287
        .. code-block::
288

289
290
            Prompt:    |AAAA BBBB What's in these images?|
            Positions: |.................................|
291

292
293
294
                images      = [A, B]
                src_ranges  = [(0, 4), (4, 8)]
                dest_ranges = [(0, 4), (5, 9)]
295

296
297
            Prompt:    |AAAA BBBB What's in these images?|
            Positions: |  .....                          |
298

299
300
301
                images      = [A, B]
                src_ranges  = [(2, 4), (4, 6)]
                dest_ranges = [(0, 2), (3, 5)]
302

303
304
            Prompt:    |AAAA BBBB What's in these images?|
            Positions: |     .........                   |
305

306
307
308
                images      = [B]
                src_ranges  = [(0, 4)]
                dest_ranges = [(0, 4)]
309

310
311
312
313
314
315
            Prompt:    |AAAA BBBB What's in these images?|
            Positions: |          .......................|

                images      = []
                src_ranges  = []
                dest_ranges = []
316
        """
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        seq_mm_data = seq_group.multi_modal_data
        seq_mm_placeholders = seq_group.multi_modal_placeholders

        if not seq_mm_data or not seq_mm_placeholders:
            return seq_mm_data, {}

        # For merged processor, we directly use mm_kwargs as mm_data
        if isinstance(seq_mm_data, MultiModalKwargs):
            placeholder_maps = dict[str, MultiModalPlaceholderMap]()

            for modality, placeholders in seq_mm_placeholders.items():
                placeholder_map = MultiModalPlaceholderMap()

                if positions:
                    placeholder_map.append_items_from_seq_group(
                        positions,
                        # Dummy, since we don't care about intersecting items
                        [None] * len(placeholders),
                        placeholders,
                    )

                placeholder_maps[modality] = placeholder_map

            return seq_mm_data, placeholder_maps
341

342
343
        mm_data = {**seq_mm_data}
        placeholder_maps = defaultdict[str, MultiModalPlaceholderMap](
344
345
            MultiModalPlaceholderMap)

346
        for modality, placeholders in seq_mm_placeholders.items():
347
348
349
350
351
            mm_items = mm_data.pop(modality)
            if not isinstance(mm_items, list):
                mm_items = [mm_items]

            if positions:
352
353
354
355
356
357
                intersecting_items = placeholder_maps[modality] \
                    .append_items_from_seq_group(
                        positions,
                        mm_items,
                        placeholders,
                    )
358
359
360
361
362
363
364

                if intersecting_items:
                    mm_data[modality] = intersecting_items

        return mm_data, placeholder_maps

    def append_items_from_seq_group(
365
366
        self,
        positions: range,
367
        multi_modal_items: list[_T],
368
        multi_modal_placeholders: Sequence[PlaceholderRange],
369
    ) -> list[_T]:
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        """
        Adds the multi-modal items that intersect ```positions`` to this
        placeholder map and returns the intersecting items.
        """
        intersecting_items = []

        if len(multi_modal_items) != len(multi_modal_placeholders):
            raise ValueError(
                "Multi-modal placeholders and items must have the same length."
            )
        for placeholder_dict, mm_item in zip(multi_modal_placeholders,
                                             multi_modal_items):
            placeholder = range(
                placeholder_dict["offset"],
384
385
386
387
388
389
                placeholder_dict["offset"] + placeholder_dict["length"],
            )
            intersection = range(
                max(positions.start, placeholder.start),
                min(positions.stop, placeholder.stop),
            )
390
391
392
393
394

            if not intersection:
                # Skip this multi-modal item.
                continue

395
396
397
398
            token_embedding_range = range(
                intersection.start - positions.start,
                intersection.stop - positions.start,
            )
399
400
401

            multimodal_embedding_range = range(
                intersection.start - placeholder.start + self.src_len,
402
403
                intersection.stop - placeholder.start + self.src_len,
            )
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444

            intersecting_items.append(mm_item)
            self.dest_ranges.append(token_embedding_range)
            self.src_ranges.append(multimodal_embedding_range)
            self.src_len += len(placeholder)

        self.dest_len += len(positions)
        return intersecting_items

    def extend(self, other: "MultiModalPlaceholderMap"):
        """
        Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
        instance based on the source and destination tensors being
        concatenated.
        """

        self.src_ranges.extend(
            range(self.src_len + r.start, self.src_len + r.stop)
            for r in other.src_ranges)
        self.src_len += other.src_len
        self.dest_ranges.extend(
            range(self.dest_len + r.start, self.dest_len + r.stop)
            for r in other.dest_ranges)
        self.dest_len += other.dest_len

    def index_map(self) -> "IndexMap":
        """
        Finalizes the placeholder map into lists of indices that can be used to
        index the source and destination tensors.
        """

        src_indices = [i for r in self.src_ranges for i in r]
        dest_indices = [i for r in self.dest_ranges for i in r]

        if len(src_indices) != len(dest_indices):
            raise ValueError(
                f"The number of source ({len(src_indices)}) and destination "
                f"indices ({len(dest_indices)}) must be the same.")

        return MultiModalPlaceholderMap.IndexMap(src=src_indices,
                                                 dest=dest_indices)
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463


class MediaIO(ABC, Generic[_T]):

    @abstractmethod
    def load_bytes(self, data: bytes) -> _T:
        raise NotImplementedError

    @abstractmethod
    def load_base64(self, media_type: str, data: str) -> _T:
        """
        List of media types:
        https://www.iana.org/assignments/media-types/media-types.xhtml
        """
        raise NotImplementedError

    @abstractmethod
    def load_file(self, filepath: Path) -> _T:
        raise NotImplementedError