registry.py 15.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
6
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
7

8
from vllm.config.multimodal import BaseDummyOptions
9
from vllm.config.observability import ObservabilityConfig
10
from vllm.logger import init_logger
11
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
12

13
14
15
16
17
18
19
20
21
from .cache import (
    BaseMultiModalProcessorCache,
    BaseMultiModalReceiverCache,
    MultiModalProcessorOnlyCache,
    MultiModalProcessorSenderCache,
    MultiModalReceiverCache,
    ShmObjectStoreReceiverCache,
    ShmObjectStoreSenderCache,
)
22
from .inputs import MultiModalInputs
23
from .processing import (
24
    BaseDummyInputsBuilder,
25
26
27
28
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
)
29

30
if TYPE_CHECKING:
31
    from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
32
    from vllm.model_executor.models.interfaces import SupportsMultiModal
33

34
35
logger = init_logger(__name__)

36
N = TypeVar("N", bound=type["SupportsMultiModal"])
37
38
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
39
40


41
class ProcessingInfoFactory(Protocol[_I_co]):
42
43
44
45
46
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
47
48
49
50

    def __call__(
        self,
        ctx: InputProcessingContext,
51
    ) -> _I_co: ...
52
53


54
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
55
    """
56
    Constructs a
57
    [`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
58
    instance from the context.
59
60
    """

61
    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
62
63


64
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
65
66
67
68
69
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
70
71
72
73
74

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
75
        *,
76
        cache: BaseMultiModalProcessorCache | None = None,
77
    ) -> BaseMultiModalProcessor[_I]: ...
78

79

80
81
82
83
84
85
86
87
88
89
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
90
        cache: BaseMultiModalProcessorCache | None = None,
91
92
93
94
95
96
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


97
98
class MultiModalRegistry:
    """
99
    A registry that dispatches data processing according to the model.
100
101
    """

102
103
104
    def _extract_mm_options(
        self,
        model_config: "ModelConfig",
105
    ) -> Mapping[str, BaseDummyOptions] | None:
106
107
108
109
110
111
112
113
114
115
116
117
        """
        Extract multimodal dummy options from model config.

        Returns None if no configurable options are found, otherwise returns
        a mapping of modality names to their dummy options.
        """
        if not model_config.multimodal_config:
            return None

        mm_options = {
            m: opt
            for m in model_config.multimodal_config.limit_per_prompt
118
            if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
119
120
121
122
        }

        return mm_options if len(mm_options) > 0 else None

123
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
124
125
        """
        Checks if the model supports multimodal inputs.
126
127
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
128
129
130
131
132
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

133
        info = self._create_processing_info(model_config, tokenizer=None)
134
        supported_modalities = info.get_supported_mm_limits()
135
136
137
138
139

        mm_config = model_config.get_multimodal_config()

        # Check if all supported modalities have limit == 0
        if all(
140
141
142
            mm_config.get_limit_per_prompt(modality) == 0
            for modality in supported_modalities
        ):
143
144
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
145
146
                "are set to 0, running in text-only mode."
            )
147
148
149
150
            return False

        return True

151
152
    def get_max_tokens_per_item_by_modality(
        self,
153
        model_config: "ModelConfig",
154
        *,
155
        cache: BaseMultiModalProcessorCache | None = None,
156
        profiler_limits: Mapping[str, int] | None = None,
157
        observability_config: ObservabilityConfig | None = None,
158
159
    ) -> Mapping[str, int]:
        """
160
        Get the maximum number of tokens per data item from each modality based
161
        on underlying model configuration.
162
        """
163
164
        if not model_config.is_multimodal_model:
            return {}
165

166
167
168
        processor = self.create_processor(
            model_config, observability_config, cache=cache
        )
169

170
        if profiler_limits is None:
171
            profiler_limits = processor.info.allowed_mm_limits
172
173
174
175

        mm_counts = {
            modality: 1 for modality, limit in profiler_limits.items() if limit > 0
        }
176

177
        max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
178
            seq_len=model_config.max_model_len,
179
180
181
182
183
184
185
186
187
            mm_counts=mm_counts,
        )
        if max_tokens_per_item is not None:
            return {
                modality: max_tokens
                for modality, max_tokens in max_tokens_per_item.items()
                if mm_counts.get(modality, 0) > 0
            }

188
189
        mm_inputs = self.get_dummy_mm_inputs(
            model_config,
190
            mm_counts=mm_counts,
191
            processor=processor,
192
        )
193

194
195
196
197
198
        return {
            modality: sum(item.get_num_embeds for item in placeholders)
            for modality, placeholders in mm_inputs["mm_placeholders"].items()
        }

199
200
    def get_mm_limits_per_prompt(
        self,
201
        model_config: "ModelConfig",
202
        *,
203
        observability_config: ObservabilityConfig | None = None,
204
205
206
207
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.
208
        """
209
210
        if not model_config.is_multimodal_model:
            return {}
211

212
213
        info = self._create_processing_info(model_config, observability_config)
        return info.allowed_mm_limits
214
215
216

    def register_processor(
        self,
217
218
219
220
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
221
222
    ):
        """
223
224
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
225
226
227
228
229
230

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
231
            if "_processor_factory" in model_cls.__dict__:
232
                logger.warning(
233
                    "Model class %s already has a multi-modal processor "
234
                    "registered to %s. It is overwritten by the new one.",
235
236
237
                    model_cls,
                    self,
                )
238

239
            model_cls._processor_factory = _ProcessorFactories(
240
241
242
243
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
244
245
246
247
248

            return model_cls

        return wrapper

249
    def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
250
251
252
253
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
254
255
        assert hasattr(model_cls, "_processor_factory")
        return cast("SupportsMultiModal", model_cls)
256

257
258
259
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
260
        observability_config: "ObservabilityConfig | None" = None,
261
262
        tokenizer: TokenizerLike | None = None,
    ) -> InputProcessingContext:
263
        if tokenizer is None:
264
265
            tokenizer = cached_tokenizer_from_config(model_config)

266
267
268
        return InputProcessingContext(
            model_config, tokenizer, observability_config=observability_config
        )
269

270
271
    def _create_processing_info(
        self,
272
        model_config: "ModelConfig",
273
        observability_config: "ObservabilityConfig | None" = None,
274
        *,
275
        tokenizer: TokenizerLike | None = None,
276
    ) -> BaseProcessingInfo:
277
        model_cls = self._get_model_cls(model_config)
278
        factories = model_cls._processor_factory
279
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
280
281
        return factories.info(ctx)

282
283
    def create_processor(
        self,
284
        model_config: "ModelConfig",
285
        observability_config: "ObservabilityConfig | None" = None,
286
        *,
287
        tokenizer: TokenizerLike | None = None,
288
        cache: BaseMultiModalProcessorCache | None = None,
289
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
290
291
292
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
293
294
295
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

296
        model_cls = self._get_model_cls(model_config)
297
        factories = model_cls._processor_factory
298

299
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
300

301
        return factories.build_processor(ctx, cache=cache)
302

303
    def get_dummy_mm_inputs(
304
        self,
305
        model_config: "ModelConfig",
306
        mm_counts: Mapping[str, int] | None = None,
307
        *,
308
        cache: BaseMultiModalProcessorCache | None = None,
309
        observability_config: ObservabilityConfig | None = None,
310
311
        processor: BaseMultiModalProcessor | None = None,
    ) -> MultiModalInputs:
312
313
314
        """
        Create dummy data for profiling the memory usage of a model.

315
        The model is identified by `model_config`.
316
        """
317
318
319
320
321
322
323
        seq_len = model_config.max_model_len

        if processor is None:
            processor = self.create_processor(
                model_config, observability_config, cache=cache
            )
        if mm_counts is None:
324
            mm_counts = processor.info.allowed_mm_limits
325
326
327

        processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
            seq_len=seq_len,
328
329
330
            mm_counts=mm_counts,
            mm_options=self._extract_mm_options(model_config),
        )
331
332
        mm_inputs = processor.apply(
            prompt=processor_inputs.prompt,
333
            mm_items=processor_inputs.mm_items,
334
335
336
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
        )
337

338
339
340
341
        prompt_token_ids = mm_inputs["prompt_token_ids"]
        total_len = len(prompt_token_ids)
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
342

343
        return mm_inputs
344

345
    def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
346
347
348
349
350
        """
        Get the maximum length of the encoder input for encoder-decoder models.
        """
        if not model_config.is_encoder_decoder:
            return 0
351
        max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
352
353
354
355
356
        if not max_tokens:
            # TODO - this function assumes encoder-decoder models are
            # multimodal. This will need to change when adding support for more
            # than whisper.
            return 0
357
        assert len(max_tokens) == 1, (
358
359
            "Encoder-decoder models are expected "
            "to implement the multimodal interface with at most one modality."
360
        )
361
362
363

        first_modality = next(iter(max_tokens))
        return max_tokens[first_modality]
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444

    def _get_cache_type(
        self,
        vllm_config: "VllmConfig",
    ) -> Literal[None, "processor_only", "lru", "shm"]:
        model_config = vllm_config.model_config
        if not self.supports_multimodal_inputs(model_config):
            return None

        # Check if the cache is disabled.
        mm_config = model_config.get_multimodal_config()
        if mm_config.mm_processor_cache_gb <= 0:
            return None

        # Check if IPC caching is supported.
        parallel_config = vllm_config.parallel_config
        is_ipc_supported = parallel_config._api_process_count == 1 and (
            parallel_config.data_parallel_size == 1
            or parallel_config.data_parallel_external_lb
        )

        if not is_ipc_supported:
            return "processor_only"

        mm_config = model_config.get_multimodal_config()
        return mm_config.mm_processor_cache_type

    def processor_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalProcessorCache | None:
        """Return a `BaseMultiModalProcessorCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None
        elif cache_type == "processor_only":
            return MultiModalProcessorOnlyCache(vllm_config.model_config)
        elif cache_type == "lru":
            return MultiModalProcessorSenderCache(vllm_config.model_config)
        elif cache_type == "shm":
            return ShmObjectStoreSenderCache(vllm_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def processor_only_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> MultiModalProcessorOnlyCache | None:
        """Return a `MultiModalProcessorOnlyCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None

        return MultiModalProcessorOnlyCache(vllm_config.model_config)

    def engine_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the engine process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "shm"):
            return None
        elif cache_type == "lru":
            return MultiModalReceiverCache(vllm_config.model_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def worker_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
        shared_worker_lock: LockType,
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the worker process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "lru"):
            return None
        elif cache_type == "shm":
            return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")