registry.py 11.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
6
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
7

8
from vllm.config.multimodal import BaseDummyOptions
9
from vllm.config.observability import ObservabilityConfig
10
from vllm.logger import init_logger
11
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
12

13
14
15
16
17
18
19
20
21
from .cache import (
    BaseMultiModalProcessorCache,
    BaseMultiModalReceiverCache,
    MultiModalProcessorOnlyCache,
    MultiModalProcessorSenderCache,
    MultiModalReceiverCache,
    ShmObjectStoreReceiverCache,
    ShmObjectStoreSenderCache,
)
22
from .inputs import MultiModalInputs
23
from .processing import (
24
    BaseDummyInputsBuilder,
25
26
27
28
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
)
29

30
if TYPE_CHECKING:
31
    from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
32
    from vllm.model_executor.models.interfaces import SupportsMultiModal
33

34
35
logger = init_logger(__name__)

36
N = TypeVar("N", bound=type["SupportsMultiModal"])
37
38
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
39
40


41
class ProcessingInfoFactory(Protocol[_I_co]):
42
43
44
45
46
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
47
48
49
50

    def __call__(
        self,
        ctx: InputProcessingContext,
51
    ) -> _I_co: ...
52
53


54
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
55
    """
56
    Constructs a
57
    [`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
58
    instance from the context.
59
60
    """

61
    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
62
63


64
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
65
66
67
68
69
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
70
71
72
73
74

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
75
        *,
76
        cache: BaseMultiModalProcessorCache | None = None,
77
    ) -> BaseMultiModalProcessor[_I]: ...
78

79

80
81
82
83
84
85
86
87
88
89
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
90
        cache: BaseMultiModalProcessorCache | None = None,
91
92
93
94
95
96
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


97
98
class MultiModalRegistry:
    """
99
    A registry that dispatches data processing according to the model.
100
101
    """

102
103
104
    def _extract_mm_options(
        self,
        model_config: "ModelConfig",
105
    ) -> Mapping[str, BaseDummyOptions] | None:
106
107
108
109
110
111
112
113
114
115
116
117
        """
        Extract multimodal dummy options from model config.

        Returns None if no configurable options are found, otherwise returns
        a mapping of modality names to their dummy options.
        """
        if not model_config.multimodal_config:
            return None

        mm_options = {
            m: opt
            for m in model_config.multimodal_config.limit_per_prompt
118
            if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
119
120
121
122
        }

        return mm_options if len(mm_options) > 0 else None

123
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
124
125
        """
        Checks if the model supports multimodal inputs.
126
127
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
128
129
130
131
132
133
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

        mm_config = model_config.get_multimodal_config()
134
        info = self._create_processing_info(model_config, tokenizer=None)
135
136
137

        # Check if all supported modalities have limit == 0
        if all(
138
            mm_config.get_limit_per_prompt(modality) == 0
139
            for modality in info.supported_mm_limits
140
        ):
141
142
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
143
144
                "are set to 0, running in text-only mode."
            )
145
146
147
148
            return False

        return True

149
150
    def register_processor(
        self,
151
152
153
154
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
155
156
    ):
        """
157
158
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
159
160
161
162
163
164

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
165
            if "_processor_factory" in model_cls.__dict__:
166
                logger.warning(
167
                    "Model class %s already has a multi-modal processor "
168
                    "registered to %s. It is overwritten by the new one.",
169
170
171
                    model_cls,
                    self,
                )
172

173
            model_cls._processor_factory = _ProcessorFactories(
174
175
176
177
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
178
179
180
181
182

            return model_cls

        return wrapper

183
    def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
184
185
186
187
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
188
189
        assert hasattr(model_cls, "_processor_factory")
        return cast("SupportsMultiModal", model_cls)
190

191
192
193
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
194
        observability_config: "ObservabilityConfig | None" = None,
195
196
        tokenizer: TokenizerLike | None = None,
    ) -> InputProcessingContext:
197
        if tokenizer is None:
198
199
            tokenizer = cached_tokenizer_from_config(model_config)

200
201
202
        return InputProcessingContext(
            model_config, tokenizer, observability_config=observability_config
        )
203

204
205
    def _create_processing_info(
        self,
206
        model_config: "ModelConfig",
207
        observability_config: "ObservabilityConfig | None" = None,
208
        *,
209
        tokenizer: TokenizerLike | None = None,
210
    ) -> BaseProcessingInfo:
211
        model_cls = self._get_model_cls(model_config)
212
        factories = model_cls._processor_factory
213
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
214
215
        return factories.info(ctx)

216
217
    def create_processor(
        self,
218
        model_config: "ModelConfig",
219
        observability_config: "ObservabilityConfig | None" = None,
220
        *,
221
        tokenizer: TokenizerLike | None = None,
222
        cache: BaseMultiModalProcessorCache | None = None,
223
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
224
225
226
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
227
228
229
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

230
        model_cls = self._get_model_cls(model_config)
231
        factories = model_cls._processor_factory
232

233
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
234

235
        return factories.build_processor(ctx, cache=cache)
236

237
    def get_dummy_mm_inputs(
238
        self,
239
        model_config: "ModelConfig",
240
        mm_counts: Mapping[str, int],
241
        *,
242
        cache: BaseMultiModalProcessorCache | None = None,
243
244
        processor: BaseMultiModalProcessor | None = None,
    ) -> MultiModalInputs:
245
246
247
        """
        Create dummy data for profiling the memory usage of a model.

248
        The model is identified by `model_config`.
249
        """
250
251
252
        seq_len = model_config.max_model_len

        if processor is None:
253
            processor = self.create_processor(model_config, cache=cache)
254
255
256

        processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
            seq_len=seq_len,
257
258
259
            mm_counts=mm_counts,
            mm_options=self._extract_mm_options(model_config),
        )
260
261
        mm_inputs = processor.apply(
            prompt=processor_inputs.prompt,
262
            mm_items=processor_inputs.mm_items,
263
264
265
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
        )
266

267
268
269
270
        prompt_token_ids = mm_inputs["prompt_token_ids"]
        total_len = len(prompt_token_ids)
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
271

272
        return mm_inputs
273

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    def _get_cache_type(
        self,
        vllm_config: "VllmConfig",
    ) -> Literal[None, "processor_only", "lru", "shm"]:
        model_config = vllm_config.model_config
        if not self.supports_multimodal_inputs(model_config):
            return None

        # Check if the cache is disabled.
        mm_config = model_config.get_multimodal_config()
        if mm_config.mm_processor_cache_gb <= 0:
            return None

        # Check if IPC caching is supported.
        parallel_config = vllm_config.parallel_config
        is_ipc_supported = parallel_config._api_process_count == 1 and (
            parallel_config.data_parallel_size == 1
            or parallel_config.data_parallel_external_lb
        )

        if not is_ipc_supported:
            return "processor_only"

        mm_config = model_config.get_multimodal_config()
        return mm_config.mm_processor_cache_type

    def processor_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalProcessorCache | None:
        """Return a `BaseMultiModalProcessorCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None
        elif cache_type == "processor_only":
            return MultiModalProcessorOnlyCache(vllm_config.model_config)
        elif cache_type == "lru":
            return MultiModalProcessorSenderCache(vllm_config.model_config)
        elif cache_type == "shm":
            return ShmObjectStoreSenderCache(vllm_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def processor_only_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> MultiModalProcessorOnlyCache | None:
        """Return a `MultiModalProcessorOnlyCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None

        return MultiModalProcessorOnlyCache(vllm_config.model_config)

    def engine_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the engine process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "shm"):
            return None
        elif cache_type == "lru":
            return MultiModalReceiverCache(vllm_config.model_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def worker_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
        shared_worker_lock: LockType,
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the worker process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "lru"):
            return None
        elif cache_type == "shm":
            return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")