"vscode:/vscode.git/clone" did not exist on "9a0f3bdbe530f4d90e27cf9c6f5cc506e2b44c03"
mm_input_cache.py 4.54 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
5

6
from vllm.multimodal import MultiModalRegistry
7
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
8
9
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.utils import is_list_of
10

11
12
13
14
if TYPE_CHECKING:
    from vllm.config import ModelConfig

# The idea of multimodal input caching is based on having a client and
15
16
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
17
#
18
19
20
# -- P0:
#  - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
#    each input multi-modal item (e.g. image),
21
#  - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
22
23
#    which are MultiModalKwargsItem instances that each correspond to an
#    input multi-modal item.
24
#  - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
25
#    `mm_hash` for each item. It stores the `mm_hash` as keys and the size
26
#    of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
27
28
#    up additional memory in P0.
#  - The `mm_hash` is always sent to P1.
29
#  - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
30
#    in MultiModalInputCacheServer.
31
#
32
# -- P1:
33
34
35
36
37
#  - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
#    MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
#  - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
#    MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
#  - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
38
#    the engine for model execution.
39
#
40
41
42
43
44
# Both Client and Server must perform cache update and eviction based on the
# same item size. This ensures that the keys of MultiModalInputCacheClient
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
# whether a key is cached in MultiModalInputCacheServer by querying
# MultiModalInputCacheClient without having to communicate with P1.
45
46


47
48
class MultiModalInputCacheClient:
    """Used by P0 to check whether multi-modal kwargs are cached in P1."""
49

50
51
    def __init__(self, model_config: "ModelConfig",
                 mm_registry: MultiModalRegistry) -> None:
52
        super().__init__()
53

54
        self.enabled = mm_registry.enable_mm_input_cache(model_config)
55
56
57
58
        self.mm_cache = MultiModalCache.get_lru_cache(
            model_config.get_mm_input_cache_gb(),
            MultiModalCacheItemMetadata,
        )
59

60
    def get_and_update(
61
        self,
62
        mm_kwargs: Sequence[MultiModalKwargsItem],
63
        mm_hashes: list[str],
64
    ) -> list[Optional[MultiModalKwargsItem]]:
65
        if not self.enabled:
66
            return list(mm_kwargs)
67
68

        assert len(mm_kwargs) == len(mm_hashes)
69

70
        out_mm_items = list[Optional[MultiModalKwargsItem]]()
71
        for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
72
            if self.mm_cache.get(mm_hash) is not None:
73
                out_mm_items.append(None)
74
            else:
75
                self.mm_cache[mm_hash] = \
76
                    MultiModalCacheItemMetadata.wraps(mm_item)
77
                out_mm_items.append(mm_item)
78

79
        return out_mm_items
80

81
82
83
84
85
86
87
    def reset(self) -> None:
        self.mm_cache.clear()


class MultiModalInputCacheServer:
    """Used by P1 to avoid requiring past multi-modal kwargs from P0."""

88
89
    def __init__(self, model_config: "ModelConfig",
                 mm_registry: MultiModalRegistry) -> None:
90
91
        super().__init__()

92
        self.enabled = mm_registry.enable_mm_input_cache(model_config)
93
94
        self.mm_cache = MultiModalCache.get_lru_cache(
            model_config.get_mm_input_cache_gb(),
95
            MultiModalKwargsItem,
96
97
98
        )

    def get_and_update(
99
        self,
100
        mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
101
        mm_hashes: list[str],
102
    ) -> list[MultiModalKwargsItem]:
103
        if not self.enabled:
104
105
106
            mm_kwargs_lst = list(mm_kwargs)
            assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
            return mm_kwargs_lst
107

108
        assert len(mm_kwargs) == len(mm_hashes)
109

110
111
        out_mm_items = list[MultiModalKwargsItem]()
        for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
112
113
            if mm_item is None:
                out_mm_items.append(self.mm_cache[mm_hash])
114
            else:
115
                self.mm_cache[mm_hash] = mm_item
116
                out_mm_items.append(mm_item)
117

118
        return out_mm_items
119

120
    def reset(self) -> None:
121
        self.mm_cache.clear()