registry.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
from collections.abc import Mapping
3
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
5

6
import torch.nn as nn
7
from typing_extensions import deprecated
8

9
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
10
from vllm.inputs import InputProcessingContext
11
from vllm.logger import init_logger
12
13
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               cached_tokenizer_from_config)
14
from vllm.utils import ClassRegistry
15

16
17
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         ProcessingCache)
18
19
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
                        DummyEncoderData, MultiModalProfiler)
20

21
22
23
if TYPE_CHECKING:
    from vllm.config import ModelConfig

24
25
logger = init_logger(__name__)

26
N = TypeVar("N", bound=type[nn.Module])
27
28
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
29
30


31
class ProcessingInfoFactory(Protocol[_I_co]):
32
    """Constructs a {class}`MultiModalProcessor` instance from the context."""
33
34
35
36

    def __call__(
        self,
        ctx: InputProcessingContext,
37
38
39
40
41
42
    ) -> _I_co:
        ...


class DummyInputsBuilderFactory(Protocol[_I]):
    """
43
    Constructs a {class}`BaseDummyInputsBuilder` instance from the context.
44
45
46
47
48
49
50
    """

    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
        ...


class MultiModalProcessorFactory(Protocol[_I]):
51
    """Constructs a {class}`MultiModalProcessor` instance from the context."""
52
53
54
55
56

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
57
58
        *,
        cache: Optional[ProcessingCache] = None,
59
    ) -> BaseMultiModalProcessor[_I]:
60
        ...
61

62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
        cache: Optional[ProcessingCache] = None,
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


80
81
class MultiModalRegistry:
    """
82
    A registry that dispatches data processing according to the model.
83
84
    """

85
    def __init__(self) -> None:
86
        self._processor_factories = ClassRegistry[nn.Module,
87
                                                  _ProcessorFactories]()
88

89
        self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
90

91
92
93
94
95
96
    def reset_processor_cache(self) -> bool:
        """Reset the multi-modal processing cache."""
        self._processing_cache.reset()

        return True  # Success

97
98
99
100
    @deprecated("Legacy input processor/mapper pipeline has been removed. "
                "Please update your model runner to use "
                "`seq_group_metadata.multi_modal_data` directly without "
                "further processing.")
101
    def create_input_mapper(self, model_config: "ModelConfig"):
102
        return lambda data, mm_processor_kwargs: data
103

104
105
106
107
108
    def get_max_tokens_per_item_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
        """
109
        Get the maximum number of tokens per data item from each modality based
110
        on underlying model configuration.
111
        """
112
113
        if not model_config.is_multimodal_model:
            return {}
114

115
        processor = self.create_processor(model_config, disable_cache=False)
116
117
118
119
120
121
122
123
124
125
126
127
        profiler = MultiModalProfiler(processor)

        seq_len = model_config.max_model_len
        mm_limits = self.get_mm_limits_per_prompt(model_config)

        return profiler.get_mm_max_tokens(
            seq_len,
            {
                modality: 1
                for modality, limit in mm_limits.items() if limit > 0
            },
        )
128

129
130
131
132
133
134
    def get_max_tokens_per_item_by_nonzero_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
        """
        Get the maximum number of tokens per data item from each modality based
135
        on underlying model configuration, excluding modalities that user
136
137
138
        explicitly disabled via `limit_mm_per_prompt`.

        Note:
139
            This is currently directly used only in V1 for profiling the memory
140
141
            usage of a model.
        """
142
        mm_limits = self.get_mm_limits_per_prompt(model_config)
143
144
145
146
147

        return {
            key: max_tokens_per_mm_item
            for key, max_tokens_per_mm_item in
            self.get_max_tokens_per_item_by_modality(model_config).items()
148
            if mm_limits[key] > 0
149
150
        }

151
152
153
154
    def get_max_tokens_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
155
        """
156
        Get the maximum number of tokens from each modality
157
        for profiling the memory usage of a model.
158

159
        See {meth}`MultiModalPlugin.get_max_multimodal_tokens` for more details.
160
        """
161
        mm_limits = self.get_mm_limits_per_prompt(model_config)
162

163
        return {
164
            key: mm_limits[key] * max_tokens_per_mm_item
165
166
            for key, max_tokens_per_mm_item in
            self.get_max_tokens_per_item_by_modality(model_config).items()
167
168
169
170
171
172
173
        }

    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

174
        See {meth}`MultiModalPlugin.get_max_multimodal_tokens` for more details.
175
176
        """
        return sum(self.get_max_tokens_by_modality(model_config).values())
177

178
179
180
181
    @deprecated("Legacy input processor/mapper pipeline has been removed. "
                "Please update your model runner to use "
                "`seq_group_metadata.multi_modal_data` directly without "
                "further processing.")
182
183
    def init_mm_limits_per_prompt(
        self,
184
        model_config: "ModelConfig",
185
    ) -> None:
186
        pass
187
188
189

    def get_mm_limits_per_prompt(
        self,
190
        model_config: "ModelConfig",
191
192
193
194
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.
195
        """
196
197
        if not model_config.is_multimodal_model:
            return {}
198

199
        processor = self.create_processor(model_config, disable_cache=False)
200
201
        profiler = MultiModalProfiler(processor)
        return profiler.get_mm_limits()
202
203
204

    def register_processor(
        self,
205
206
207
208
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
209
210
    ):
        """
211
212
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
213
214
215
216

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.

217
218
219
        :::{seealso}
        {ref}`mm-processing`
        :::
220
221
222
        """

        def wrapper(model_cls: N) -> N:
223
            if self._processor_factories.contains(model_cls, strict=True):
224
                logger.warning(
225
                    "Model class %s already has a multi-modal processor "
226
227
228
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

229
230
231
232
233
            self._processor_factories[model_cls] = _ProcessorFactories(
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
234
235
236
237
238

            return model_cls

        return wrapper

239
    def _get_model_cls(self, model_config: "ModelConfig"):
240
241
242
243
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
244
245
        return model_cls

246
247
248
249
    @deprecated("Legacy input processor/mapper pipeline has been removed. "
                "Please update your model runner to use "
                "`seq_group_metadata.multi_modal_data` directly without "
                "further processing.")
250
    def has_processor(self, model_config: "ModelConfig") -> bool:
251
        return True
252
253
254
255

    def create_processor(
        self,
        model_config: "ModelConfig",
256
        *,
257
        tokenizer: Optional[AnyTokenizer] = None,
258
        disable_cache: Optional[bool] = None,
259
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
260
261
        """
        Create a multi-modal processor for a specific model and tokenizer.
262

263
264
265
        :::{seealso}
        {ref}`mm-processing`
        :::
266
        """
267
268
269
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

270
271
        if tokenizer is None:
            tokenizer = cached_tokenizer_from_config(model_config)
272
        if disable_cache is None:
273
274
            mm_config = model_config.get_multimodal_config()
            disable_cache = mm_config.disable_mm_preprocessor_cache
275

276
        model_cls = self._get_model_cls(model_config)
277
        factories = self._processor_factories[model_cls]
278
279

        ctx = InputProcessingContext(model_config, tokenizer)
280
        cache = None if disable_cache else self._processing_cache
281

282
        return factories.build_processor(ctx, cache=cache)
283
284
285
286
287

    def get_decoder_dummy_data(
        self,
        model_config: "ModelConfig",
        seq_len: int,
288
        mm_counts: Optional[Mapping[str, int]] = None,
289
290
291
292
293
294
    ) -> DummyDecoderData:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
295
        processor = self.create_processor(model_config, disable_cache=False)
296
        profiler = MultiModalProfiler(processor)
297
        dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
298
299
300
301
302
303
304
305
306
307
308
309
310
311

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            raise AssertionError(
                f"Expected at least {seq_len} dummy tokens for profiling, "
                f"but found {len(token_ids)} tokens instead.")

        return dummy_data

    def get_encoder_dummy_data(
        self,
        model_config: "ModelConfig",
        seq_len: int,
312
        mm_counts: Optional[Mapping[str, int]] = None,
313
314
315
316
317
318
    ) -> DummyEncoderData:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
319
        processor = self.create_processor(model_config, disable_cache=False)
320
        profiler = MultiModalProfiler(processor)
321
        dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
322
323
324
325
326

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            logger.warning_once(
327
328
329
330
                "Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.",  # noqa: E501
                seq_len,
                len(token_ids),
            )
331
332

        return dummy_data