base.py 15.9 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from collections import defaultdict
3
4
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
5
6
                    Optional, Sequence, Tuple, Type, TypeVar, Union)

7
from torch import nn
8

9
from vllm.inputs import InputContext
10
from vllm.logger import init_logger
11
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
12
                        resolve_mm_processor_kwargs)
13

14
15
16
17
if TYPE_CHECKING:
    from vllm.config import ModelConfig
    from vllm.sequence import SequenceGroupMetadata

18
19
from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs,
                     PlaceholderRange)
20

21
logger = init_logger(__name__)
22

23
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
24
                                 MultiModalKwargs]
25
26
"""
Return a dictionary to be passed as keyword arguments to
27
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
28
29
30
31
32
33
34
35
36
37
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
38

39
_T = TypeVar("_T")
40
41
N = TypeVar("N", bound=Type[nn.Module])

42

43
class MultiModalPlugin(ABC):
44
45
46
47
48
49
50
51
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
52
53

    See also:
54
        :ref:`adding-multimodal-plugin`
55
56
57
    """

    def __init__(self) -> None:
58
59
        self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
        self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
60
61

    @abstractmethod
62
    def get_data_key(self) -> str:
63
        """
64
        Get the data key corresponding to the modality.
65
66
67
68
        """
        raise NotImplementedError

    @abstractmethod
69
70
71
    def _default_input_mapper(
        self,
        ctx: InputContext,
72
        data: MultiModalData[Any],
73
        **mm_processor_kwargs,
74
    ) -> MultiModalKwargs:
75
76
        """
        Return a dictionary to be passed as keyword arguments to
77
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
78
        tokenizers and processors in HuggingFace Transformers.
79
80

        If the data is not supported, throw :exc:`TypeError`.
81
82
83
        """
        raise NotImplementedError

84
85
    def register_input_mapper(
        self,
86
        mapper: Optional[MultiModalInputMapper] = None,
87
    ):
88
        """
89
        Register an input mapper to a model class.
90

91
        When the model receives input data that matches the modality served by
92
        this plugin (see :meth:`get_data_key`), the provided function is
93
        invoked to transform the data into a dictionary of model inputs.
94

95
96
97
        If `None` is provided, then the default input mapper is used instead.

        See also:
98
99
            - :ref:`input-processing-pipeline`
            - :ref:`enabling-multimodal-inputs`
100
101
102
        """

        def wrapper(model_cls: N) -> N:
103
            if self._input_mappers.contains(model_cls, strict=True):
104
                logger.warning(
105
                    "Model class %s already has an input mapper "
106
                    "registered to %s. It is overwritten by the new one.",
107
108
109
                    model_cls,
                    self,
                )
110

111
112
            self._input_mappers[model_cls] = (mapper
                                              or self._default_input_mapper)
113
114
115
116
117

            return model_cls

        return wrapper

118
119
120
    def map_input(
        self,
        model_config: "ModelConfig",
121
        data: MultiModalData[Any],
122
        mm_processor_kwargs: Optional[dict[str, Any]],
123
    ) -> MultiModalKwargs:
124
        """
125
126
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
127
128
129

        The model is identified by ``model_config``.

130
131
132
        Raises:
            TypeError: If the data type is not supported.

133
        See also:
134
135
            - :ref:`input-processing-pipeline`
            - :ref:`enabling-multimodal-inputs`
136
        """
137

138
139
140
141
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
142

143
        mapper = self._input_mappers.get(model_cls)
144

145
146
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
147
148
                           f"model class {model_cls.__name__}.")

149
150
151
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        # In the case of the default mapper, we have to get resource
        # processor through its HuggingFace autoclass; since this goes
        # through **kwargs, we can't inspect it the same way, so we allow
        # drop mm_processor_kwargs based on signature inspection
        # if we're using the default mapper.
        #
        # This should be safe in general due to the sanitation, since the
        # transformers resource should filter unused kwargs anyway.
        uses_default_mapper = mapper == self._default_input_mapper
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
            mm_processor_kwargs,
            callable=mapper,
            allow_var_kwargs=uses_default_mapper,
        )
167
        return mapper(InputContext(model_config), data, **mm_processor_kwargs)
168
169
170
171

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
172
173
        Calculate the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model.
174
175
176
177
178
179
180
181
182
183
184
185
186
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
187
188
189
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model
        for a model class.
190
191
192
193

        If `None` is provided, then the default calculation is used instead.

        See also:
194
            :ref:`enabling-multimodal-inputs`
195
196
197
        """

        def wrapper(model_cls: N) -> N:
198
            if self._max_mm_tokens.contains(model_cls, strict=True):
199
200
201
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
202
203
204
                    model_cls,
                    self,
                )
205
206
207
208

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

209
210
            self._max_mm_tokens[model_cls] = (
                max_mm_tokens or self._default_max_multimodal_tokens)
211
212
213
214
215

            return model_cls

        return wrapper

216
    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
217
218
219
220
221
222
223
224
225
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
226
            :ref:`enabling-multimodal-inputs`
227
228
229
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture
230
        from vllm.model_executor.models import supports_multimodal
231
232
233

        model_cls, _ = get_model_architecture(model_config)

234
        if not supports_multimodal(model_cls):
235
236
237
238
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
239
            return 0
240
241

        if callable(max_mm_tokens):
242
243
244
245
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                max_mm_tokens, overrides=model_config.mm_processor_kwargs)
            max_mm_tokens = max_mm_tokens(InputContext(model_config),
                                          **mm_processor_kwargs)
246
247
248
249

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens
250
251
252
253
254
255
256
257


class MultiModalPlaceholderMap:
    """
    Relates multi-modal embeddings to their corresponding placeholders.
    """

    class IndexMap(NamedTuple):
258
259
        src: list[int]
        dest: list[int]
260

261
    src_ranges: list[range]
262
263
264
265
266
267
268
269
270
271
    """
    The indices of the multi-modal embeddings that will replace the
    corresponding placeholder embeddings pointed to by ``dest_ranges``.
    """

    src_len: int
    """
    The total number of flattened multi-modal embeddings.
    """

272
    dest_ranges: list[range]
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    """
    The indices of the placeholder embeddings that will be replaced by the
    multimodal embeddings.
    """

    dest_len: int
    """
    The total number of embeddings in the destination tensor.
    """

    def __init__(self):
        self.src_ranges = []
        self.src_len = 0
        self.dest_ranges = []
        self.dest_len = 0

    @classmethod
    def from_seq_group(
        cls, seq_group: "SequenceGroupMetadata", positions: range
292
    ) -> Tuple[Optional[MultiModalDataDict], dict[str,
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
                                                  "MultiModalPlaceholderMap"]]:
        """
        Returns the multi-modal items that intersect with the portion of a
        prompt (``seq_group``) represented by ``positions``, as well as a
        ``MultiModalPlaceholderMap`` that relates the multi-modal embedding
        vectors to their corresponding placeholders.

        Consider the following scenarios:

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |.................................|

            images      = [A, B]
            src_ranges  = [(0, 4), (4, 8)]
            dest_ranges = [(0, 4), (5, 9)]

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |  .....                          |

            images      = [A, B]
            src_ranges  = [(2, 4), (4, 6)]
            dest_ranges = [(0, 2), (3, 5)]

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |     .........                   |

            images      = [B]
            src_ranges  = [(0, 4)]
            dest_ranges = [(0, 4)]

           Prompt: |AAAA BBBB What's in these images?|
        Positions: |          .......................|

            images      = []
            src_ranges  = []
            dest_ranges = []
        """
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        seq_mm_data = seq_group.multi_modal_data
        seq_mm_placeholders = seq_group.multi_modal_placeholders

        if not seq_mm_data or not seq_mm_placeholders:
            return seq_mm_data, {}

        # For merged processor, we directly use mm_kwargs as mm_data
        if isinstance(seq_mm_data, MultiModalKwargs):
            placeholder_maps = dict[str, MultiModalPlaceholderMap]()

            for modality, placeholders in seq_mm_placeholders.items():
                placeholder_map = MultiModalPlaceholderMap()

                if positions:
                    placeholder_map.append_items_from_seq_group(
                        positions,
                        # Dummy, since we don't care about intersecting items
                        [None] * len(placeholders),
                        placeholders,
                    )

                placeholder_maps[modality] = placeholder_map

            return seq_mm_data, placeholder_maps
354

355
356
        mm_data = {**seq_mm_data}
        placeholder_maps = defaultdict[str, MultiModalPlaceholderMap](
357
358
            MultiModalPlaceholderMap)

359
        for modality, placeholders in seq_mm_placeholders.items():
360
361
362
363
364
            mm_items = mm_data.pop(modality)
            if not isinstance(mm_items, list):
                mm_items = [mm_items]

            if positions:
365
366
367
368
369
370
                intersecting_items = placeholder_maps[modality] \
                    .append_items_from_seq_group(
                        positions,
                        mm_items,
                        placeholders,
                    )
371
372
373
374
375
376
377

                if intersecting_items:
                    mm_data[modality] = intersecting_items

        return mm_data, placeholder_maps

    def append_items_from_seq_group(
378
379
        self,
        positions: range,
380
        multi_modal_items: list[_T],
381
        multi_modal_placeholders: Sequence[PlaceholderRange],
382
    ) -> list[_T]:
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        """
        Adds the multi-modal items that intersect ```positions`` to this
        placeholder map and returns the intersecting items.
        """
        intersecting_items = []

        if len(multi_modal_items) != len(multi_modal_placeholders):
            raise ValueError(
                "Multi-modal placeholders and items must have the same length."
            )
        for placeholder_dict, mm_item in zip(multi_modal_placeholders,
                                             multi_modal_items):
            placeholder = range(
                placeholder_dict["offset"],
397
398
399
400
401
402
                placeholder_dict["offset"] + placeholder_dict["length"],
            )
            intersection = range(
                max(positions.start, placeholder.start),
                min(positions.stop, placeholder.stop),
            )
403
404
405
406
407

            if not intersection:
                # Skip this multi-modal item.
                continue

408
409
410
411
            token_embedding_range = range(
                intersection.start - positions.start,
                intersection.stop - positions.start,
            )
412
413
414

            multimodal_embedding_range = range(
                intersection.start - placeholder.start + self.src_len,
415
416
                intersection.stop - placeholder.start + self.src_len,
            )
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457

            intersecting_items.append(mm_item)
            self.dest_ranges.append(token_embedding_range)
            self.src_ranges.append(multimodal_embedding_range)
            self.src_len += len(placeholder)

        self.dest_len += len(positions)
        return intersecting_items

    def extend(self, other: "MultiModalPlaceholderMap"):
        """
        Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
        instance based on the source and destination tensors being
        concatenated.
        """

        self.src_ranges.extend(
            range(self.src_len + r.start, self.src_len + r.stop)
            for r in other.src_ranges)
        self.src_len += other.src_len
        self.dest_ranges.extend(
            range(self.dest_len + r.start, self.dest_len + r.stop)
            for r in other.dest_ranges)
        self.dest_len += other.dest_len

    def index_map(self) -> "IndexMap":
        """
        Finalizes the placeholder map into lists of indices that can be used to
        index the source and destination tensors.
        """

        src_indices = [i for r in self.src_ranges for i in r]
        dest_indices = [i for r in self.dest_ranges for i in r]

        if len(src_indices) != len(dest_indices):
            raise ValueError(
                f"The number of source ({len(src_indices)}) and destination "
                f"indices ({len(dest_indices)}) must be the same.")

        return MultiModalPlaceholderMap.IndexMap(src=src_indices,
                                                 dest=dest_indices)
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476


class MediaIO(ABC, Generic[_T]):

    @abstractmethod
    def load_bytes(self, data: bytes) -> _T:
        raise NotImplementedError

    @abstractmethod
    def load_base64(self, media_type: str, data: str) -> _T:
        """
        List of media types:
        https://www.iana.org/assignments/media-types/media-types.xhtml
        """
        raise NotImplementedError

    @abstractmethod
    def load_file(self, filepath: Path) -> _T:
        raise NotImplementedError