registry.py 13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
6

7
8
import torch.nn as nn

9
from vllm.config.multimodal import BaseDummyOptions
10
from vllm.logger import init_logger
11
12
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               cached_tokenizer_from_config)
13
from vllm.utils import ClassRegistry
14

15
from .cache import BaseMultiModalProcessorCache
16
17
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         InputProcessingContext)
18
19
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
                        DummyEncoderData, MultiModalProfiler)
20

21
22
23
if TYPE_CHECKING:
    from vllm.config import ModelConfig

24
25
logger = init_logger(__name__)

26
N = TypeVar("N", bound=type[nn.Module])
27
28
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
29
30


31
class ProcessingInfoFactory(Protocol[_I_co]):
32
33
34
35
36
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
37
38
39
40

    def __call__(
        self,
        ctx: InputProcessingContext,
41
42
43
44
    ) -> _I_co:
        ...


45
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
46
    """
47
48
49
    Constructs a
    [`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
    instance from the context.
50
51
52
53
54
55
    """

    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
        ...


56
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
57
58
59
60
61
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
62
63
64
65
66

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
67
        *,
68
        cache: Optional[BaseMultiModalProcessorCache] = None,
69
    ) -> BaseMultiModalProcessor[_I]:
70
        ...
71

72

73
74
75
76
77
78
79
80
81
82
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
83
        cache: Optional[BaseMultiModalProcessorCache] = None,
84
85
86
87
88
89
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


90
91
class MultiModalRegistry:
    """
92
    A registry that dispatches data processing according to the model.
93
94
    """

95
    def __init__(self) -> None:
96
        self._processor_factories = ClassRegistry[nn.Module,
97
                                                  _ProcessorFactories]()
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    def _extract_mm_options(
        self,
        model_config: "ModelConfig",
    ) -> Optional[Mapping[str, BaseDummyOptions]]:
        """
        Extract multimodal dummy options from model config.

        Returns None if no configurable options are found, otherwise returns
        a mapping of modality names to their dummy options.
        """
        if not model_config.multimodal_config:
            return None

        mm_options = {
            m: opt
            for m in model_config.multimodal_config.limit_per_prompt
            if (opt := model_config.multimodal_config.get_dummy_options(m)
                ) is not None
        }

        return mm_options if len(mm_options) > 0 else None

121
122
123
124
125
126
127
128
129
130
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
        """
        Checks if the model supports multimodal inputs.
        Returns True if the model is multimodal with any non-zero supported 
        modalities, otherwise returns False, effectively running in 
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

131
132
        info = self._create_processing_info(model_config, tokenizer=None)
        supported_modalities = info.get_supported_mm_limits()
133
134
135
136
137
138
139
140
141
142
143
144
145
146

        mm_config = model_config.get_multimodal_config()

        # Check if all supported modalities have limit == 0
        if all(
                mm_config.get_limit_per_prompt(modality) == 0
                for modality in supported_modalities):
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
                "are set to 0, running in text-only mode.")
            return False

        return True

147
148
149
    def get_max_tokens_per_item_by_modality(
        self,
        model_config: "ModelConfig",
150
151
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
152
153
    ) -> Mapping[str, int]:
        """
154
        Get the maximum number of tokens per data item from each modality based
155
        on underlying model configuration.
156
        """
157
158
        if not model_config.is_multimodal_model:
            return {}
159

160
        processor = self.create_processor(model_config, cache=cache)
161
        profiler: MultiModalProfiler = MultiModalProfiler(processor)
162
163

        seq_len = model_config.max_model_len
164
        mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
165

166
        return profiler.get_mm_max_contiguous_tokens(
167
168
169
170
171
172
            seq_len,
            {
                modality: 1
                for modality, limit in mm_limits.items() if limit > 0
            },
        )
173

174
175
176
    def get_max_tokens_per_item_by_nonzero_modality(
        self,
        model_config: "ModelConfig",
177
178
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
179
180
181
    ) -> Mapping[str, int]:
        """
        Get the maximum number of tokens per data item from each modality based
182
        on underlying model configuration, excluding modalities that user
183
184
185
        explicitly disabled via `limit_mm_per_prompt`.

        Note:
186
            This is currently directly used only in V1 for profiling the memory
187
188
            usage of a model.
        """
189
190
191
192
193
        mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
        max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
            model_config,
            cache=cache,
        )
194
195
196

        return {
            key: max_tokens_per_mm_item
197
            for key, max_tokens_per_mm_item in max_tokens_per_item.items()
198
            if mm_limits[key] > 0
199
200
        }

201
202
    def get_mm_limits_per_prompt(
        self,
203
        model_config: "ModelConfig",
204
205
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
206
207
208
209
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.
210
        """
211
212
        if not model_config.is_multimodal_model:
            return {}
213

214
        processor = self.create_processor(model_config, cache=cache)
215
        profiler: MultiModalProfiler = MultiModalProfiler(processor)
216
        return profiler.get_mm_limits()
217
218
219

    def register_processor(
        self,
220
221
222
223
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
224
225
    ):
        """
226
227
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
228
229
230
231
232
233

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
234
            if self._processor_factories.contains(model_cls, strict=True):
235
                logger.warning(
236
                    "Model class %s already has a multi-modal processor "
237
238
239
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

240
241
242
243
244
            self._processor_factories[model_cls] = _ProcessorFactories(
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
245
246
247
248
249

            return model_cls

        return wrapper

250
    def _get_model_cls(self, model_config: "ModelConfig"):
251
252
253
254
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
255
256
        return model_cls

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
        tokenizer: Optional[AnyTokenizer] = None,
    ) -> InputProcessingContext:
        if tokenizer is None and not model_config.skip_tokenizer_init:
            tokenizer = cached_tokenizer_from_config(model_config)
        return InputProcessingContext(model_config, tokenizer)

    def _create_processing_info(
        self,
        model_config: "ModelConfig",
        *,
        tokenizer: Optional[AnyTokenizer] = None,
    ) -> BaseProcessingInfo:
        model_cls = self._get_model_cls(model_config)
        factories = self._processor_factories[model_cls]
        ctx = self._create_processing_ctx(model_config, tokenizer)
        return factories.info(ctx)

277
278
279
    def create_processor(
        self,
        model_config: "ModelConfig",
280
        *,
281
        tokenizer: Optional[AnyTokenizer] = None,
282
        cache: Optional[BaseMultiModalProcessorCache] = None,
283
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
284
285
286
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
287
288
289
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

290
        model_cls = self._get_model_cls(model_config)
291
        factories = self._processor_factories[model_cls]
292

293
        ctx = self._create_processing_ctx(model_config, tokenizer)
294

295
        return factories.build_processor(ctx, cache=cache)
296
297
298
299
300

    def get_decoder_dummy_data(
        self,
        model_config: "ModelConfig",
        seq_len: int,
301
        mm_counts: Optional[Mapping[str, int]] = None,
302
303
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
304
305
306
307
308
309
    ) -> DummyDecoderData:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
310
        processor = self.create_processor(model_config, cache=cache)
311
312
313
314
315
316
317
318
319
        profiler: MultiModalProfiler = MultiModalProfiler(processor)

        # Extract configurable options from multimodal config.
        # Only include modalities that use advanced option types so legacy
        # count-only behavior remains unchanged.
        mm_options = self._extract_mm_options(model_config)

        dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts,
                                                     mm_options)
320
321
322
323
324
325
326
327
328
329
330
331
332
333

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            raise AssertionError(
                f"Expected at least {seq_len} dummy tokens for profiling, "
                f"but found {len(token_ids)} tokens instead.")

        return dummy_data

    def get_encoder_dummy_data(
        self,
        model_config: "ModelConfig",
        seq_len: int,
334
        mm_counts: Optional[Mapping[str, int]] = None,
335
336
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
337
338
339
340
341
342
    ) -> DummyEncoderData:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
343
        processor = self.create_processor(model_config, cache=cache)
344
345
346
347
348
349
350
351
352
        profiler: MultiModalProfiler = MultiModalProfiler(processor)

        # Extract configurable options from multimodal config.
        # Only include modalities that use advanced option types so legacy
        # count-only behavior remains unchanged.
        mm_options = self._extract_mm_options(model_config)

        dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts,
                                                     mm_options)
353
354
355
356
357

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            logger.warning_once(
358
359
360
361
                "Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.",  # noqa: E501
                seq_len,
                len(token_ids),
            )
362
363

        return dummy_data
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

    def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
        """
        Get the maximum length of the encoder input for encoder-decoder models.
        """
        if not model_config.is_encoder_decoder:
            return 0
        max_tokens = self.\
            get_max_tokens_per_item_by_nonzero_modality(model_config)
        if not max_tokens:
            # TODO - this function assumes encoder-decoder models are
            # multimodal. This will need to change when adding support for more
            # than whisper.
            return 0
        assert len(max_tokens) == 1, "Encoder-decoder models are expected \
            to implement the multimodal interface with at most one modality."

        first_modality = next(iter(max_tokens))
        return max_tokens[first_modality]