registry.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
6
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
7

8
from vllm.config.multimodal import BaseDummyOptions
9
from vllm.config.observability import ObservabilityConfig
10
from vllm.logger import init_logger
11
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
12

13
14
15
16
17
18
19
20
21
from .cache import (
    BaseMultiModalProcessorCache,
    BaseMultiModalReceiverCache,
    MultiModalProcessorOnlyCache,
    MultiModalProcessorSenderCache,
    MultiModalReceiverCache,
    ShmObjectStoreReceiverCache,
    ShmObjectStoreSenderCache,
)
22
from .inputs import MultiModalInputs
23
from .processing import (
24
    BaseDummyInputsBuilder,
25
26
27
28
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
)
29

30
if TYPE_CHECKING:
31
    from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
32
    from vllm.model_executor.models.interfaces import SupportsMultiModal
33

34
35
logger = init_logger(__name__)

36
N = TypeVar("N", bound=type["SupportsMultiModal"])
37
38
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
39
40


41
class ProcessingInfoFactory(Protocol[_I_co]):
42
43
44
45
46
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
47
48
49
50

    def __call__(
        self,
        ctx: InputProcessingContext,
51
    ) -> _I_co: ...
52
53


54
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
55
    """
56
    Constructs a
57
    [`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
58
    instance from the context.
59
60
    """

61
    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
62
63


64
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
65
66
67
68
69
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
70
71
72
73
74

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
75
        *,
76
        cache: BaseMultiModalProcessorCache | None = None,
77
    ) -> BaseMultiModalProcessor[_I]: ...
78

79

80
81
82
83
84
85
86
87
88
89
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
90
        cache: BaseMultiModalProcessorCache | None = None,
91
92
93
94
95
96
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


97
98
class MultiModalRegistry:
    """
99
    A registry that dispatches data processing according to the model.
100
101
    """

102
103
104
    def _extract_mm_options(
        self,
        model_config: "ModelConfig",
105
    ) -> Mapping[str, BaseDummyOptions] | None:
106
107
108
109
110
111
112
113
114
115
116
117
        """
        Extract multimodal dummy options from model config.

        Returns None if no configurable options are found, otherwise returns
        a mapping of modality names to their dummy options.
        """
        if not model_config.multimodal_config:
            return None

        mm_options = {
            m: opt
            for m in model_config.multimodal_config.limit_per_prompt
118
            if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
119
120
121
122
        }

        return mm_options if len(mm_options) > 0 else None

123
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
124
125
        """
        Checks if the model supports multimodal inputs.
126
127
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
128
129
130
131
132
133
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

        mm_config = model_config.get_multimodal_config()
134
        info = self._create_processing_info(model_config, tokenizer=None)
135
136
137

        # Check if all supported modalities have limit == 0
        if all(
138
            mm_config.get_limit_per_prompt(modality) == 0
139
            for modality in info.supported_mm_limits
140
        ):
141
142
143
144
145
            # If enable_mm_embeds is True, we still need MM infrastructure
            # to process pre-computed embeddings even though encoder won't run
            if mm_config.enable_mm_embeds:
                return True

146
147
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
148
149
                "are set to 0, running in text-only mode."
            )
150
151
152
153
            return False

        return True

154
155
    def register_processor(
        self,
156
157
158
159
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
160
161
    ):
        """
162
163
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
164
165
166
167
168
169

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
170
            if "_processor_factory" in model_cls.__dict__:
171
                logger.warning(
172
                    "Model class %s already has a multi-modal processor "
173
                    "registered to %s. It is overwritten by the new one.",
174
175
176
                    model_cls,
                    self,
                )
177

178
            model_cls._processor_factory = _ProcessorFactories(
179
180
181
182
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
183
184
185
186
187

            return model_cls

        return wrapper

188
    def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
189
190
191
192
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
193
194
        assert hasattr(model_cls, "_processor_factory")
        return cast("SupportsMultiModal", model_cls)
195

196
197
198
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
199
        observability_config: "ObservabilityConfig | None" = None,
200
201
        tokenizer: TokenizerLike | None = None,
    ) -> InputProcessingContext:
202
        if tokenizer is None:
203
204
            tokenizer = cached_tokenizer_from_config(model_config)

205
206
207
        return InputProcessingContext(
            model_config, tokenizer, observability_config=observability_config
        )
208

209
210
    def _create_processing_info(
        self,
211
        model_config: "ModelConfig",
212
        observability_config: "ObservabilityConfig | None" = None,
213
        *,
214
        tokenizer: TokenizerLike | None = None,
215
    ) -> BaseProcessingInfo:
216
        model_cls = self._get_model_cls(model_config)
217
        factories = model_cls._processor_factory
218
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
219
220
        return factories.info(ctx)

221
222
    def create_processor(
        self,
223
        model_config: "ModelConfig",
224
        observability_config: "ObservabilityConfig | None" = None,
225
        *,
226
        tokenizer: TokenizerLike | None = None,
227
        cache: BaseMultiModalProcessorCache | None = None,
228
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
229
230
231
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
232
233
234
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

235
        model_cls = self._get_model_cls(model_config)
236
        factories = model_cls._processor_factory
237

238
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
239

240
        return factories.build_processor(ctx, cache=cache)
241

242
    def get_dummy_mm_inputs(
243
        self,
244
        model_config: "ModelConfig",
245
        mm_counts: Mapping[str, int],
246
        *,
247
        cache: BaseMultiModalProcessorCache | None = None,
248
249
        processor: BaseMultiModalProcessor | None = None,
    ) -> MultiModalInputs:
250
251
252
        """
        Create dummy data for profiling the memory usage of a model.

253
        The model is identified by `model_config`.
254
        """
255
256
257
        seq_len = model_config.max_model_len

        if processor is None:
258
            processor = self.create_processor(model_config, cache=cache)
259
260
261

        processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
            seq_len=seq_len,
262
263
264
            mm_counts=mm_counts,
            mm_options=self._extract_mm_options(model_config),
        )
265
266
        mm_inputs = processor.apply(
            prompt=processor_inputs.prompt,
267
            mm_items=processor_inputs.mm_items,
268
269
270
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
        )
271

272
273
274
275
        prompt_token_ids = mm_inputs["prompt_token_ids"]
        total_len = len(prompt_token_ids)
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
276

277
        return mm_inputs
278

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    def _get_cache_type(
        self,
        vllm_config: "VllmConfig",
    ) -> Literal[None, "processor_only", "lru", "shm"]:
        model_config = vllm_config.model_config
        if not self.supports_multimodal_inputs(model_config):
            return None

        # Check if the cache is disabled.
        mm_config = model_config.get_multimodal_config()
        if mm_config.mm_processor_cache_gb <= 0:
            return None

        # Check if IPC caching is supported.
        parallel_config = vllm_config.parallel_config
        is_ipc_supported = parallel_config._api_process_count == 1 and (
            parallel_config.data_parallel_size == 1
            or parallel_config.data_parallel_external_lb
        )

        if not is_ipc_supported:
            return "processor_only"

        mm_config = model_config.get_multimodal_config()
        return mm_config.mm_processor_cache_type

    def processor_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalProcessorCache | None:
        """Return a `BaseMultiModalProcessorCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None
        elif cache_type == "processor_only":
            return MultiModalProcessorOnlyCache(vllm_config.model_config)
        elif cache_type == "lru":
            return MultiModalProcessorSenderCache(vllm_config.model_config)
        elif cache_type == "shm":
            return ShmObjectStoreSenderCache(vllm_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def processor_only_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> MultiModalProcessorOnlyCache | None:
        """Return a `MultiModalProcessorOnlyCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None

        return MultiModalProcessorOnlyCache(vllm_config.model_config)

    def engine_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the engine process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "shm"):
            return None
        elif cache_type == "lru":
            return MultiModalReceiverCache(vllm_config.model_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def worker_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
        shared_worker_lock: LockType,
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the worker process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "lru"):
            return None
        elif cache_type == "shm":
            return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")