registry.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import threading
from collections import defaultdict
5
from collections.abc import Mapping
6
from dataclasses import dataclass
7
8
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
9

10
from vllm.inputs import MultiModalInput
11
from vllm.logger import init_logger
12
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
13

14
15
16
17
18
19
20
21
22
from .cache import (
    BaseMultiModalProcessorCache,
    BaseMultiModalReceiverCache,
    MultiModalProcessorOnlyCache,
    MultiModalProcessorSenderCache,
    MultiModalReceiverCache,
    ShmObjectStoreReceiverCache,
    ShmObjectStoreSenderCache,
)
23
from .processing import (
24
    BaseDummyInputsBuilder,
25
26
27
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
28
    TimingContext,
29
)
30

31
if TYPE_CHECKING:
32
    from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
33
    from vllm.model_executor.models.interfaces import SupportsMultiModal
34

35
36
logger = init_logger(__name__)

37
N = TypeVar("N", bound=type["SupportsMultiModal"])
38
39
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
40
41


42
class ProcessingInfoFactory(Protocol[_I_co]):
43
44
45
46
47
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
48
49
50
51

    def __call__(
        self,
        ctx: InputProcessingContext,
52
    ) -> _I_co: ...
53
54


55
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
56
    """
57
    Constructs a
58
    [`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
59
    instance from the context.
60
61
    """

62
    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
63
64


65
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
66
67
68
69
70
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
71
72
73
74
75

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
76
        *,
77
        cache: BaseMultiModalProcessorCache | None = None,
78
    ) -> BaseMultiModalProcessor[_I]: ...
79

80

81
82
83
84
85
86
87
88
89
90
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
91
        cache: BaseMultiModalProcessorCache | None = None,
92
93
94
95
96
97
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


98
99
class MultiModalRegistry:
    """
100
    A registry that dispatches data processing according to the model.
101
102
    """

103
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
104
105
        """
        Checks if the model supports multimodal inputs.
106
107
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
108
109
110
111
112
113
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

        mm_config = model_config.get_multimodal_config()
114
        info = self._create_processing_info(model_config, tokenizer=None)
115
116
117

        # Check if all supported modalities have limit == 0
        if all(
118
            mm_config.get_limit_per_prompt(modality) == 0
119
            for modality in info.supported_mm_limits
120
        ):
121
122
123
124
125
            # If enable_mm_embeds is True, we still need MM infrastructure
            # to process pre-computed embeddings even though encoder won't run
            if mm_config.enable_mm_embeds:
                return True

126
127
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
128
129
                "are set to 0, running in text-only mode."
            )
130
131
132
133
            return False

        return True

134
135
    def register_processor(
        self,
136
137
138
139
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
140
141
    ):
        """
142
143
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
144
145
146
147
148
149

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
150
            if "_processor_factory" in model_cls.__dict__:
151
                logger.warning(
152
                    "Model class %s already has a multi-modal processor "
153
                    "registered to %s. It is overwritten by the new one.",
154
155
156
                    model_cls,
                    self,
                )
157

158
            model_cls._processor_factory = _ProcessorFactories(
159
160
161
162
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
163
164
165
166
167

            return model_cls

        return wrapper

168
    def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
169
170
171
172
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
173
174
        assert hasattr(model_cls, "_processor_factory")
        return cast("SupportsMultiModal", model_cls)
175

176
177
178
179
180
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
        tokenizer: TokenizerLike | None = None,
    ) -> InputProcessingContext:
181
        if tokenizer is None:
182
183
            tokenizer = cached_tokenizer_from_config(model_config)

184
        return InputProcessingContext(model_config, tokenizer)
185

186
187
    def _create_processing_info(
        self,
188
        model_config: "ModelConfig",
189
        tokenizer: TokenizerLike | None = None,
190
    ) -> BaseProcessingInfo:
191
        model_cls = self._get_model_cls(model_config)
192
        factories = model_cls._processor_factory
193
        ctx = self._create_processing_ctx(model_config, tokenizer)
194
195
        return factories.info(ctx)

196
197
    def create_processor(
        self,
198
        model_config: "ModelConfig",
199
        *,
200
        tokenizer: TokenizerLike | None = None,
201
        cache: BaseMultiModalProcessorCache | None = None,
202
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
203
204
205
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
206
207
208
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

209
        model_cls = self._get_model_cls(model_config)
210
        factories = model_cls._processor_factory
211

212
        ctx = self._create_processing_ctx(model_config, tokenizer)
213

214
        return factories.build_processor(ctx, cache=cache)
215

216
    def get_dummy_mm_inputs(
217
        self,
218
        model_config: "ModelConfig",
219
        mm_counts: Mapping[str, int],
220
        *,
221
        cache: BaseMultiModalProcessorCache | None = None,
222
        processor: BaseMultiModalProcessor | None = None,
223
    ) -> MultiModalInput:
224
225
226
        """
        Create dummy data for profiling the memory usage of a model.

227
        The model is identified by `model_config`.
228
        """
229
230
231
        seq_len = model_config.max_model_len

        if processor is None:
232
            processor = self.create_processor(model_config, cache=cache)
233

234
        mm_config = model_config.get_multimodal_config()
235
236
        processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
            seq_len=seq_len,
237
            mm_counts=mm_counts,
238
            mm_options=mm_config.limit_per_prompt,
239
        )
240
        mm_inputs = processor.apply(
241
242
            processor_inputs,
            timing_ctx=TimingContext(enabled=False),
243
        )
244

245
246
247
248
        prompt_token_ids = mm_inputs["prompt_token_ids"]
        total_len = len(prompt_token_ids)
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
249

250
        return mm_inputs
251

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    def _get_cache_type(
        self,
        vllm_config: "VllmConfig",
    ) -> Literal[None, "processor_only", "lru", "shm"]:
        model_config = vllm_config.model_config
        if not self.supports_multimodal_inputs(model_config):
            return None

        # Check if the cache is disabled.
        mm_config = model_config.get_multimodal_config()
        if mm_config.mm_processor_cache_gb <= 0:
            return None

        # Check if IPC caching is supported.
        parallel_config = vllm_config.parallel_config
        is_ipc_supported = parallel_config._api_process_count == 1 and (
            parallel_config.data_parallel_size == 1
            or parallel_config.data_parallel_external_lb
        )

        if not is_ipc_supported:
            return "processor_only"

        mm_config = model_config.get_multimodal_config()
        return mm_config.mm_processor_cache_type

    def processor_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalProcessorCache | None:
        """Return a `BaseMultiModalProcessorCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None
        elif cache_type == "processor_only":
            return MultiModalProcessorOnlyCache(vllm_config.model_config)
        elif cache_type == "lru":
            return MultiModalProcessorSenderCache(vllm_config.model_config)
        elif cache_type == "shm":
            return ShmObjectStoreSenderCache(vllm_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def processor_only_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> MultiModalProcessorOnlyCache | None:
        """Return a `MultiModalProcessorOnlyCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None

        return MultiModalProcessorOnlyCache(vllm_config.model_config)

    def engine_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the engine process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "shm"):
            return None
        elif cache_type == "lru":
            return MultiModalReceiverCache(vllm_config.model_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def worker_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
        shared_worker_lock: LockType,
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the worker process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "lru"):
            return None
        elif cache_type == "shm":
            return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362


class MultiModalTimingRegistry:
    def __init__(self, observability_config: "ObservabilityConfig | None") -> None:
        super().__init__()

        if observability_config and observability_config.enable_mm_processor_stats:
            self._lock = threading.Lock()
            self._ctx_by_request_id = defaultdict[str, TimingContext](TimingContext)
            self._enabled = True
        else:
            self._enabled = False

    def get(self, request_id: str) -> TimingContext:
        if not self._enabled:
            return TimingContext(enabled=False)

        with self._lock:
            return self._ctx_by_request_id[request_id]

    def stat(self) -> dict[str, dict[str, float]]:
        if not self._enabled:
            return {}

        with self._lock:
            stats = {
                req_id: ctx.get_stats_dict()
                for req_id, ctx in self._ctx_by_request_id.items()
            }
            self._ctx_by_request_id.clear()
            return stats