mm_input_cache.py 4.44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from collections.abc import Mapping
from typing import TYPE_CHECKING
5

6
from vllm.multimodal import MultiModalRegistry
7
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
8
from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors
9

10
11
12
13
if TYPE_CHECKING:
    from vllm.config import ModelConfig

# The idea of multimodal input caching is based on having a client and
14
15
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
16
#
17
18
19
# -- P0:
#  - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
#    each input multi-modal item (e.g. image),
20
#  - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
21
22
#    which are MultiModalKwargsItem instances that each correspond to an
#    input multi-modal item.
23
#  - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
24
#    `mm_hash` for each item. It stores the `mm_hash` as keys and the size
25
#    of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
26
27
#    up additional memory in P0.
#  - The `mm_hash` is always sent to P1.
28
#  - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
29
#    in MultiModalInputCacheServer.
30
#
31
# -- P1:
32
33
34
35
36
#  - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
#    MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
#  - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
#    MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
#  - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
37
#    the engine for model execution.
38
#
39
40
41
42
43
# Both Client and Server must perform cache update and eviction based on the
# same item size. This ensures that the keys of MultiModalInputCacheClient
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
# whether a key is cached in MultiModalInputCacheServer by querying
# MultiModalInputCacheClient without having to communicate with P1.
44
45


46
47
class MultiModalInputCacheClient:
    """Used by P0 to check whether multi-modal kwargs are cached in P1."""
48

49
50
    def __init__(self, model_config: "ModelConfig",
                 mm_registry: MultiModalRegistry) -> None:
51
        super().__init__()
52

53
        self.enabled = mm_registry.enable_mm_input_cache(model_config)
54
55
56
57
        self.mm_cache = MultiModalCache.get_lru_cache(
            model_config.get_mm_input_cache_gb(),
            MultiModalCacheItemMetadata,
        )
58

59
    def get_and_update(
60
        self,
61
        mm_kwargs: list[MultiModalKwargsItem],
62
        mm_hashes: list[str],
63
    ) -> list[MultiModalKwargsItem]:
64
        if not self.enabled:
65
66
67
            return mm_kwargs

        assert len(mm_kwargs) == len(mm_hashes)
68

69
70
        out_mm_items = list[MultiModalKwargsItem]()
        for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
71
            if self.mm_cache.get(mm_hash) is not None:
72
                out_mm_items.append(mm_item.without_data())
73
            else:
74
                self.mm_cache[mm_hash] = \
75
76
                    MultiModalCacheItemMetadata.wraps(mm_item.require_data())
                out_mm_items.append(mm_item)
77

78
        return out_mm_items
79

80
81
82
83
84
85
86
    def reset(self) -> None:
        self.mm_cache.clear()


class MultiModalInputCacheServer:
    """Used by P1 to avoid requiring past multi-modal kwargs from P0."""

87
88
    def __init__(self, model_config: "ModelConfig",
                 mm_registry: MultiModalRegistry) -> None:
89
90
        super().__init__()

91
        self.enabled = mm_registry.enable_mm_input_cache(model_config)
92
93
        self.mm_cache = MultiModalCache.get_lru_cache(
            model_config.get_mm_input_cache_gb(),
94
            Mapping[str, NestedTensors],
95
96
97
        )

    def get_and_update(
98
        self,
99
        mm_kwargs: list[MultiModalKwargsItem],
100
        mm_hashes: list[str],
101
    ) -> list[MultiModalKwargsItem]:
102
        if not self.enabled:
103
            return mm_kwargs
104

105
        assert len(mm_kwargs) == len(mm_hashes)
106

107
108
109
110
111
112
113
        out_mm_items = list[MultiModalKwargsItem]()
        for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
            if (mm_data := mm_item.get_data()) is None:
                out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash]))
            else:
                self.mm_cache[mm_hash] = mm_data
                out_mm_items.append(mm_item)
114

115
        return out_mm_items
116

117
    def reset(self) -> None:
118
        self.mm_cache.clear()