registry.py 12.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
6

7
8
import torch.nn as nn

9
from vllm.config.multimodal import BaseDummyOptions
10
from vllm.logger import init_logger
11
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
12
from vllm.utils import ClassRegistry
13

14
from .cache import BaseMultiModalProcessorCache
15
16
17
18
19
20
21
22
23
24
25
from .processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
)
from .profiling import (
    BaseDummyInputsBuilder,
    DummyDecoderData,
    DummyEncoderData,
    MultiModalProfiler,
)
26

27
28
29
if TYPE_CHECKING:
    from vllm.config import ModelConfig

30
31
logger = init_logger(__name__)

32
N = TypeVar("N", bound=type[nn.Module])
33
34
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
35
36


37
class ProcessingInfoFactory(Protocol[_I_co]):
38
39
40
41
42
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
43
44
45
46

    def __call__(
        self,
        ctx: InputProcessingContext,
47
    ) -> _I_co: ...
48
49


50
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
51
    """
52
53
54
    Constructs a
    [`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
    instance from the context.
55
56
    """

57
    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
58
59


60
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
61
62
63
64
65
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
66
67
68
69
70

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
71
        *,
72
        cache: Optional[BaseMultiModalProcessorCache] = None,
73
    ) -> BaseMultiModalProcessor[_I]: ...
74

75

76
77
78
79
80
81
82
83
84
85
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
86
        cache: Optional[BaseMultiModalProcessorCache] = None,
87
88
89
90
91
92
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


93
94
class MultiModalRegistry:
    """
95
    A registry that dispatches data processing according to the model.
96
97
    """

98
    def __init__(self) -> None:
99
        self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]()
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    def _extract_mm_options(
        self,
        model_config: "ModelConfig",
    ) -> Optional[Mapping[str, BaseDummyOptions]]:
        """
        Extract multimodal dummy options from model config.

        Returns None if no configurable options are found, otherwise returns
        a mapping of modality names to their dummy options.
        """
        if not model_config.multimodal_config:
            return None

        mm_options = {
            m: opt
            for m in model_config.multimodal_config.limit_per_prompt
117
            if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
118
119
120
121
        }

        return mm_options if len(mm_options) > 0 else None

122
123
124
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
        """
        Checks if the model supports multimodal inputs.
125
126
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
127
128
129
130
131
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

132
133
        info = self._create_processing_info(model_config, tokenizer=None)
        supported_modalities = info.get_supported_mm_limits()
134
135
136
137
138

        mm_config = model_config.get_multimodal_config()

        # Check if all supported modalities have limit == 0
        if all(
139
140
141
            mm_config.get_limit_per_prompt(modality) == 0
            for modality in supported_modalities
        ):
142
143
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
144
145
                "are set to 0, running in text-only mode."
            )
146
147
148
149
            return False

        return True

150
151
152
    def get_max_tokens_per_item_by_modality(
        self,
        model_config: "ModelConfig",
153
154
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
155
156
    ) -> Mapping[str, int]:
        """
157
        Get the maximum number of tokens per data item from each modality based
158
        on underlying model configuration.
159
        """
160
161
        if not model_config.is_multimodal_model:
            return {}
162

163
        processor = self.create_processor(model_config, cache=cache)
164
        profiler: MultiModalProfiler = MultiModalProfiler(processor)
165
166

        seq_len = model_config.max_model_len
167
        mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
168

169
        return profiler.get_mm_max_contiguous_tokens(
170
            seq_len,
171
            {modality: 1 for modality, limit in mm_limits.items() if limit > 0},
172
        )
173

174
175
176
    def get_max_tokens_per_item_by_nonzero_modality(
        self,
        model_config: "ModelConfig",
177
178
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
179
180
181
    ) -> Mapping[str, int]:
        """
        Get the maximum number of tokens per data item from each modality based
182
        on underlying model configuration, excluding modalities that user
183
184
185
        explicitly disabled via `limit_mm_per_prompt`.

        Note:
186
            This is currently directly used only in V1 for profiling the memory
187
188
            usage of a model.
        """
189
190
191
192
193
        mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
        max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
            model_config,
            cache=cache,
        )
194
195
196

        return {
            key: max_tokens_per_mm_item
197
            for key, max_tokens_per_mm_item in max_tokens_per_item.items()
198
            if mm_limits[key] > 0
199
200
        }

201
202
    def get_mm_limits_per_prompt(
        self,
203
        model_config: "ModelConfig",
204
205
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
206
207
208
209
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.
210
        """
211
212
        if not model_config.is_multimodal_model:
            return {}
213

214
        processor = self.create_processor(model_config, cache=cache)
215
        profiler: MultiModalProfiler = MultiModalProfiler(processor)
216
        return profiler.get_mm_limits()
217
218
219

    def register_processor(
        self,
220
221
222
223
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
224
225
    ):
        """
226
227
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
228
229
230
231
232
233

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
234
            if self._processor_factories.contains(model_cls, strict=True):
235
                logger.warning(
236
                    "Model class %s already has a multi-modal processor "
237
                    "registered to %s. It is overwritten by the new one.",
238
239
240
                    model_cls,
                    self,
                )
241

242
243
244
245
246
            self._processor_factories[model_cls] = _ProcessorFactories(
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
247
248
249
250
251

            return model_cls

        return wrapper

252
    def _get_model_cls(self, model_config: "ModelConfig"):
253
254
255
256
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
257
258
        return model_cls

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
        tokenizer: Optional[AnyTokenizer] = None,
    ) -> InputProcessingContext:
        if tokenizer is None and not model_config.skip_tokenizer_init:
            tokenizer = cached_tokenizer_from_config(model_config)
        return InputProcessingContext(model_config, tokenizer)

    def _create_processing_info(
        self,
        model_config: "ModelConfig",
        *,
        tokenizer: Optional[AnyTokenizer] = None,
    ) -> BaseProcessingInfo:
        model_cls = self._get_model_cls(model_config)
        factories = self._processor_factories[model_cls]
        ctx = self._create_processing_ctx(model_config, tokenizer)
        return factories.info(ctx)

279
280
281
    def create_processor(
        self,
        model_config: "ModelConfig",
282
        *,
283
        tokenizer: Optional[AnyTokenizer] = None,
284
        cache: Optional[BaseMultiModalProcessorCache] = None,
285
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
286
287
288
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
289
290
291
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

292
        model_cls = self._get_model_cls(model_config)
293
        factories = self._processor_factories[model_cls]
294

295
        ctx = self._create_processing_ctx(model_config, tokenizer)
296

297
        return factories.build_processor(ctx, cache=cache)
298
299
300
301
302

    def get_decoder_dummy_data(
        self,
        model_config: "ModelConfig",
        seq_len: int,
303
        mm_counts: Optional[Mapping[str, int]] = None,
304
305
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
306
307
308
309
310
311
    ) -> DummyDecoderData:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
312
        processor = self.create_processor(model_config, cache=cache)
313
314
315
316
317
318
319
        profiler: MultiModalProfiler = MultiModalProfiler(processor)

        # Extract configurable options from multimodal config.
        # Only include modalities that use advanced option types so legacy
        # count-only behavior remains unchanged.
        mm_options = self._extract_mm_options(model_config)

320
        dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options)
321
322
323
324
325
326

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            raise AssertionError(
                f"Expected at least {seq_len} dummy tokens for profiling, "
327
328
                f"but found {len(token_ids)} tokens instead."
            )
329
330
331
332
333
334
335

        return dummy_data

    def get_encoder_dummy_data(
        self,
        model_config: "ModelConfig",
        seq_len: int,
336
        mm_counts: Optional[Mapping[str, int]] = None,
337
338
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
339
340
341
342
343
344
    ) -> DummyEncoderData:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
345
        processor = self.create_processor(model_config, cache=cache)
346
347
348
349
350
351
352
        profiler: MultiModalProfiler = MultiModalProfiler(processor)

        # Extract configurable options from multimodal config.
        # Only include modalities that use advanced option types so legacy
        # count-only behavior remains unchanged.
        mm_options = self._extract_mm_options(model_config)

353
        dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options)
354
355
356
357
358

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            logger.warning_once(
359
360
361
362
                "Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.",  # noqa: E501
                seq_len,
                len(token_ids),
            )
363
364

        return dummy_data
365
366
367
368
369
370
371

    def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
        """
        Get the maximum length of the encoder input for encoder-decoder models.
        """
        if not model_config.is_encoder_decoder:
            return 0
372
        max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config)
373
374
375
376
377
        if not max_tokens:
            # TODO - this function assumes encoder-decoder models are
            # multimodal. This will need to change when adding support for more
            # than whisper.
            return 0
378
379
        assert len(max_tokens) == 1, (
            "Encoder-decoder models are expected \
380
            to implement the multimodal interface with at most one modality."
381
        )
382
383
384

        first_modality = next(iter(max_tokens))
        return max_tokens[first_modality]