mm_input_cache.py 4.47 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Sequence
4
from typing import TYPE_CHECKING, Optional
5

6
from vllm.multimodal import MultiModalKwargs, MultiModalRegistry
7
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
8
from vllm.utils import is_list_of
9

10
11
12
13
if TYPE_CHECKING:
    from vllm.config import ModelConfig

# The idea of multimodal input caching is based on having a client and
14
15
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
16
#
17
18
19
20
21
22
23
24
25
26
27
28
29
# -- P0:
#  - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
#    each input multi-modal item (e.g. image),
#  - BaseMultiModalProcessor processes the input items into `mm_inputs`,
#    which are MultiModalKwargsItem instances that each correspond to an
#    input multi-modal item.
#  - MultiModalInputCacheClient accepts the `mm_inputs` and corresponding
#    `mm_hash` for each item. It stores the `mm_hash` as keys and the size
#    of `mm_inputs`, but not the `mm_inputs` themselves, to avoid taking
#    up additional memory in P0.
#  - The `mm_hash` is always sent to P1.
#  - The corresponding `mm_inputs` are only sent to P1 if they are not cached
#    in MultiModalInputCacheServer.
30
#
31
32
33
34
35
36
37
# -- P1:
#  - If the `mm_hash` is cached (i.e. `mm_inputs` are not sent from P0),
#    MultiModalInputCacheServer retrieves the corresponding `mm_inputs`.
#  - If the `mm_hash` is not cached (i.e. `mm_inputs` are sent from P0),
#    MultiModalInputCacheServer stores `mm_inputs` under the key `mm_hash`.
#  - Either way, the `mm_hash` and corresponding `mm_inputs` are sent to
#    the engine for model execution.
38
#
39
40
41
42
43
# Both Client and Server must perform cache update and eviction based on the
# same item size. This ensures that the keys of MultiModalInputCacheClient
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
# whether a key is cached in MultiModalInputCacheServer by querying
# MultiModalInputCacheClient without having to communicate with P1.
44
45


46
47
class MultiModalInputCacheClient:
    """Used by P0 to check whether multi-modal kwargs are cached in P1."""
48

49
50
    def __init__(self, model_config: "ModelConfig",
                 mm_registry: MultiModalRegistry) -> None:
51
        super().__init__()
52

53
        self.enabled = mm_registry.enable_mm_input_cache(model_config)
54
55
56
57
        self.mm_cache = MultiModalCache.get_lru_cache(
            model_config.get_mm_input_cache_gb(),
            MultiModalCacheItemMetadata,
        )
58

59
    def get_and_update(
60
        self,
61
        mm_inputs: Sequence[MultiModalKwargs],
62
        mm_hashes: list[str],
63
    ) -> Sequence[Optional[MultiModalKwargs]]:
64
65
        assert len(mm_inputs) == len(mm_hashes)

66
        if not self.enabled:
67
            assert is_list_of(mm_inputs, MultiModalKwargs)
68
69
            return mm_inputs

70
71
        full_mm_inputs = list[Optional[MultiModalKwargs]]()
        for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
72
            if self.mm_cache.get(mm_hash) is not None:
73
74
                mm_input = None
            else:
75
76
                self.mm_cache[mm_hash] = \
                    MultiModalCacheItemMetadata.wraps(mm_input)
77
78
79
80
81

            full_mm_inputs.append(mm_input)

        return full_mm_inputs

82
83
84
85
86
87
88
    def reset(self) -> None:
        self.mm_cache.clear()


class MultiModalInputCacheServer:
    """Used by P1 to avoid requiring past multi-modal kwargs from P0."""

89
90
    def __init__(self, model_config: "ModelConfig",
                 mm_registry: MultiModalRegistry) -> None:
91
92
        super().__init__()

93
        self.enabled = mm_registry.enable_mm_input_cache(model_config)
94
95
96
97
98
99
        self.mm_cache = MultiModalCache.get_lru_cache(
            model_config.get_mm_input_cache_gb(),
            MultiModalKwargs,
        )

    def get_and_update(
100
101
102
103
104
105
        self,
        mm_inputs: Sequence[Optional[MultiModalKwargs]],
        mm_hashes: list[str],
    ) -> Sequence[MultiModalKwargs]:
        assert len(mm_inputs) == len(mm_hashes)

106
        if not self.enabled:
107
108
109
110
            assert is_list_of(mm_inputs, MultiModalKwargs)
            return mm_inputs

        full_mm_inputs = list[MultiModalKwargs]()
111
112
        for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
            if mm_input is None:
113
                mm_input = self.mm_cache[mm_hash]
114
            else:
115
                self.mm_cache[mm_hash] = mm_input
116
117
118
119

            full_mm_inputs.append(mm_input)

        return full_mm_inputs
120

121
    def reset(self) -> None:
122
        self.mm_cache.clear()