registry.py 11.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
6

7
8
import torch.nn as nn

9
from vllm.logger import init_logger
10
11
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               cached_tokenizer_from_config)
12
from vllm.utils import ClassRegistry
13

14
from .cache import BaseMultiModalProcessorCache
15
16
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         InputProcessingContext)
17
18
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
                        DummyEncoderData, MultiModalProfiler)
19

20
21
22
if TYPE_CHECKING:
    from vllm.config import ModelConfig

23
24
logger = init_logger(__name__)

25
N = TypeVar("N", bound=type[nn.Module])
26
27
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
28
29


30
class ProcessingInfoFactory(Protocol[_I_co]):
31
32
33
34
35
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
36
37
38
39

    def __call__(
        self,
        ctx: InputProcessingContext,
40
41
42
43
    ) -> _I_co:
        ...


44
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
45
    """
46
47
48
    Constructs a
    [`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
    instance from the context.
49
50
51
52
53
54
55
    """

    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
        ...


class MultiModalProcessorFactory(Protocol[_I]):
56
57
58
59
60
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
61
62
63
64
65

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
66
        *,
67
        cache: Optional[BaseMultiModalProcessorCache] = None,
68
    ) -> BaseMultiModalProcessor[_I]:
69
        ...
70

71

72
73
74
75
76
77
78
79
80
81
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
82
        cache: Optional[BaseMultiModalProcessorCache] = None,
83
84
85
86
87
88
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


89
90
class MultiModalRegistry:
    """
91
    A registry that dispatches data processing according to the model.
92
93
    """

94
    def __init__(self) -> None:
95
        self._processor_factories = ClassRegistry[nn.Module,
96
                                                  _ProcessorFactories]()
97

98
99
100
101
102
103
104
105
106
107
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
        """
        Checks if the model supports multimodal inputs.
        Returns True if the model is multimodal with any non-zero supported 
        modalities, otherwise returns False, effectively running in 
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

108
109
        info = self._create_processing_info(model_config, tokenizer=None)
        supported_modalities = info.get_supported_mm_limits()
110
111
112
113
114
115
116
117
118
119
120
121
122
123

        mm_config = model_config.get_multimodal_config()

        # Check if all supported modalities have limit == 0
        if all(
                mm_config.get_limit_per_prompt(modality) == 0
                for modality in supported_modalities):
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
                "are set to 0, running in text-only mode.")
            return False

        return True

124
125
126
    def get_max_tokens_per_item_by_modality(
        self,
        model_config: "ModelConfig",
127
128
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
129
130
    ) -> Mapping[str, int]:
        """
131
        Get the maximum number of tokens per data item from each modality based
132
        on underlying model configuration.
133
        """
134
135
        if not model_config.is_multimodal_model:
            return {}
136

137
        processor = self.create_processor(model_config, cache=cache)
138
139
140
        profiler = MultiModalProfiler(processor)

        seq_len = model_config.max_model_len
141
        mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
142

143
        return profiler.get_mm_max_contiguous_tokens(
144
145
146
147
148
149
            seq_len,
            {
                modality: 1
                for modality, limit in mm_limits.items() if limit > 0
            },
        )
150

151
152
153
    def get_max_tokens_per_item_by_nonzero_modality(
        self,
        model_config: "ModelConfig",
154
155
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
156
157
158
    ) -> Mapping[str, int]:
        """
        Get the maximum number of tokens per data item from each modality based
159
        on underlying model configuration, excluding modalities that user
160
161
162
        explicitly disabled via `limit_mm_per_prompt`.

        Note:
163
            This is currently directly used only in V1 for profiling the memory
164
165
            usage of a model.
        """
166
167
168
169
170
        mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
        max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
            model_config,
            cache=cache,
        )
171
172
173

        return {
            key: max_tokens_per_mm_item
174
            for key, max_tokens_per_mm_item in max_tokens_per_item.items()
175
            if mm_limits[key] > 0
176
177
        }

178
179
    def get_mm_limits_per_prompt(
        self,
180
        model_config: "ModelConfig",
181
182
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
183
184
185
186
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.
187
        """
188
189
        if not model_config.is_multimodal_model:
            return {}
190

191
        processor = self.create_processor(model_config, cache=cache)
192
193
        profiler = MultiModalProfiler(processor)
        return profiler.get_mm_limits()
194
195
196

    def register_processor(
        self,
197
198
199
200
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
201
202
    ):
        """
203
204
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
205
206
207
208
209
210

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
211
            if self._processor_factories.contains(model_cls, strict=True):
212
                logger.warning(
213
                    "Model class %s already has a multi-modal processor "
214
215
216
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

217
218
219
220
221
            self._processor_factories[model_cls] = _ProcessorFactories(
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
222
223
224
225
226

            return model_cls

        return wrapper

227
    def _get_model_cls(self, model_config: "ModelConfig"):
228
229
230
231
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
232
233
        return model_cls

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
        tokenizer: Optional[AnyTokenizer] = None,
    ) -> InputProcessingContext:
        if tokenizer is None and not model_config.skip_tokenizer_init:
            tokenizer = cached_tokenizer_from_config(model_config)
        return InputProcessingContext(model_config, tokenizer)

    def _create_processing_info(
        self,
        model_config: "ModelConfig",
        *,
        tokenizer: Optional[AnyTokenizer] = None,
    ) -> BaseProcessingInfo:
        model_cls = self._get_model_cls(model_config)
        factories = self._processor_factories[model_cls]
        ctx = self._create_processing_ctx(model_config, tokenizer)
        return factories.info(ctx)

254
255
256
    def create_processor(
        self,
        model_config: "ModelConfig",
257
        *,
258
        tokenizer: Optional[AnyTokenizer] = None,
259
        cache: Optional[BaseMultiModalProcessorCache] = None,
260
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
261
262
263
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
264
265
266
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

267
        model_cls = self._get_model_cls(model_config)
268
        factories = self._processor_factories[model_cls]
269

270
        ctx = self._create_processing_ctx(model_config, tokenizer)
271

272
        return factories.build_processor(ctx, cache=cache)
273
274
275
276
277

    def get_decoder_dummy_data(
        self,
        model_config: "ModelConfig",
        seq_len: int,
278
        mm_counts: Optional[Mapping[str, int]] = None,
279
280
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
281
282
283
284
285
286
    ) -> DummyDecoderData:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
287
        processor = self.create_processor(model_config, cache=cache)
288
        profiler = MultiModalProfiler(processor)
289
        dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
290
291
292
293
294
295
296
297
298
299
300
301
302
303

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            raise AssertionError(
                f"Expected at least {seq_len} dummy tokens for profiling, "
                f"but found {len(token_ids)} tokens instead.")

        return dummy_data

    def get_encoder_dummy_data(
        self,
        model_config: "ModelConfig",
        seq_len: int,
304
        mm_counts: Optional[Mapping[str, int]] = None,
305
306
        *,
        cache: Optional[BaseMultiModalProcessorCache] = None,
307
308
309
310
311
312
    ) -> DummyEncoderData:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
313
        processor = self.create_processor(model_config, cache=cache)
314
        profiler = MultiModalProfiler(processor)
315
        dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
316
317
318
319
320

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            logger.warning_once(
321
322
323
324
                "Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.",  # noqa: E501
                seq_len,
                len(token_ids),
            )
325
326

        return dummy_data
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345

    def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
        """
        Get the maximum length of the encoder input for encoder-decoder models.
        """
        if not model_config.is_encoder_decoder:
            return 0
        max_tokens = self.\
            get_max_tokens_per_item_by_nonzero_modality(model_config)
        if not max_tokens:
            # TODO - this function assumes encoder-decoder models are
            # multimodal. This will need to change when adding support for more
            # than whisper.
            return 0
        assert len(max_tokens) == 1, "Encoder-decoder models are expected \
            to implement the multimodal interface with at most one modality."

        first_modality = next(iter(max_tokens))
        return max_tokens[first_modality]