mm_input_cache.py 1.89 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
4
from vllm.multimodal import MultiModalKwargs
5
from vllm.multimodal.processing import ProcessingCache
6

7
8
9
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
10
#
11
# -- Client:
12
13
#  - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
#    with built-in caching functionality, with mm_hash as its identifier.
14
15
#
# -- Server:
16
#  - MMInputCacheServer to perform caching of the received MultiModalKwargs.
17
#
18
# The caching for both client and server is mirrored, and this allows us
19
# to avoid the serialization of "mm_inputs" (like pixel values) between
20
21
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
# cache.
22

23
# Both Client and Server must use the same cache size
24
# (to perform mirrored caching). This cache size is set by the environment
25
# variable VLLM_MM_INPUT_CACHE_GIB.
26

27

28
class MMInputCacheServer:
29

30
    def __init__(self, model_config):
31
        self.use_cache = not model_config.disable_mm_preprocessor_cache
32
33
        self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
                                                      MultiModalKwargs)
34

35
    def get_and_update(
36
        self,
37
        mm_inputs: list[MultiModalKwargs],
38
        mm_hashes: list[str],
39
    ) -> list[MultiModalKwargs]:
40
41
        assert len(mm_inputs) == len(mm_hashes)

42
43
44
        if not self.use_cache:
            return mm_inputs

45
46
        full_mm_inputs = []
        for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
47
            assert mm_hash is not None
48
            if mm_input is None:
49
                mm_input = self.mm_cache[mm_hash]
50
            else:
51
                self.mm_cache[mm_hash] = mm_input
52
53
54
55

            full_mm_inputs.append(mm_input)

        return full_mm_inputs