"docs/getting_started/installation/gpu.cuda.inc.md" did not exist on "dbe7f07001955d6ba745f297203fee0aa0fbc5cf"
registry.py 11.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
6
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
7

8
from vllm.config.observability import ObservabilityConfig
9
from vllm.logger import init_logger
10
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
11

12
13
14
15
16
17
18
19
20
from .cache import (
    BaseMultiModalProcessorCache,
    BaseMultiModalReceiverCache,
    MultiModalProcessorOnlyCache,
    MultiModalProcessorSenderCache,
    MultiModalReceiverCache,
    ShmObjectStoreReceiverCache,
    ShmObjectStoreSenderCache,
)
21
from .inputs import MultiModalInputs
22
from .processing import (
23
    BaseDummyInputsBuilder,
24
25
26
27
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
)
28

29
if TYPE_CHECKING:
30
    from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
31
    from vllm.model_executor.models.interfaces import SupportsMultiModal
32

33
34
logger = init_logger(__name__)

35
N = TypeVar("N", bound=type["SupportsMultiModal"])
36
37
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
38
39


40
class ProcessingInfoFactory(Protocol[_I_co]):
41
42
43
44
45
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
46
47
48
49

    def __call__(
        self,
        ctx: InputProcessingContext,
50
    ) -> _I_co: ...
51
52


53
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
54
    """
55
    Constructs a
56
    [`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
57
    instance from the context.
58
59
    """

60
    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
61
62


63
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
64
65
66
67
68
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
69
70
71
72
73

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
74
        *,
75
        cache: BaseMultiModalProcessorCache | None = None,
76
    ) -> BaseMultiModalProcessor[_I]: ...
77

78

79
80
81
82
83
84
85
86
87
88
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
89
        cache: BaseMultiModalProcessorCache | None = None,
90
91
92
93
94
95
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


96
97
class MultiModalRegistry:
    """
98
    A registry that dispatches data processing according to the model.
99
100
    """

101
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
102
103
        """
        Checks if the model supports multimodal inputs.
104
105
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
106
107
108
109
110
111
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

        mm_config = model_config.get_multimodal_config()
112
        info = self._create_processing_info(model_config, tokenizer=None)
113
114
115

        # Check if all supported modalities have limit == 0
        if all(
116
            mm_config.get_limit_per_prompt(modality) == 0
117
            for modality in info.supported_mm_limits
118
        ):
119
120
121
122
123
            # If enable_mm_embeds is True, we still need MM infrastructure
            # to process pre-computed embeddings even though encoder won't run
            if mm_config.enable_mm_embeds:
                return True

124
125
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
126
127
                "are set to 0, running in text-only mode."
            )
128
129
130
131
            return False

        return True

132
133
    def register_processor(
        self,
134
135
136
137
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
138
139
    ):
        """
140
141
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
142
143
144
145
146
147

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
148
            if "_processor_factory" in model_cls.__dict__:
149
                logger.warning(
150
                    "Model class %s already has a multi-modal processor "
151
                    "registered to %s. It is overwritten by the new one.",
152
153
154
                    model_cls,
                    self,
                )
155

156
            model_cls._processor_factory = _ProcessorFactories(
157
158
159
160
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
161
162
163
164
165

            return model_cls

        return wrapper

166
    def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
167
168
169
170
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
171
172
        assert hasattr(model_cls, "_processor_factory")
        return cast("SupportsMultiModal", model_cls)
173

174
175
176
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
177
        observability_config: "ObservabilityConfig | None" = None,
178
179
        tokenizer: TokenizerLike | None = None,
    ) -> InputProcessingContext:
180
        if tokenizer is None:
181
182
            tokenizer = cached_tokenizer_from_config(model_config)

183
184
185
        return InputProcessingContext(
            model_config, tokenizer, observability_config=observability_config
        )
186

187
188
    def _create_processing_info(
        self,
189
        model_config: "ModelConfig",
190
        observability_config: "ObservabilityConfig | None" = None,
191
        *,
192
        tokenizer: TokenizerLike | None = None,
193
    ) -> BaseProcessingInfo:
194
        model_cls = self._get_model_cls(model_config)
195
        factories = model_cls._processor_factory
196
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
197
198
        return factories.info(ctx)

199
200
    def create_processor(
        self,
201
        model_config: "ModelConfig",
202
        observability_config: "ObservabilityConfig | None" = None,
203
        *,
204
        tokenizer: TokenizerLike | None = None,
205
        cache: BaseMultiModalProcessorCache | None = None,
206
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
207
208
209
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
210
211
212
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

213
        model_cls = self._get_model_cls(model_config)
214
        factories = model_cls._processor_factory
215

216
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
217

218
        return factories.build_processor(ctx, cache=cache)
219

220
    def get_dummy_mm_inputs(
221
        self,
222
        model_config: "ModelConfig",
223
        mm_counts: Mapping[str, int],
224
        *,
225
        cache: BaseMultiModalProcessorCache | None = None,
226
227
        processor: BaseMultiModalProcessor | None = None,
    ) -> MultiModalInputs:
228
229
230
        """
        Create dummy data for profiling the memory usage of a model.

231
        The model is identified by `model_config`.
232
        """
233
234
235
        seq_len = model_config.max_model_len

        if processor is None:
236
            processor = self.create_processor(model_config, cache=cache)
237

238
        mm_config = model_config.get_multimodal_config()
239
240
        processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
            seq_len=seq_len,
241
            mm_counts=mm_counts,
242
            mm_options=mm_config.limit_per_prompt,
243
        )
244
245
        mm_inputs = processor.apply(
            prompt=processor_inputs.prompt,
246
            mm_items=processor_inputs.mm_items,
247
248
249
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
        )
250

251
252
253
254
        prompt_token_ids = mm_inputs["prompt_token_ids"]
        total_len = len(prompt_token_ids)
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
255

256
        return mm_inputs
257

258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    def _get_cache_type(
        self,
        vllm_config: "VllmConfig",
    ) -> Literal[None, "processor_only", "lru", "shm"]:
        model_config = vllm_config.model_config
        if not self.supports_multimodal_inputs(model_config):
            return None

        # Check if the cache is disabled.
        mm_config = model_config.get_multimodal_config()
        if mm_config.mm_processor_cache_gb <= 0:
            return None

        # Check if IPC caching is supported.
        parallel_config = vllm_config.parallel_config
        is_ipc_supported = parallel_config._api_process_count == 1 and (
            parallel_config.data_parallel_size == 1
            or parallel_config.data_parallel_external_lb
        )

        if not is_ipc_supported:
            return "processor_only"

        mm_config = model_config.get_multimodal_config()
        return mm_config.mm_processor_cache_type

    def processor_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalProcessorCache | None:
        """Return a `BaseMultiModalProcessorCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None
        elif cache_type == "processor_only":
            return MultiModalProcessorOnlyCache(vllm_config.model_config)
        elif cache_type == "lru":
            return MultiModalProcessorSenderCache(vllm_config.model_config)
        elif cache_type == "shm":
            return ShmObjectStoreSenderCache(vllm_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def processor_only_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> MultiModalProcessorOnlyCache | None:
        """Return a `MultiModalProcessorOnlyCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None

        return MultiModalProcessorOnlyCache(vllm_config.model_config)

    def engine_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the engine process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "shm"):
            return None
        elif cache_type == "lru":
            return MultiModalReceiverCache(vllm_config.model_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def worker_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
        shared_worker_lock: LockType,
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the worker process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "lru"):
            return None
        elif cache_type == "shm":
            return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")