cache.py 22.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import operator
4
import sys
5
6
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
7
from multiprocessing.synchronize import Lock as LockType
8
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast
9
10

import torch
11
from typing_extensions import override
12

13
import vllm.envs as envs
14
from vllm.distributed.device_communicators.shm_object_storage import (
15
16
17
18
    MsgpackSerde,
    SingleWriterShmObjectStorage,
    SingleWriterShmRingBuffer,
)
19
from vllm.logger import init_logger
20
from vllm.utils import GiB_bytes, MiB_bytes
21
from vllm.utils.cache import CacheInfo, LRUCache
22
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
23

24
25
26
27
28
29
30
31
32
from .inputs import (
    MultiModalBatchedField,
    MultiModalFeatureSpec,
    MultiModalFieldElem,
    MultiModalKwargs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    NestedTensors,
)
33

34
35
36
37
38
39
if TYPE_CHECKING:
    from vllm.config import ModelConfig, VllmConfig

    from .processing import ResolvedPromptUpdate
    from .registry import MultiModalRegistry

40
41
42
logger = init_logger(__name__)


43
44
45
class MultiModalProcessorCacheItem:
    """
    The data to store inside `MultiModalProcessorOnlyCache`.
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    Args:
        item: The processed tensor data corresponding to a multi-modal item.
        prompt_updates: The prompt updates corresponding to `item`.
    """

    def __init__(
        self,
        item: MultiModalKwargsItem,
        prompt_updates: Sequence["ResolvedPromptUpdate"],
    ) -> None:
        super().__init__()

        self.item = item
        self.prompt_updates = prompt_updates


class MultiModalProcessorCacheItemMetadata:
    """
    The metadata to store inside `MultiModalProcessorSenderCache`.

    Args:
        item: The processed tensor data corresponding to a multi-modal item.
            Since P1 already stores the tensor data, we only store its size
            metadata in P0 to reduce memory usage. The size metadata is still
            needed to keep the same cache eviction policy as P0.
        prompt_updates: The prompt updates corresponding to `item`.
            This needs to stay on P0 because for some models, they are
            dependent on the processed tensor data (cached on P1).
    """

    def __init__(
        self,
        item: MultiModalKwargsItem,
        prompt_updates: Sequence["ResolvedPromptUpdate"],
    ) -> None:
        super().__init__()

        self.item_size = MultiModalCache.get_item_size(item)
        self.prompt_updates = prompt_updates
86
87


88
89
90
91
92
93
94
95
MultiModalCacheValue: TypeAlias = (
    MultiModalProcessorCacheItem
    | MultiModalProcessorCacheItemMetadata
    | MultiModalKwargsItems
    | MultiModalKwargsItem
    | MultiModalKwargs
    | Mapping[str, NestedTensors]
)
96
97
98
99
100
101

_V = TypeVar("_V", bound=MultiModalCacheValue)


class MultiModalCache:
    @classmethod
102
    def get_leaf_size(cls, leaf: object) -> int:
103
104
105
106
        if isinstance(leaf, MultiModalProcessorCacheItem):
            return cls.get_leaf_size(leaf.item)
        if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
            return leaf.item_size
107

108
        # These are not subclasses of dict
109
110
111
112
113
114
115
116
117
        if isinstance(
            leaf,
            (
                MultiModalKwargs,
                MultiModalKwargsItems,
                MultiModalKwargsItem,
                MultiModalFieldElem,
            ),
        ):
118
119
            return cls.get_item_size(leaf.data)  # type: ignore

120
121
122
123
124
125
126
127
128
129
130
131
132
        # sys.getsizeof doesn't work for tensors
        if isinstance(leaf, torch.Tensor):
            return leaf.nbytes

        return sys.getsizeof(leaf)

    @classmethod
    def get_item_size(
        cls,
        value: MultiModalCacheValue,
        *,
        debug: bool = False,
    ) -> int:
133
134
135
        size = json_reduce_leaves(
            operator.add, json_map_leaves(cls.get_leaf_size, value)
        )
136
137

        if debug:
138
139
140
141
142
143
144
            leaf_count = json_count_leaves(value)
            logger.debug(
                "Calculated size of %s to be %.2f GiB (%d leaves)",
                type(value),
                size / GiB_bytes,
                leaf_count,
            )
145
146
147

        return size

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    @classmethod
    def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
        """
        Get the number of leaf elements in a multi-modal cache value.

        This provides a measure of structural complexity that can be useful
        for debugging cache performance and understanding data patterns.

        Args:
            value: The multi-modal cache value to analyze.

        Returns:
            The number of leaf elements in the nested structure.
        """
        return json_count_leaves(value)

164
165
166
167
168
169
170
171
172
173
174
175
    @classmethod
    def get_lru_cache(
        cls,
        capacity_gb: float,
        value_type: type[_V],
        *,
        debug: bool = False,
    ) -> LRUCache[str, _V]:
        return LRUCache(
            GiB_bytes * capacity_gb,
            getsizeof=lambda x: cls.get_item_size(x, debug=debug),
        )
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215


_I = TypeVar("_I", contravariant=True)
_O = TypeVar("_O", covariant=True)


class BaseMultiModalCache(ABC, Generic[_I, _O]):
    """
    Abstract base class to read/write multi-modal items from cache.

    The idea of multi-modal caching is based on having a client and server
    where the client executes in the frontend process (=P0) and
    the server in the core process (=P1). The data flow is as follows:

    ```
                  is_cached() x N    get_and_update()
    P0: From API -----------------> -----------------> To P1

                 get_and_update()
    P1: From P0 -----------------> To model
    ```

    `is_cached()` can be called any number of times in P0. However,
    `get_and_update()` must be called in P0 and P1 one after another
    so that their cache eviction order remains the same.

    This ensures that the keys in P0 and P1 caches are mirrored,
    allowing us to determine whether a key is cached in P1 by looking
    up the P0 cache, without having to communicate with P1.
    """

    @abstractmethod
    def get_and_update_item(
        self,
        mm_item: _I,
        mm_hash: str,
    ) -> _O:
        """
        Possibly update a multi-modal item based on whether it is
        in the underlying cache.
216

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        This update is done out-of-place and updates the cache eviction order.

        Args:
            mm_item: The multi-modal item to update.
            mm_hash: The hash of `mm_item`.

        Returns:
            The update multi-modal item.
        """
        raise NotImplementedError

    def get_and_update(
        self,
        mm_items: Sequence[_I],
        mm_hashes: list[str],
    ) -> list[_O]:
        """
        Possibly update a sequence of multi-modal items based on whether they
        are in the underlying cache.

        This update is done out-of-place and updates the cache eviction order.

        Args:
            mm_items: The multi-modal items to update.
            mm_hashes: The hash of each item in `mm_items`.

        Returns:
            A new list of updated multi-modal items.
        """
        assert len(mm_items) == len(mm_hashes)

        return [
            self.get_and_update_item(mm_item, mm_hash)
            for mm_item, mm_hash in zip(mm_items, mm_hashes)
        ]

    @abstractmethod
    def clear_cache(self) -> None:
        """Clear the underlying cache."""
        raise NotImplementedError


259
260
261
MultiModalProcessorCacheInItem: TypeAlias = (
    tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] | None
)
262
263


264
MultiModalProcessorCacheOutItem: TypeAlias = tuple[
265
    MultiModalKwargsItem | None, Sequence["ResolvedPromptUpdate"]
266
]
267
268
269


class BaseMultiModalProcessorCache(
270
271
    BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem]
):
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    """The required interface for caches on P0."""

    @abstractmethod
    def is_cached_item(self, mm_hash: str) -> bool:
        """
        Check whether a multi-modal item is
        in the underlying cache.

        This **DOES NOT** update the cache eviction order.

        Args:
            mm_hash: The hash of the item to check.

        Returns:
            `True` if the item is cached, otherwise `False`.
        """
        raise NotImplementedError

    def is_cached(self, mm_hashes: list[str]) -> list[bool]:
        """
        Check whether a sequence of multi-modal items are
        in the underlying cache.

        This **DOES NOT** update the cache eviction order.
296

297
298
299
300
301
302
303
304
        Args:
            mm_hashes: The hash of each item to check.

        Returns:
            For each item, `True` if the item is cached, otherwise `False`.
        """
        return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]

305
306
307
308
309
310
311
312
313
314
    @abstractmethod
    def make_stats(self, *, delta: bool = False) -> CacheInfo:
        """
        Get (and reset) the multi-modal cache stats.

        Returns:
            The current multi-modal caching stats.
        """
        raise NotImplementedError

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
    """
    The cache which is used on P0 when IPC caching is disabled.

    How to update each item:

    - If the item is in the cache, replace the input with the cached item.
    - If the item is not in the cache, store that item (which includes
      tensor data and metadata) into the cache, and return the input.
    """

    def __init__(self, model_config: "ModelConfig") -> None:
        super().__init__()

        mm_config = model_config.get_multimodal_config()

        self._cache = MultiModalCache.get_lru_cache(
            mm_config.mm_processor_cache_gb,
            MultiModalProcessorCacheItem,
        )

    @override
    def is_cached_item(self, mm_hash: str) -> bool:
        return mm_hash in self._cache

    @override
    def get_and_update_item(
        self,
        mm_item: MultiModalProcessorCacheInItem,
        mm_hash: str,
    ) -> MultiModalProcessorCacheOutItem:
        if (cached_item := self._cache.get(mm_hash)) is not None:
            return cached_item.item, cached_item.prompt_updates

        assert mm_item is not None, f"Expected a cached item for {mm_hash=}"

        self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)

        return mm_item

    @override
    def clear_cache(self) -> None:
        self._cache.clear()

360
361
362
363
    @override
    def make_stats(self, *, delta: bool = False) -> CacheInfo:
        return self._cache.stat(delta=delta)

364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413

class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
    """
    The cache which is used on P0 when IPC caching is enabled.

    How to update each item:

    - If the item is already in the cache, clear the input to avoid
      unnecessary IPC.

    - If the item is not in the cache, store the metadata of that item so
      that the eviction policy remains the same as the cache on P1,
      and return the input.
      By only storing the metadata, we avoid keeping the data itself in
      memory inside P0.
    """

    def __init__(self, model_config: "ModelConfig") -> None:
        super().__init__()

        mm_config = model_config.get_multimodal_config()

        self._cache = MultiModalCache.get_lru_cache(
            mm_config.mm_processor_cache_gb,
            MultiModalProcessorCacheItemMetadata,
        )

    @override
    def is_cached_item(self, mm_hash: str) -> bool:
        return mm_hash in self._cache

    @override
    def get_and_update_item(
        self,
        mm_item: MultiModalProcessorCacheInItem,
        mm_hash: str,
    ) -> MultiModalProcessorCacheOutItem:
        if (cached_item := self._cache.get(mm_hash)) is not None:
            return None, cached_item.prompt_updates

        assert mm_item is not None, f"Expected a cached item for {mm_hash=}"

        self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)

        return mm_item

    @override
    def clear_cache(self) -> None:
        self._cache.clear()

414
415
416
417
    @override
    def make_stats(self, *, delta: bool = False) -> CacheInfo:
        return self._cache.stat(delta=delta)

418

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
    """
    The cache which is used on P0 when IPC caching is enabled.

    How to update each item:

    - If the item is already in the cache, clear the input to avoid
      unnecessary IPC.

    - If the item is not in the cache, store the data in shared memory.
    """

    def __init__(self, vllm_config: "VllmConfig") -> None:
        super().__init__()

        self.world_size = vllm_config.parallel_config.world_size
        mm_config = vllm_config.model_config.get_multimodal_config()

        ring_buffer = SingleWriterShmRingBuffer(
            data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
439
            name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
440
441
442
            create=True,  # sender is the writer
        )
        self._shm_cache = SingleWriterShmObjectStorage(
443
            max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
444
445
446
447
448
            n_readers=self.world_size,
            ring_buffer=ring_buffer,
            serde_class=MsgpackSerde,
        )
        # cache (prompt_updates, modality) for P0 only
449
        self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
450

451
452
453
454
455
456
457
458
459
460
461
462
463
464
        self._hits = 0
        self._total = 0
        self._last_info = CacheInfo(hits=0, total=0)

    def _stat(self, *, delta: bool = False) -> CacheInfo:
        info = CacheInfo(hits=self._hits, total=self._total)

        if delta:
            info_delta = info - self._last_info
            self._last_info = info
            info = info_delta

        return info

465
466
467
468
469
470
471
472
473
474
475
    @override
    def is_cached_item(self, mm_hash: str) -> bool:
        return self._shm_cache.is_cached(mm_hash)

    @override
    def get_and_update_item(
        self,
        mm_item: MultiModalProcessorCacheInItem,
        mm_hash: str,
    ) -> MultiModalProcessorCacheOutItem:
        if self._shm_cache.is_cached(mm_hash):
476
477
478
            self._hits += 1
            self._total += 1

479
480
            address, monotonic_id = self._shm_cache.get_cached(mm_hash)
            prompt_updates, modality = self._p0_cache[mm_hash]
481
            return self.address_as_item(address, monotonic_id, modality), prompt_updates
482
483
484

        assert mm_item is not None, f"Expected a cached item for {mm_hash=}"

485
486
        self._total += 1

487
488
489
490
491
492
        try:
            address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
            # Try to remove dangling items if p0 cache is too large.
            if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
                self.remove_dangling_items()
            self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
493
494
495
            address_item = self.address_as_item(
                address, monotonic_id, mm_item[0].modality
            )
496
497
498
499
500
            return address_item, mm_item[1]
        except (ValueError, MemoryError) as e:
            # put may fail if the object is too large or
            # the cache is full.
            # In this case we log the error and keep the original mm_input.
501
            logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
502
503
504
505
506
507
508
            return mm_item

    @override
    def clear_cache(self) -> None:
        self._shm_cache.clear()
        self._p0_cache.clear()

509
510
511
512
513
514
515
516
        self._hits = 0
        self._total = 0
        self._last_info = CacheInfo(hits=0, total=0)

    @override
    def make_stats(self, *, delta: bool = False) -> CacheInfo:
        return self._stat(delta=delta)

517
518
519
520
521
522
523
    def remove_dangling_items(self) -> None:
        """Remove items that are no longer in the shared memory cache."""
        cached_hashes = self._shm_cache.key_index.keys()
        dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
        for mm_hash in dangling_hashes:
            del self._p0_cache[mm_hash]

524
525
526
    def address_as_item(
        self, address: int, monotonic_id: int, modality: str
    ) -> MultiModalKwargsItem:
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
        addr_elem = MultiModalFieldElem(
            modality=modality,
            key="address",
            data=address,
            field=MultiModalBatchedField(),
        )
        id_elem = MultiModalFieldElem(
            modality=modality,
            key="monotonic_id",
            data=monotonic_id,
            field=MultiModalBatchedField(),
        )
        mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
        return mm_item


543
544
545
546
547
548
549
550
551
552
553
554
555
def _enable_processor_cache(
    model_config: "ModelConfig",
    mm_registry: "MultiModalRegistry",
) -> bool:
    if not mm_registry.supports_multimodal_inputs(model_config):
        return False

    mm_config = model_config.get_multimodal_config()
    return mm_config.mm_processor_cache_gb > 0


def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
    parallel_config = vllm_config.parallel_config
556
557
558
559
    supports_ipc_cache = (
        parallel_config._api_process_count == 1
        and parallel_config.data_parallel_size == 1
    ) or parallel_config.data_parallel_external_lb
560
561
562
563

    return supports_ipc_cache


564
565
566
567
568
569
570
571
572
573
574
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
    """Whether the shared memory based cache should be enabled."""

    if not _enable_ipc_cache(vllm_config):
        return False

    mm_config = vllm_config.model_config.get_multimodal_config()

    return mm_config.mm_processor_cache_type == "shm"


575
576
577
def processor_cache_from_config(
    vllm_config: "VllmConfig",
    mm_registry: "MultiModalRegistry",
578
) -> BaseMultiModalProcessorCache | None:
579
580
581
582
583
584
585
586
587
    """Return a `BaseMultiModalProcessorCache`, if enabled."""
    model_config = vllm_config.model_config

    if not _enable_processor_cache(model_config, mm_registry):
        return None

    if not _enable_ipc_cache(vllm_config):
        return MultiModalProcessorOnlyCache(model_config)

588
589
590
    if not _enable_mm_input_shm_cache(vllm_config):
        return MultiModalProcessorSenderCache(model_config)
    return ShmObjectStoreSenderCache(vllm_config)
591
592
593
594
595
596
597
598
599
600
601
602
603
604


def processor_only_cache_from_config(
    model_config: "ModelConfig",
    mm_registry: "MultiModalRegistry",
):
    """Return a `MultiModalProcessorOnlyCache`, if enabled."""
    if not _enable_processor_cache(model_config, mm_registry):
        return None

    return MultiModalProcessorOnlyCache(model_config)


class BaseMultiModalReceiverCache(
605
    BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
606
):
607
608
    """The required interface for caches on P1."""

609
610
611
612
613
614
    def get_and_update_features(
        self,
        mm_features: list["MultiModalFeatureSpec"],
    ) -> list["MultiModalFeatureSpec"]:
        """Update multimodal features with cached encoder outputs."""
        for feature in mm_features:
615
            feature.data = self.get_and_update_item(feature.data, feature.identifier)
616
617
        return mm_features

618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642

class MultiModalReceiverCache(BaseMultiModalReceiverCache):
    """
    The cache which is used on P1 when IPC caching is enabled.

    How to update each item:

    - If the item is in the cache, replace the input with the cached item.
    - If the item is not in the cache, store that item (which includes tensor
      data) into the cache, and return the input.
    """

    def __init__(self, model_config: "ModelConfig") -> None:
        super().__init__()

        mm_config = model_config.get_multimodal_config()

        self._cache = MultiModalCache.get_lru_cache(
            mm_config.mm_processor_cache_gb,
            MultiModalKwargsItem,
        )

    @override
    def get_and_update_item(
        self,
643
        mm_item: MultiModalKwargsItem | None,
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        mm_hash: str,
    ) -> MultiModalKwargsItem:
        if (cached_item := self._cache.get(mm_hash)) is not None:
            return cached_item

        assert mm_item is not None, f"Expected a cached item for {mm_hash=}"

        self._cache[mm_hash] = mm_item
        return mm_item

    @override
    def clear_cache(self) -> None:
        self._cache.clear()


659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
    """
    The cache which is used on P1 Worker Process when IPC caching is enabled.

    How to update each item:

    - If the item has an address, replace the input with the cached item.
    - If not, return the input.
    """

    def __init__(
        self,
        vllm_config: "VllmConfig",
        shared_worker_lock: LockType,
    ) -> None:
        super().__init__()

        self.world_size = vllm_config.parallel_config.world_size
        mm_config = vllm_config.model_config.get_multimodal_config()

        ring_buffer = SingleWriterShmRingBuffer(
            data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
681
            name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
682
683
684
            create=False,  # Server is a reader
        )
        self._shm_cache = SingleWriterShmObjectStorage(
685
            max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
686
687
688
689
690
691
692
693
694
            n_readers=self.world_size,
            ring_buffer=ring_buffer,
            serde_class=MsgpackSerde,
            reader_lock=shared_worker_lock,
        )

    @override
    def get_and_update_item(
        self,
695
        mm_item: MultiModalKwargsItem | None,
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        mm_hash: str,
    ) -> MultiModalKwargsItem:
        assert mm_item is not None, f"Expected an address item for {mm_hash=}"
        if "address" in mm_item:
            address = cast(int, mm_item["address"].data)
            monotonic_id = cast(int, mm_item["monotonic_id"].data)
            return self._shm_cache.get(address, monotonic_id)

        return mm_item

    @override
    def clear_cache(self) -> None:
        self._shm_cache.clear()


def engine_receiver_cache_from_config(
712
713
    vllm_config: "VllmConfig",
    mm_registry: "MultiModalRegistry",
714
) -> BaseMultiModalReceiverCache | None:
715
716
717
718
719
    """
    This is used in the engine process.
    Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
    mm_processor_cache_type=="lru".
    """
720
721
722
723
724
725
726
727
    model_config = vllm_config.model_config

    if not _enable_processor_cache(model_config, mm_registry):
        return None

    if not _enable_ipc_cache(vllm_config):
        return None

728
729
730
731
732
733
734
735
736
737
    if not _enable_mm_input_shm_cache(vllm_config):
        return MultiModalReceiverCache(model_config)

    return None


def worker_receiver_cache_from_config(
    vllm_config: "VllmConfig",
    mm_registry: "MultiModalRegistry",
    shared_worker_lock: LockType,
738
) -> BaseMultiModalReceiverCache | None:
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
    """
    This is used in the worker process.
    Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
    mm_processor_cache_type=="shm".
    """
    model_config = vllm_config.model_config

    if not _enable_processor_cache(model_config, mm_registry):
        return None

    if not _enable_ipc_cache(vllm_config):
        return None

    if not _enable_mm_input_shm_cache(vllm_config):
        return None

    return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)