mm_input_cache.py 3.14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from collections.abc import Sequence
from typing import Optional
5

6
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
7
from vllm.multimodal import MultiModalKwargs
8
from vllm.multimodal.processing import ProcessingCache
9
from vllm.utils import is_list_of
10

11
12
13
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
14
#
15
# -- Client:
16
17
#  - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
#    with built-in caching functionality, with mm_hash as its identifier.
18
19
#  - MirroredProcessingCache to keep track of the cached entries and
#    determine whether to send the MultiModalKwargs to P1.
20
21
#
# -- Server:
22
#  - MirroredProcessingCache to store the MultiModalKwargs from P0.
23
#
24
# The caching for both client and server is mirrored, and this allows us
25
# to avoid the serialization of "mm_inputs" (like pixel values) between
26
27
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
# cache.
28

29
# Both Client and Server must use the same cache size
30
# (to perform mirrored caching). This cache size is set by the environment
31
# variable VLLM_MM_INPUT_CACHE_GIB.
32

33

34
class MirroredProcessingCache:
35

36
    def __init__(self, model_config):
37
        mm_config = model_config.multimodal_config
38
39
        disable_mm_preprocessor_cache = (
            mm_config is not None and mm_config.disable_mm_preprocessor_cache)
40
        self.use_cache = not disable_mm_preprocessor_cache
41
42
        self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
                                                      MultiModalKwargs)
43

44
    def get_and_update_p0(
45
        self,
46
        mm_inputs: Sequence[MultiModalKwargs],
47
        mm_hashes: list[str],
48
    ) -> Sequence[Optional[MultiModalKwargs]]:
49
50
        assert len(mm_inputs) == len(mm_hashes)

51
        if not self.use_cache:
52
            assert is_list_of(mm_inputs, MultiModalKwargs)
53
54
            return mm_inputs

55
56
        full_mm_inputs = list[Optional[MultiModalKwargs]]()
        for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
57
            if self.mm_cache.get(mm_hash) is not None:
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
                mm_input = None
            else:
                self.mm_cache[mm_hash] = mm_input

            full_mm_inputs.append(mm_input)

        return full_mm_inputs

    def get_and_update_p1(
        self,
        mm_inputs: Sequence[Optional[MultiModalKwargs]],
        mm_hashes: list[str],
    ) -> Sequence[MultiModalKwargs]:
        assert len(mm_inputs) == len(mm_hashes)

        if not self.use_cache:
            assert is_list_of(mm_inputs, MultiModalKwargs)
            return mm_inputs

        full_mm_inputs = list[MultiModalKwargs]()
78
79
        for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
            if mm_input is None:
80
                mm_input = self.mm_cache[mm_hash]
81
            else:
82
                self.mm_cache[mm_hash] = mm_input
83
84
85
86

            full_mm_inputs.append(mm_input)

        return full_mm_inputs
87
88
89
90
91

    def reset(self) -> bool:
        self.mm_cache.clear()

        return True