"vllm/vscode:/vscode.git/clone" did not exist on "2d940766aedf46428de8d3d2268d1f56ad75cfcc"
base.py 15.9 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from collections import defaultdict
3
4
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
5
6
                    Optional, Sequence, Tuple, Type, TypeVar, Union)

7
from torch import nn
8

9
from vllm.inputs import InputContext
10
from vllm.logger import init_logger
11
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
12
                        resolve_mm_processor_kwargs)
13

14
15
16
17
if TYPE_CHECKING:
    from vllm.config import ModelConfig
    from vllm.sequence import SequenceGroupMetadata

18
from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs,
19
                     PlaceholderRange)
20

21
logger = init_logger(__name__)
22

23
MultiModalInputMapper = Callable[[InputContext, ModalityData[object]],
24
                                 MultiModalKwargs]
25
26
"""
Return a dictionary to be passed as keyword arguments to
27
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
28
29
30
31
32
33
34
35
36
37
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
38

39
_T = TypeVar("_T")
40
41
N = TypeVar("N", bound=Type[nn.Module])

42

43
class MultiModalPlugin(ABC):
44
45
46
47
48
49
50
51
52
53
54
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
    """

    def __init__(self) -> None:
55
56
        self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
        self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
57
58

    @abstractmethod
59
    def get_data_key(self) -> str:
60
        """
61
        Get the data key corresponding to the modality.
62
63
64
65
        """
        raise NotImplementedError

    @abstractmethod
66
67
68
    def _default_input_mapper(
        self,
        ctx: InputContext,
69
        data: ModalityData[Any],
70
        **mm_processor_kwargs,
71
    ) -> MultiModalKwargs:
72
73
        """
        Return a dictionary to be passed as keyword arguments to
74
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
75
        tokenizers and processors in HuggingFace Transformers.
76
77

        If the data is not supported, throw :exc:`TypeError`.
78
79
80
        """
        raise NotImplementedError

81
82
    def register_input_mapper(
        self,
83
        mapper: Optional[MultiModalInputMapper] = None,
84
    ):
85
        """
86
        Register an input mapper to a model class.
87

88
        When the model receives input data that matches the modality served by
89
        this plugin (see :meth:`get_data_key`), the provided function is
90
        invoked to transform the data into a dictionary of model inputs.
91

92
93
94
        If `None` is provided, then the default input mapper is used instead.

        See also:
95
96
            - :ref:`input-processing-pipeline`
            - :ref:`enabling-multimodal-inputs`
97
98
99
        """

        def wrapper(model_cls: N) -> N:
100
            if self._input_mappers.contains(model_cls, strict=True):
101
                logger.warning(
102
                    "Model class %s already has an input mapper "
103
                    "registered to %s. It is overwritten by the new one.",
104
105
106
                    model_cls,
                    self,
                )
107

108
109
            self._input_mappers[model_cls] = (mapper
                                              or self._default_input_mapper)
110
111
112
113
114

            return model_cls

        return wrapper

115
116
117
    def map_input(
        self,
        model_config: "ModelConfig",
118
        data: ModalityData[Any],
119
        mm_processor_kwargs: Optional[dict[str, Any]],
120
    ) -> MultiModalKwargs:
121
        """
122
123
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
124
125
126

        The model is identified by ``model_config``.

127
128
129
        Raises:
            TypeError: If the data type is not supported.

130
        See also:
131
132
            - :ref:`input-processing-pipeline`
            - :ref:`enabling-multimodal-inputs`
133
        """
134

135
136
137
138
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
139

140
        mapper = self._input_mappers.get(model_cls)
141

142
143
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
144
145
                           f"model class {model_cls.__name__}.")

146
147
148
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        # In the case of the default mapper, we have to get resource
        # processor through its HuggingFace autoclass; since this goes
        # through **kwargs, we can't inspect it the same way, so we allow
        # drop mm_processor_kwargs based on signature inspection
        # if we're using the default mapper.
        #
        # This should be safe in general due to the sanitation, since the
        # transformers resource should filter unused kwargs anyway.
        uses_default_mapper = mapper == self._default_input_mapper
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
            mm_processor_kwargs,
            callable=mapper,
            allow_var_kwargs=uses_default_mapper,
        )
164
        return mapper(InputContext(model_config), data, **mm_processor_kwargs)
165
166
167
168

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
169
170
        Calculate the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model.
171
172
173
174
175
176
177
178
179
180
181
182
183
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
184
185
186
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model
        for a model class.
187
188
189
190

        If `None` is provided, then the default calculation is used instead.

        See also:
191
            :ref:`enabling-multimodal-inputs`
192
193
194
        """

        def wrapper(model_cls: N) -> N:
195
            if self._max_mm_tokens.contains(model_cls, strict=True):
196
197
198
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
199
200
201
                    model_cls,
                    self,
                )
202
203
204
205

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

206
207
            self._max_mm_tokens[model_cls] = (
                max_mm_tokens or self._default_max_multimodal_tokens)
208
209
210
211
212

            return model_cls

        return wrapper

213
    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
214
215
216
217
218
219
220
221
222
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
223
            :ref:`enabling-multimodal-inputs`
224
225
226
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture
227
        from vllm.model_executor.models import supports_multimodal
228
229
230

        model_cls, _ = get_model_architecture(model_config)

231
        if not supports_multimodal(model_cls):
232
233
234
235
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
236
            return 0
237
238

        if callable(max_mm_tokens):
239
240
241
242
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                max_mm_tokens, overrides=model_config.mm_processor_kwargs)
            max_mm_tokens = max_mm_tokens(InputContext(model_config),
                                          **mm_processor_kwargs)
243
244
245
246

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens
247
248
249
250
251
252
253
254


class MultiModalPlaceholderMap:
    """
    Relates multi-modal embeddings to their corresponding placeholders.
    """

    class IndexMap(NamedTuple):
255
256
        src: list[int]
        dest: list[int]
257

258
    src_ranges: list[range]
259
260
261
262
263
264
265
266
267
268
    """
    The indices of the multi-modal embeddings that will replace the
    corresponding placeholder embeddings pointed to by ``dest_ranges``.
    """

    src_len: int
    """
    The total number of flattened multi-modal embeddings.
    """

269
    dest_ranges: list[range]
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    """
    The indices of the placeholder embeddings that will be replaced by the
    multimodal embeddings.
    """

    dest_len: int
    """
    The total number of embeddings in the destination tensor.
    """

    def __init__(self):
        self.src_ranges = []
        self.src_len = 0
        self.dest_ranges = []
        self.dest_len = 0

    @classmethod
    def from_seq_group(
        cls, seq_group: "SequenceGroupMetadata", positions: range
289
    ) -> Tuple[Optional[MultiModalDataDict], dict[str,
290
291
292
293
294
295
296
                                                  "MultiModalPlaceholderMap"]]:
        """
        Returns the multi-modal items that intersect with the portion of a
        prompt (``seq_group``) represented by ``positions``, as well as a
        ``MultiModalPlaceholderMap`` that relates the multi-modal embedding
        vectors to their corresponding placeholders.

297
        Examples:
298

299
        .. code-block::
300

301
302
            Prompt:    |AAAA BBBB What's in these images?|
            Positions: |.................................|
303

304
305
306
                images      = [A, B]
                src_ranges  = [(0, 4), (4, 8)]
                dest_ranges = [(0, 4), (5, 9)]
307

308
309
            Prompt:    |AAAA BBBB What's in these images?|
            Positions: |  .....                          |
310

311
312
313
                images      = [A, B]
                src_ranges  = [(2, 4), (4, 6)]
                dest_ranges = [(0, 2), (3, 5)]
314

315
316
            Prompt:    |AAAA BBBB What's in these images?|
            Positions: |     .........                   |
317

318
319
320
                images      = [B]
                src_ranges  = [(0, 4)]
                dest_ranges = [(0, 4)]
321

322
323
324
325
326
327
            Prompt:    |AAAA BBBB What's in these images?|
            Positions: |          .......................|

                images      = []
                src_ranges  = []
                dest_ranges = []
328
        """
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        seq_mm_data = seq_group.multi_modal_data
        seq_mm_placeholders = seq_group.multi_modal_placeholders

        if not seq_mm_data or not seq_mm_placeholders:
            return seq_mm_data, {}

        # For merged processor, we directly use mm_kwargs as mm_data
        if isinstance(seq_mm_data, MultiModalKwargs):
            placeholder_maps = dict[str, MultiModalPlaceholderMap]()

            for modality, placeholders in seq_mm_placeholders.items():
                placeholder_map = MultiModalPlaceholderMap()

                if positions:
                    placeholder_map.append_items_from_seq_group(
                        positions,
                        # Dummy, since we don't care about intersecting items
                        [None] * len(placeholders),
                        placeholders,
                    )

                placeholder_maps[modality] = placeholder_map

            return seq_mm_data, placeholder_maps
353

354
355
        mm_data = {**seq_mm_data}
        placeholder_maps = defaultdict[str, MultiModalPlaceholderMap](
356
357
            MultiModalPlaceholderMap)

358
        for modality, placeholders in seq_mm_placeholders.items():
359
360
361
362
363
            mm_items = mm_data.pop(modality)
            if not isinstance(mm_items, list):
                mm_items = [mm_items]

            if positions:
364
365
366
367
368
369
                intersecting_items = placeholder_maps[modality] \
                    .append_items_from_seq_group(
                        positions,
                        mm_items,
                        placeholders,
                    )
370
371
372
373
374
375
376

                if intersecting_items:
                    mm_data[modality] = intersecting_items

        return mm_data, placeholder_maps

    def append_items_from_seq_group(
377
378
        self,
        positions: range,
379
        multi_modal_items: list[_T],
380
        multi_modal_placeholders: Sequence[PlaceholderRange],
381
    ) -> list[_T]:
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        """
        Adds the multi-modal items that intersect ```positions`` to this
        placeholder map and returns the intersecting items.
        """
        intersecting_items = []

        if len(multi_modal_items) != len(multi_modal_placeholders):
            raise ValueError(
                "Multi-modal placeholders and items must have the same length."
            )
        for placeholder_dict, mm_item in zip(multi_modal_placeholders,
                                             multi_modal_items):
            placeholder = range(
                placeholder_dict["offset"],
396
397
398
399
400
401
                placeholder_dict["offset"] + placeholder_dict["length"],
            )
            intersection = range(
                max(positions.start, placeholder.start),
                min(positions.stop, placeholder.stop),
            )
402
403
404
405
406

            if not intersection:
                # Skip this multi-modal item.
                continue

407
408
409
410
            token_embedding_range = range(
                intersection.start - positions.start,
                intersection.stop - positions.start,
            )
411
412
413

            multimodal_embedding_range = range(
                intersection.start - placeholder.start + self.src_len,
414
415
                intersection.stop - placeholder.start + self.src_len,
            )
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456

            intersecting_items.append(mm_item)
            self.dest_ranges.append(token_embedding_range)
            self.src_ranges.append(multimodal_embedding_range)
            self.src_len += len(placeholder)

        self.dest_len += len(positions)
        return intersecting_items

    def extend(self, other: "MultiModalPlaceholderMap"):
        """
        Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
        instance based on the source and destination tensors being
        concatenated.
        """

        self.src_ranges.extend(
            range(self.src_len + r.start, self.src_len + r.stop)
            for r in other.src_ranges)
        self.src_len += other.src_len
        self.dest_ranges.extend(
            range(self.dest_len + r.start, self.dest_len + r.stop)
            for r in other.dest_ranges)
        self.dest_len += other.dest_len

    def index_map(self) -> "IndexMap":
        """
        Finalizes the placeholder map into lists of indices that can be used to
        index the source and destination tensors.
        """

        src_indices = [i for r in self.src_ranges for i in r]
        dest_indices = [i for r in self.dest_ranges for i in r]

        if len(src_indices) != len(dest_indices):
            raise ValueError(
                f"The number of source ({len(src_indices)}) and destination "
                f"indices ({len(dest_indices)}) must be the same.")

        return MultiModalPlaceholderMap.IndexMap(src=src_indices,
                                                 dest=dest_indices)
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475


class MediaIO(ABC, Generic[_T]):

    @abstractmethod
    def load_bytes(self, data: bytes) -> _T:
        raise NotImplementedError

    @abstractmethod
    def load_base64(self, media_type: str, data: str) -> _T:
        """
        List of media types:
        https://www.iana.org/assignments/media-types/media-types.xhtml
        """
        raise NotImplementedError

    @abstractmethod
    def load_file(self, filepath: Path) -> _T:
        raise NotImplementedError