registry.py 15.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
6
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
7

8
from vllm.config.multimodal import BaseDummyOptions
9
from vllm.config.observability import ObservabilityConfig
10
from vllm.logger import init_logger
11
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
12

13
14
15
16
17
18
19
20
21
from .cache import (
    BaseMultiModalProcessorCache,
    BaseMultiModalReceiverCache,
    MultiModalProcessorOnlyCache,
    MultiModalProcessorSenderCache,
    MultiModalReceiverCache,
    ShmObjectStoreReceiverCache,
    ShmObjectStoreSenderCache,
)
22
from .inputs import MultiModalInputs
23
from .processing import (
24
    BaseDummyInputsBuilder,
25
26
27
28
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
)
29

30
if TYPE_CHECKING:
31
    from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
32
    from vllm.model_executor.models.interfaces import SupportsMultiModal
33

34
35
logger = init_logger(__name__)

36
N = TypeVar("N", bound=type["SupportsMultiModal"])
37
38
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
39
40


41
class ProcessingInfoFactory(Protocol[_I_co]):
42
43
44
45
46
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
47
48
49
50

    def __call__(
        self,
        ctx: InputProcessingContext,
51
    ) -> _I_co: ...
52
53


54
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
55
    """
56
    Constructs a
57
    [`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
58
    instance from the context.
59
60
    """

61
    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
62
63


64
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
65
66
67
68
69
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
70
71
72
73
74

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
75
        *,
76
        cache: BaseMultiModalProcessorCache | None = None,
77
    ) -> BaseMultiModalProcessor[_I]: ...
78

79

80
81
82
83
84
85
86
87
88
89
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
90
        cache: BaseMultiModalProcessorCache | None = None,
91
92
93
94
95
96
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


97
98
class MultiModalRegistry:
    """
99
    A registry that dispatches data processing according to the model.
100
101
    """

102
103
104
    def _extract_mm_options(
        self,
        model_config: "ModelConfig",
105
    ) -> Mapping[str, BaseDummyOptions] | None:
106
107
108
109
110
111
112
113
114
115
116
117
        """
        Extract multimodal dummy options from model config.

        Returns None if no configurable options are found, otherwise returns
        a mapping of modality names to their dummy options.
        """
        if not model_config.multimodal_config:
            return None

        mm_options = {
            m: opt
            for m in model_config.multimodal_config.limit_per_prompt
118
            if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
119
120
121
122
        }

        return mm_options if len(mm_options) > 0 else None

123
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
124
125
        """
        Checks if the model supports multimodal inputs.
126
127
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
128
129
130
131
132
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

133
        info = self._create_processing_info(model_config, tokenizer=None)
134
        supported_modalities = info.get_supported_mm_limits()
135
136
137
138
139

        mm_config = model_config.get_multimodal_config()

        # Check if all supported modalities have limit == 0
        if all(
140
141
142
            mm_config.get_limit_per_prompt(modality) == 0
            for modality in supported_modalities
        ):
143
144
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
145
146
                "are set to 0, running in text-only mode."
            )
147
148
149
150
            return False

        return True

151
152
    def get_max_tokens_per_item_by_modality(
        self,
153
        model_config: "ModelConfig",
154
        *,
155
        cache: BaseMultiModalProcessorCache | None = None,
156
        profiler_limits: Mapping[str, int] | None = None,
157
        observability_config: ObservabilityConfig | None = None,
158
159
    ) -> Mapping[str, int]:
        """
160
        Get the maximum number of tokens per data item from each modality based
161
        on underlying model configuration.
162
        """
163
164
        if not model_config.is_multimodal_model:
            return {}
165

166
167
168
        processor = self.create_processor(
            model_config, observability_config, cache=cache
        )
169

170
171
172
173
174
175
        if profiler_limits is None:
            profiler_limits = processor.allowed_mm_limits

        mm_counts = {
            modality: 1 for modality, limit in profiler_limits.items() if limit > 0
        }
176

177
        max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
178
            seq_len=model_config.max_model_len,
179
180
181
182
183
184
185
186
187
            mm_counts=mm_counts,
        )
        if max_tokens_per_item is not None:
            return {
                modality: max_tokens
                for modality, max_tokens in max_tokens_per_item.items()
                if mm_counts.get(modality, 0) > 0
            }

188
189
        mm_inputs = self.get_dummy_mm_inputs(
            model_config,
190
            mm_counts=mm_counts,
191
            processor=processor,
192
        )
193

194
195
196
197
198
        return {
            modality: sum(item.get_num_embeds for item in placeholders)
            for modality, placeholders in mm_inputs["mm_placeholders"].items()
        }

199
200
    def get_mm_limits_per_prompt(
        self,
201
        model_config: "ModelConfig",
202
        *,
203
        cache: BaseMultiModalProcessorCache | None = None,
204
        observability_config: ObservabilityConfig | None = None,
205
206
207
208
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.
209
        """
210
211
        if not model_config.is_multimodal_model:
            return {}
212

213
214
215
        processor = self.create_processor(
            model_config, observability_config, cache=cache
        )
216
        return processor.allowed_mm_limits
217
218
219

    def register_processor(
        self,
220
221
222
223
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
224
225
    ):
        """
226
227
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
228
229
230
231
232
233

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
234
            if "_processor_factory" in model_cls.__dict__:
235
                logger.warning(
236
                    "Model class %s already has a multi-modal processor "
237
                    "registered to %s. It is overwritten by the new one.",
238
239
240
                    model_cls,
                    self,
                )
241

242
            model_cls._processor_factory = _ProcessorFactories(
243
244
245
246
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
247
248
249
250
251

            return model_cls

        return wrapper

252
    def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
253
254
255
256
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
257
258
        assert hasattr(model_cls, "_processor_factory")
        return cast("SupportsMultiModal", model_cls)
259

260
261
262
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
263
        observability_config: "ObservabilityConfig | None" = None,
264
265
        tokenizer: TokenizerLike | None = None,
    ) -> InputProcessingContext:
266
        if tokenizer is None:
267
268
            tokenizer = cached_tokenizer_from_config(model_config)

269
270
271
        return InputProcessingContext(
            model_config, tokenizer, observability_config=observability_config
        )
272

273
274
    def _create_processing_info(
        self,
275
        model_config: "ModelConfig",
276
        observability_config: "ObservabilityConfig | None" = None,
277
        *,
278
        tokenizer: TokenizerLike | None = None,
279
    ) -> BaseProcessingInfo:
280
        model_cls = self._get_model_cls(model_config)
281
        factories = model_cls._processor_factory
282
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
283
284
        return factories.info(ctx)

285
286
    def create_processor(
        self,
287
        model_config: "ModelConfig",
288
        observability_config: "ObservabilityConfig | None" = None,
289
        *,
290
        tokenizer: TokenizerLike | None = None,
291
        cache: BaseMultiModalProcessorCache | None = None,
292
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
293
294
295
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
296
297
298
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

299
        model_cls = self._get_model_cls(model_config)
300
        factories = model_cls._processor_factory
301

302
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
303

304
        return factories.build_processor(ctx, cache=cache)
305

306
    def get_dummy_mm_inputs(
307
        self,
308
        model_config: "ModelConfig",
309
        mm_counts: Mapping[str, int] | None = None,
310
        *,
311
        cache: BaseMultiModalProcessorCache | None = None,
312
        observability_config: ObservabilityConfig | None = None,
313
314
        processor: BaseMultiModalProcessor | None = None,
    ) -> MultiModalInputs:
315
316
317
        """
        Create dummy data for profiling the memory usage of a model.

318
        The model is identified by `model_config`.
319
        """
320
321
322
323
324
325
326
327
328
329
330
        seq_len = model_config.max_model_len

        if processor is None:
            processor = self.create_processor(
                model_config, observability_config, cache=cache
            )
        if mm_counts is None:
            mm_counts = processor.allowed_mm_limits

        processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
            seq_len=seq_len,
331
332
333
            mm_counts=mm_counts,
            mm_options=self._extract_mm_options(model_config),
        )
334
335
336
337
338
339
        mm_inputs = processor.apply(
            prompt=processor_inputs.prompt,
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
        )
340

341
342
343
344
        prompt_token_ids = mm_inputs["prompt_token_ids"]
        total_len = len(prompt_token_ids)
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
345

346
        return mm_inputs
347

348
    def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
349
350
351
352
353
        """
        Get the maximum length of the encoder input for encoder-decoder models.
        """
        if not model_config.is_encoder_decoder:
            return 0
354
        max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
355
356
357
358
359
        if not max_tokens:
            # TODO - this function assumes encoder-decoder models are
            # multimodal. This will need to change when adding support for more
            # than whisper.
            return 0
360
        assert len(max_tokens) == 1, (
361
362
            "Encoder-decoder models are expected "
            "to implement the multimodal interface with at most one modality."
363
        )
364
365
366

        first_modality = next(iter(max_tokens))
        return max_tokens[first_modality]
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447

    def _get_cache_type(
        self,
        vllm_config: "VllmConfig",
    ) -> Literal[None, "processor_only", "lru", "shm"]:
        model_config = vllm_config.model_config
        if not self.supports_multimodal_inputs(model_config):
            return None

        # Check if the cache is disabled.
        mm_config = model_config.get_multimodal_config()
        if mm_config.mm_processor_cache_gb <= 0:
            return None

        # Check if IPC caching is supported.
        parallel_config = vllm_config.parallel_config
        is_ipc_supported = parallel_config._api_process_count == 1 and (
            parallel_config.data_parallel_size == 1
            or parallel_config.data_parallel_external_lb
        )

        if not is_ipc_supported:
            return "processor_only"

        mm_config = model_config.get_multimodal_config()
        return mm_config.mm_processor_cache_type

    def processor_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalProcessorCache | None:
        """Return a `BaseMultiModalProcessorCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None
        elif cache_type == "processor_only":
            return MultiModalProcessorOnlyCache(vllm_config.model_config)
        elif cache_type == "lru":
            return MultiModalProcessorSenderCache(vllm_config.model_config)
        elif cache_type == "shm":
            return ShmObjectStoreSenderCache(vllm_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def processor_only_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> MultiModalProcessorOnlyCache | None:
        """Return a `MultiModalProcessorOnlyCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None

        return MultiModalProcessorOnlyCache(vllm_config.model_config)

    def engine_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the engine process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "shm"):
            return None
        elif cache_type == "lru":
            return MultiModalReceiverCache(vllm_config.model_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def worker_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
        shared_worker_lock: LockType,
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the worker process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "lru"):
            return None
        elif cache_type == "shm":
            return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")