mm_input_mapper.py 6.86 KB
Newer Older
1
from typing import Any, Dict, List, Optional
2

3
4
5
import PIL
from blake3 import blake3

6
from vllm.config import ModelConfig
7
8
from vllm.inputs import PromptType
from vllm.logger import init_logger
9
10
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
                             MultiModalKwargs, MultiModalRegistry)
11
from vllm.utils import LRUCache
12
13
14
15
16
17
18
19
20
21
22
23
24

logger = init_logger(__name__)

# The idea of MM preprocessor caching is based on having a client and a server,
# where the client executes in the frontend process (=P0) and the server in the
# core process (=P1).
#
# -- Client: Executes the MM mapper and performs caching of the results.
# -- Server: Performs caching of the results
#
# The caching for both client and server is mirrored/similar, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# client (=P0) and server (=P1) processes.
25

26
27
28
29
# Both Client and Server must use the same cache size
# (to perform mirrored caching)
# TODO: Tune the MM cache size
MM_CACHE_SIZE = 256
30

31
32

class MMInputMapperClient:
33
34
35
36
37
38

    def __init__(
        self,
        model_config: ModelConfig,
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
    ):
39
        self.model_config = model_config
40
41
42
43
44
        self.mm_registry = mm_registry
        self.multi_modal_input_mapper = mm_registry.create_input_mapper(
            model_config)
        self.mm_registry.init_mm_limits_per_prompt(model_config)

45
        # Init cache
46
        self.use_cache = not model_config.disable_mm_preprocessor_cache
47
        self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
48
49
50
51
52
53

        # DEBUG: Set to None to disable
        self.mm_debug_cache_hit_ratio_steps = None
        self.mm_cache_hits = 0
        self.mm_cache_total = 0

54
    def cache_hit_ratio(self, steps):
55
56
57
58
        if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
            logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
                         self.mm_cache_hits / self.mm_cache_total)

59
    # TODO: Support modalities beyond image.
60
61
62
    def process_inputs(
        self,
        mm_data: MultiModalDataDict,
63
        mm_hashes: Optional[List[str]],
64
        mm_processor_kwargs: Optional[Dict[str, Any]],
65
        precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
66
    ) -> List[MultiModalKwargs]:
67
68
69
70
71
72
73
74
        if precomputed_mm_inputs is None:
            image_inputs = mm_data["image"]
            if not isinstance(image_inputs, list):
                image_inputs = [image_inputs]
            num_inputs = len(image_inputs)
        else:
            num_inputs = len(precomputed_mm_inputs)

75
76
        # Sanity
        if self.use_cache:
77
            assert mm_hashes is not None
78
            assert num_inputs == len(mm_hashes)
79
80
81
82
83
84
85
86
87
88

        # Process each image input separately, so that later we can schedule
        # them in a fine-grained manner.
        # Apply caching (if enabled) and reuse precomputed inputs (if provided)
        ret_inputs: List[MultiModalKwargs] = []
        for input_id in range(num_inputs):
            if self.mm_debug_cache_hit_ratio_steps is not None:
                self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)

            mm_input = None
89
            if self.use_cache:
90
                assert mm_hashes is not None
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
                mm_hash = mm_hashes[input_id]
                mm_input = self.mm_cache.get(mm_hash)

            self.mm_cache_total += 1
            if mm_input is None:
                if precomputed_mm_inputs is not None:
                    # Reuse precomputed input (for merged preprocessor)
                    mm_input = precomputed_mm_inputs[input_id]
                else:
                    # Apply MM mapper
                    mm_input = self.multi_modal_input_mapper(
                        {"image": [image_inputs[input_id]]},
                        mm_processor_kwargs=mm_processor_kwargs,
                    )

106
                if self.use_cache:
107
                    # Add to cache
108
                    assert mm_hash is not None
109
110
111
112
113
114
115
                    self.mm_cache.put(mm_hash, mm_input)
            else:
                self.mm_cache_hits += 1
                mm_input = None  # Avoids sending mm_input to Server

            ret_inputs.append(mm_input)

116
        return ret_inputs
117
118
119
120


class MMInputMapperServer:

121
    def __init__(self, model_config):
122
        self.use_cache = not model_config.disable_mm_preprocessor_cache
123
        self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
124
125
126
127

    def process_inputs(
        self,
        mm_inputs: List[Optional[MultiModalKwargs]],
128
        mm_hashes: List[str],
129
    ) -> List[MultiModalKwargs]:
130
131
        assert len(mm_inputs) == len(mm_hashes)

132
133
134
        if not self.use_cache:
            return mm_inputs

135
136
        full_mm_inputs = []
        for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
137
            assert mm_hash is not None
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            if mm_input is None:
                mm_input = self.mm_cache.get(mm_hash)
                assert mm_input is not None
            else:
                self.mm_cache.put(mm_hash, mm_input)

            full_mm_inputs.append(mm_input)

        return full_mm_inputs


class MMHasher:

    def __init__(self):
        pass

154
    def hash_dummy_mm_data(
155
156
            self,
            mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
157
158
        """Hash user-defined dummy multimodal data used for profiling."""

159
160
161
162
163
        if mm_data is None:
            return None

        image_inputs = mm_data['image']

164
165
166
167
168
169
170
171
172
173
        # This is a temporary workaround for models (e.g, Molmo) that
        # process multimodal data in the input processor (therefore
        # image_inputs is MultiModalKwargs instead of raw input format).
        # `raw_mm_data` with the original input format is expected
        # in this case.
        if isinstance(image_inputs, dict):
            assert "raw_mm_data" in image_inputs and isinstance(
                image_inputs["raw_mm_data"], PIL.Image.Image)
            image_inputs = image_inputs.pop("raw_mm_data")

174
175
        return self.hash_images(image_inputs)

176
177
178
    def hash_prompt_mm_data(self, prompt: PromptType) -> Optional[List[str]]:
        """Hash multimodal data in the user input prompt if they exist."""

179
180
181
182
        if "multi_modal_data" not in prompt:
            return None

        mm_data = prompt["multi_modal_data"]
183
        image_inputs = mm_data["image"]
184
185
186
187

        return self.hash_images(image_inputs)

    def hash_images(self, image_inputs) -> Optional[List[str]]:
188
        """Hash PIL image objects to strings."""
189
190
        if not isinstance(image_inputs, list):
            image_inputs = [image_inputs]
191
        assert len(image_inputs) > 0
192

193
194
195
196
197
198
199
200
201
202
203
204
205
        ret = []
        for image in image_inputs:
            assert isinstance(image, PIL.Image.Image)

            # Convert image to bytes
            bytes = image.tobytes()

            # Hash image bytes
            hasher = blake3()
            hasher.update(bytes)
            ret.append(hasher.hexdigest())

        return ret