mm_input_mapper.py 5.99 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Tuple
2

3
4
5
import PIL
from blake3 import blake3

6
from vllm.config import ModelConfig
7
8
from vllm.inputs import PromptType
from vllm.logger import init_logger
9
10
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
                             MultiModalKwargs, MultiModalRegistry)
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from vllm.v1.utils import LRUDictCache

logger = init_logger(__name__)

# The idea of MM preprocessor caching is based on having a client and a server,
# where the client executes in the frontend process (=P0) and the server in the
# core process (=P1).
#
# -- Client: Executes the MM mapper and performs caching of the results.
# -- Server: Performs caching of the results
#
# The caching for both client and server is mirrored/similar, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# client (=P0) and server (=P1) processes.
25

26
27
28
29
# Both Client and Server must use the same cache size
# (to perform mirrored caching)
# TODO: Tune the MM cache size
MM_CACHE_SIZE = 256
30

31
32

class MMInputMapperClient:
33
34
35
36
37
38

    def __init__(
        self,
        model_config: ModelConfig,
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
    ):
39
        self.model_config = model_config
40
41
42
43
44
        self.mm_registry = mm_registry
        self.multi_modal_input_mapper = mm_registry.create_input_mapper(
            model_config)
        self.mm_registry.init_mm_limits_per_prompt(model_config)

45
        self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
46
47
48
49
50
51

        # DEBUG: Set to None to disable
        self.mm_debug_cache_hit_ratio_steps = None
        self.mm_cache_hits = 0
        self.mm_cache_total = 0

52
    def cache_hit_ratio(self, steps):
53
54
55
56
        if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
            logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
                         self.mm_cache_hits / self.mm_cache_total)

57
    # TODO: Support modalities beyond image.
58
59
60
    def process_inputs(
        self,
        mm_data: MultiModalDataDict,
61
        mm_hashes: Optional[List[str]],
62
        mm_processor_kwargs: Optional[Dict[str, Any]],
63
        precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
64
    ) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]:
65
66
67
68
69
70
71
72
73
74
75
        if precomputed_mm_inputs is None:
            image_inputs = mm_data["image"]
            if not isinstance(image_inputs, list):
                image_inputs = [image_inputs]
            num_inputs = len(image_inputs)
        else:
            num_inputs = len(precomputed_mm_inputs)

        # Check if hash is enabled
        use_hash = mm_hashes is not None
        if use_hash:
76
            assert mm_hashes is not None
77
78
79
80
81
82
83
            assert num_inputs == len(
                mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
                    num_inputs, len(mm_hashes))

        # Process each image input separately, so that later we can schedule
        # them in a fine-grained manner.
        # Apply caching (if enabled) and reuse precomputed inputs (if provided)
84
        ret_hashes: Optional[List[str]] = [] if use_hash else None
85
86
87
88
89
90
91
92
        ret_inputs: List[MultiModalKwargs] = []
        for input_id in range(num_inputs):
            if self.mm_debug_cache_hit_ratio_steps is not None:
                self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)

            mm_hash = None
            mm_input = None
            if use_hash:
93
                assert mm_hashes is not None
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
                mm_hash = mm_hashes[input_id]
                mm_input = self.mm_cache.get(mm_hash)

            self.mm_cache_total += 1
            if mm_input is None:
                if precomputed_mm_inputs is not None:
                    # Reuse precomputed input (for merged preprocessor)
                    mm_input = precomputed_mm_inputs[input_id]
                else:
                    # Apply MM mapper
                    mm_input = self.multi_modal_input_mapper(
                        {"image": [image_inputs[input_id]]},
                        mm_processor_kwargs=mm_processor_kwargs,
                    )

                if use_hash:
                    # Add to cache
111
                    assert mm_hash is not None
112
113
114
115
116
117
                    self.mm_cache.put(mm_hash, mm_input)
            else:
                self.mm_cache_hits += 1
                mm_input = None  # Avoids sending mm_input to Server

            if use_hash:
118
119
                assert mm_hash is not None
                assert ret_hashes is not None
120
121
122
123
124
125
126
127
128
                ret_hashes.append(mm_hash)
            ret_inputs.append(mm_input)

        return ret_inputs, ret_hashes


class MMInputMapperServer:

    def __init__(self, ):
129
        self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
130
131
132
133

    def process_inputs(
        self,
        mm_inputs: List[Optional[MultiModalKwargs]],
134
        mm_hashes: List[str],
135
    ) -> List[MultiModalKwargs]:
136
137
138
139
        assert len(mm_inputs) == len(mm_hashes)

        full_mm_inputs = []
        for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
140
            assert mm_hash is not None
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
            if mm_input is None:
                mm_input = self.mm_cache.get(mm_hash)
                assert mm_input is not None
            else:
                self.mm_cache.put(mm_hash, mm_input)

            full_mm_inputs.append(mm_input)

        return full_mm_inputs


class MMHasher:

    def __init__(self):
        pass

    def hash(self, prompt: PromptType) -> Optional[List[str]]:
        if "multi_modal_data" not in prompt:
            return None

        mm_data = prompt["multi_modal_data"]
162
163
164
        image_inputs = mm_data["image"]
        if not isinstance(image_inputs, list):
            image_inputs = [image_inputs]
165
        assert len(image_inputs) > 0
166

167
168
169
170
171
172
173
174
175
176
177
178
179
        ret = []
        for image in image_inputs:
            assert isinstance(image, PIL.Image.Image)

            # Convert image to bytes
            bytes = image.tobytes()

            # Hash image bytes
            hasher = blake3()
            hasher.update(bytes)
            ret.append(hasher.hexdigest())

        return ret