"requirements/test/cuda.in" did not exist on "e34d130c1613dbabc708cd5f059506c887ac81b4"
registry.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import threading
from collections import defaultdict
5
from collections.abc import Mapping
6
from dataclasses import dataclass
7
8
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
9

10
from vllm.inputs import MultiModalInput
11
from vllm.logger import init_logger
12
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
13

14
15
16
17
18
19
20
21
22
from .cache import (
    BaseMultiModalProcessorCache,
    BaseMultiModalReceiverCache,
    MultiModalProcessorOnlyCache,
    MultiModalProcessorSenderCache,
    MultiModalReceiverCache,
    ShmObjectStoreReceiverCache,
    ShmObjectStoreSenderCache,
)
23
from .processing import (
24
    BaseDummyInputsBuilder,
25
26
27
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
28
    TimingContext,
29
)
30

31
if TYPE_CHECKING:
32
    from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
33
    from vllm.model_executor.models.interfaces import SupportsMultiModal
34

35
36
logger = init_logger(__name__)

37
N = TypeVar("N", bound=type["SupportsMultiModal"])
38
39
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
40
41


42
class ProcessingInfoFactory(Protocol[_I_co]):
43
44
45
46
47
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
48
49
50
51

    def __call__(
        self,
        ctx: InputProcessingContext,
52
    ) -> _I_co: ...
53
54


55
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
56
    """
57
    Constructs a
58
    [`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
59
    instance from the context.
60
61
    """

62
    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
63
64


65
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
66
67
68
69
70
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
71
72
73
74
75

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
76
        *,
77
        cache: BaseMultiModalProcessorCache | None = None,
78
    ) -> BaseMultiModalProcessor[_I]: ...
79

80

81
82
83
84
85
86
87
88
89
90
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
91
        cache: BaseMultiModalProcessorCache | None = None,
92
93
94
95
96
97
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


98
99
class MultiModalRegistry:
    """
100
    A registry that dispatches data processing according to the model.
101
102
    """

103
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
104
105
        """
        Checks if the model supports multimodal inputs.
106
107
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
108
109
110
111
112
113
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

        mm_config = model_config.get_multimodal_config()
114
        info = self._create_processing_info(model_config, tokenizer=None)
115
116
117

        # Check if all supported modalities have limit == 0
        if all(
118
            mm_config.get_limit_per_prompt(modality) == 0
119
            for modality in info.supported_mm_limits
120
        ):
121
122
123
124
125
            # If enable_mm_embeds is True, we still need MM infrastructure
            # to process pre-computed embeddings even though encoder won't run
            if mm_config.enable_mm_embeds:
                return True

126
127
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
128
129
                "are set to 0, running in text-only mode."
            )
130
131
132
133
            return False

        return True

134
135
    def register_processor(
        self,
136
137
138
139
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
140
141
    ):
        """
142
143
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
144
145
146
147
148
149

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
150
            if "_processor_factory" in model_cls.__dict__:
151
                logger.warning(
152
                    "Model class %s already has a multi-modal processor "
153
                    "registered to %s. It is overwritten by the new one.",
154
155
156
                    model_cls,
                    self,
                )
157

158
            model_cls._processor_factory = _ProcessorFactories(
159
160
161
162
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
163
164
165
166
167

            return model_cls

        return wrapper

168
    def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
169
170
171
172
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
173
174
        assert hasattr(model_cls, "_processor_factory")
        return cast("SupportsMultiModal", model_cls)
175

176
177
178
179
180
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
        tokenizer: TokenizerLike | None = None,
    ) -> InputProcessingContext:
181
        if tokenizer is None:
182
183
            tokenizer = cached_tokenizer_from_config(model_config)

184
        return InputProcessingContext(model_config, tokenizer)
185

186
187
    def _create_processing_info(
        self,
188
        model_config: "ModelConfig",
189
        tokenizer: TokenizerLike | None = None,
190
    ) -> BaseProcessingInfo:
191
        model_cls = self._get_model_cls(model_config)
192
        factories = model_cls._processor_factory
193
        ctx = self._create_processing_ctx(model_config, tokenizer)
194
195
        return factories.info(ctx)

196
197
198
    def get_processing_info(self, model_config: "ModelConfig") -> BaseProcessingInfo:
        return self._create_processing_info(model_config, tokenizer=None)

199
200
    def create_processor(
        self,
201
        model_config: "ModelConfig",
202
        *,
203
        tokenizer: TokenizerLike | None = None,
204
        cache: BaseMultiModalProcessorCache | None = None,
205
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
206
207
208
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
209
210
211
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

212
        model_cls = self._get_model_cls(model_config)
213
        factories = model_cls._processor_factory
214

215
        ctx = self._create_processing_ctx(model_config, tokenizer)
216

217
        return factories.build_processor(ctx, cache=cache)
218

219
    def get_dummy_mm_inputs(
220
        self,
221
        model_config: "ModelConfig",
222
        mm_counts: Mapping[str, int],
223
        *,
224
        cache: BaseMultiModalProcessorCache | None = None,
225
        processor: BaseMultiModalProcessor | None = None,
226
    ) -> MultiModalInput:
227
228
229
        """
        Create dummy data for profiling the memory usage of a model.

230
        The model is identified by `model_config`.
231
        """
232
233
234
        seq_len = model_config.max_model_len

        if processor is None:
235
            processor = self.create_processor(model_config, cache=cache)
236

237
        mm_config = model_config.get_multimodal_config()
238
239
        processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
            seq_len=seq_len,
240
            mm_counts=mm_counts,
241
            mm_options=mm_config.limit_per_prompt,
242
        )
243
        mm_inputs = processor.apply(
244
245
            processor_inputs,
            timing_ctx=TimingContext(enabled=False),
246
        )
247

248
249
250
251
        prompt_token_ids = mm_inputs["prompt_token_ids"]
        total_len = len(prompt_token_ids)
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
252

253
        return mm_inputs
254

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    def _get_cache_type(
        self,
        vllm_config: "VllmConfig",
    ) -> Literal[None, "processor_only", "lru", "shm"]:
        model_config = vllm_config.model_config
        if not self.supports_multimodal_inputs(model_config):
            return None

        # Check if the cache is disabled.
        mm_config = model_config.get_multimodal_config()
        if mm_config.mm_processor_cache_gb <= 0:
            return None

        # Check if IPC caching is supported.
        parallel_config = vllm_config.parallel_config
        is_ipc_supported = parallel_config._api_process_count == 1 and (
            parallel_config.data_parallel_size == 1
            or parallel_config.data_parallel_external_lb
        )

        if not is_ipc_supported:
            return "processor_only"

        mm_config = model_config.get_multimodal_config()
        return mm_config.mm_processor_cache_type

    def processor_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalProcessorCache | None:
        """Return a `BaseMultiModalProcessorCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None
        elif cache_type == "processor_only":
            return MultiModalProcessorOnlyCache(vllm_config.model_config)
        elif cache_type == "lru":
            return MultiModalProcessorSenderCache(vllm_config.model_config)
        elif cache_type == "shm":
            return ShmObjectStoreSenderCache(vllm_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def processor_only_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> MultiModalProcessorOnlyCache | None:
        """Return a `MultiModalProcessorOnlyCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None

        return MultiModalProcessorOnlyCache(vllm_config.model_config)

    def engine_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the engine process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "shm"):
            return None
        elif cache_type == "lru":
            return MultiModalReceiverCache(vllm_config.model_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def worker_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
        shared_worker_lock: LockType,
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the worker process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "lru"):
            return None
        elif cache_type == "shm":
            return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365


class MultiModalTimingRegistry:
    def __init__(self, observability_config: "ObservabilityConfig | None") -> None:
        super().__init__()

        if observability_config and observability_config.enable_mm_processor_stats:
            self._lock = threading.Lock()
            self._ctx_by_request_id = defaultdict[str, TimingContext](TimingContext)
            self._enabled = True
        else:
            self._enabled = False

    def get(self, request_id: str) -> TimingContext:
        if not self._enabled:
            return TimingContext(enabled=False)

        with self._lock:
            return self._ctx_by_request_id[request_id]

    def stat(self) -> dict[str, dict[str, float]]:
        if not self._enabled:
            return {}

        with self._lock:
            stats = {
                req_id: ctx.get_stats_dict()
                for req_id, ctx in self._ctx_by_request_id.items()
            }
            self._ctx_by_request_id.clear()
            return stats