test_processor_kwargs.py 14.9 KB
Newer Older
1
from array import array
2
from typing import Callable, Dict, Mapping, Optional
3
4
5
6
7
from unittest.mock import patch

import pytest
import torch

8
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
9
                         InputRegistry, ProcessorInputs, token_inputs)
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

from ..models.utils import build_model_context

# Used for fast tests where the model doesn't matter
DUMMY_MODEL_ID = "facebook/opt-125m"
# Used for tests that need a multimodal model
MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"

# For mm_processor_kwargs - we test overrides by defining mocks for each place
# it is used, and ensuring that we can pass processor kwargs an override value
# to receive the intended result for things like sequence length etc.
DEFAULT_NUM_CROPS = 4
NUM_CROPS_OVERRIDE = 16


# Mocks for all of the places that we use the mm_processor_kwargs
# to override values in different callables
@pytest.fixture
def use_processor_mock():
    """Patches the internal model input processor with an override callable."""

    def custom_processor(ctx: InputContext,
34
                         inputs: DecoderOnlyInputs,
35
36
                         *,
                         num_crops=DEFAULT_NUM_CROPS):
37
38
39
        # For testing purposes, we don't worry about the prompt
        return token_inputs(prompt_token_ids=[],
                            mm_processor_kwargs={"num_crops": num_crops})
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
               return_value=custom_processor):
        yield


@pytest.fixture
def use_dummy_data_mock():
    """Patches the internal model input processor with an override callable."""

    def custom_dummy_data_factory(self,
                                  ctx: InputContext,
                                  seq_len: int,
                                  mm_counts: Mapping[str, int],
                                  *,
                                  num_crops=DEFAULT_NUM_CROPS):
        seq_data = SequenceData(
            array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
58
        return DummyData(seq_data, None)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    with patch(
            "vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
            custom_dummy_data_factory):
        yield


# Lazy import to avoid CUDA reinitialization error
def mm_model_cls():
    from vllm.model_executor.models.phi3v import Phi3VForCausalLM

    return Phi3VForCausalLM


# lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
76
    "pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
77
78
79
}


80
### Tests for default processor logic & mm_processor_kwargs wrapping
81
82
83
84
85
def test_default_processor_is_a_noop():
    """Ensure that by default, there is no processor override."""
    dummy_registry = InputRegistry()
    ctx = build_model_context(DUMMY_MODEL_ID)
    processor = dummy_registry.create_input_processor(ctx.model_config)
86
    proc_inputs = token_inputs(prompt_token_ids=[], prompt="")
87
88
89
90
    proc_outputs = processor(inputs=proc_inputs)
    assert proc_inputs is proc_outputs


91
92
def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
    """Get the init / inference kwargs and expected num_crops for this test."""
93
94
95
    # If we have a value for num_crops, pass the override value and make
    # sure we get that value as a return-value from out mock processor,
    # otherwise fall back to the default value
96
97
    init_kwargs = None if init_num_crops is None else {
        "num_crops": init_num_crops
98
    }
99
100
101
102
103
104
105
106
107
108
109
110
    inference_kwargs = None if inference_num_crops is None else {
        "num_crops": inference_num_crops
    }
    if inference_num_crops is not None:
        expected_seq_count = inference_num_crops
    elif init_num_crops is not None:
        expected_seq_count = init_num_crops
    else:
        expected_seq_count = DEFAULT_NUM_CROPS
    return init_kwargs, inference_kwargs, expected_seq_count


111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def _get_processed_num_crops(
    processor: Callable[[ProcessorInputs], ProcessorInputs],
    inference_kwargs: Optional[Dict[str, int]],
) -> int:
    processed_inputs = processor(
        token_inputs(prompt_token_ids=[],
                     prompt="",
                     mm_processor_kwargs=inference_kwargs))

    assert "type" in processed_inputs
    assert processed_inputs["type"] == "token"
    assert "mm_processor_kwargs" in processed_inputs
    return processed_inputs["mm_processor_kwargs"]["num_crops"]


126
127
128
129
130
131
132
133
134
135
136
137
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
    (None, None),
    (NUM_CROPS_OVERRIDE, None),
    (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
])
def test_input_processor_kwargs(use_processor_mock, init_num_crops,
                                inference_num_crops):
    """Ensure input processors can use processor kwargs."""
    dummy_registry = InputRegistry()

    init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
        init_num_crops, inference_num_crops)
138

139
140
    ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
    processor = dummy_registry.create_input_processor(ctx.model_config)
141
142
    num_crops_val = _get_processed_num_crops(processor, inference_kwargs)

143
    assert num_crops_val == expected_seq_count
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161


@pytest.mark.parametrize(
    "mm_processor_kwargs",
    [
        # Not part of the signature
        {
            "does_not_exist": 100
        },
        # Part of the signature, not keyword only
        {
            "ctx": "something bad"
        }
    ])
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
                                            mm_processor_kwargs):
    """Ensure that input processors filter out invalid mm_processor_kwargs"""
    dummy_registry = InputRegistry()
162
    # Should filter out the init time kwargs
163
164
165
166
    ctx = build_model_context(DUMMY_MODEL_ID,
                              mm_processor_kwargs=mm_processor_kwargs)

    processor = dummy_registry.create_input_processor(ctx.model_config)
167
    # Should filter out the inference time kwargs
168
    num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs)
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    assert num_crops_val == DEFAULT_NUM_CROPS


### Test overrides for the dummy data
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
    """Ensure dummy data factories can use processor kwargs."""
    mm_processor_kwargs = None if num_crops is None else {
        "num_crops": num_crops
    }
    expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
    dummy_registry = InputRegistry()
    ctx = build_model_context(DUMMY_MODEL_ID,
                              mm_processor_kwargs=mm_processor_kwargs)
    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)

    # NOTE: seq_len is thrown away here since this will leverage the
    # default dummy data factory that we have patched in, whose seq
    # len is solely dependent on the value of the mm_processor_kwargs.
189
    dummy_data = dummy_registry.dummy_data_for_profiling(
190
        ctx.model_config, seq_len=-1, mm_registry=mm_registry)
191
    assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217


@pytest.mark.parametrize(
    "mm_processor_kwargs",
    [
        # Not part of the signature
        {
            "does_not_exist": 100
        },
        # Part of the signature, not keyword only
        {
            "ctx": "something bad"
        }
    ])
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
                                             mm_processor_kwargs):
    """Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
    dummy_registry = InputRegistry()
    ctx = build_model_context(DUMMY_MODEL_ID,
                              mm_processor_kwargs=mm_processor_kwargs)
    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)

    # NOTE: seq_len is thrown away here since this will leverage the
    # default dummy data factory that we have patched in, whose seq
    # len is solely dependent on the value of the mm_processor_kwargs.
218
    dummy_data = dummy_registry.dummy_data_for_profiling(
219
        ctx.model_config, seq_len=-1, mm_registry=mm_registry)
220
    assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
221
222
223
224
225
226
227
228
229
230
231
232


### Test overrides for the max token count per multimodal instance
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_max_tokens_kwarg_overrides(num_crops):
    """Ensure max token calcs can use processor kwargs."""
    mm_processor_kwargs = None if num_crops is None else {
        "num_crops": num_crops
    }
    expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops

    ctx = build_model_context(MULTIMODAL_MODEL_ID,
233
                              task="generate",
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
                              trust_remote_code=True,
                              mm_processor_kwargs=mm_processor_kwargs,
                              limit_mm_per_prompt={"image": 1})

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)
    # Patch the image registry for phi3v with our lambda that is compatible
    # with overrides, then ensure that calling the method correctly echos
    # our num_crops value back from the mm_processor_kwargs.
    with patch.object(
            mm_registry._get_plugin("image"),
            "_max_mm_tokens",
        {mm_model_cls(): get_num_crops},
    ):
        max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
            ctx.model_config)

    assert expected_seq_count == max_multimodal_tokens


@pytest.mark.parametrize(
    "mm_processor_kwargs",
    [
        # Not part of the signature
        {
            "does_not_exist": 100
        },
        # Part of the signature, not keyword only
        {
            "ctx": "something bad"
        }
    ])
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
    """Ensure that max token calcs filters out invalid mm_processor_kwargs"""
    ctx = build_model_context(MULTIMODAL_MODEL_ID,
269
                              task="generate",
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
                              trust_remote_code=True,
                              mm_processor_kwargs=mm_processor_kwargs,
                              limit_mm_per_prompt={"image": 1})

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)

    # Similar before, but since these kwargs get filtered,
    # we always get our default value back.
    with patch.object(
            mm_registry._get_plugin("image"),
            "_max_mm_tokens",
        {mm_model_cls(): get_num_crops},
    ):
        max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
            ctx.model_config)

    assert max_multimodal_tokens == DEFAULT_NUM_CROPS


### Test overrides for the mapper
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
292
def test_default_mapper_with_processor_kwargs(image_assets, num_crops):
293
294
295
296
297
    """Ensure that the mapper processor kwargs can fall back to HF models."""
    # NOTE - we don't validate bad inputs for the default mapper, because it's
    # through the automodel interface in transformers, so we can't easily
    # inspect what kwargs are or are not allowed.
    ctx = build_model_context(MULTIMODAL_MODEL_ID,
298
                              task="generate",
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                              trust_remote_code=True,
                              mm_processor_kwargs={"num_crops": num_crops},
                              limit_mm_per_prompt={"image": 1})

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)

    image = image_assets[0].pil_image
    mm_inputs = {"image": image}

    mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
    # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
    assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1


314
315
316
317
318
319
320
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
    (None, None),
    (NUM_CROPS_OVERRIDE, None),
    (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
])
def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
                                       inference_num_crops):
321
    """Ensure custom mappers can use processor kwargs."""
322
323
324
    init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
        init_num_crops, inference_num_crops)

325
    ctx = build_model_context(MULTIMODAL_MODEL_ID,
326
                              task="generate",
327
                              trust_remote_code=True,
328
                              mm_processor_kwargs=init_kwargs,
329
330
331
332
333
334
335
                              limit_mm_per_prompt={"image": 1})

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)
    image = image_assets[0].pil_image
    mm_inputs = {"image": image}

336
337
338
339
340
341
342
    # Patch the image registry for phi3v with our lambda that is compatible
    # with overrides, then ensure that calling the method correctly echos
    # our num_crops value back from the mm_processor_kwargs.
    mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
        mm_model_cls())
    mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
                                          inference_kwargs)
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361

    assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1


@pytest.mark.parametrize(
    "mm_processor_kwargs",
    [
        # Not part of the signature
        {
            "does_not_exist": 100
        },
        # Part of the signature, not keyword only
        {
            "ctx": "something bad"
        }
    ])
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
                                                mm_processor_kwargs):
    """Ensure that custom mappers filters out invalid mm_processor_kwargs"""
362
    # Should filter out the init time kwargs
363
    ctx = build_model_context(MULTIMODAL_MODEL_ID,
364
                              task="generate",
365
366
367
368
369
370
371
372
373
                              trust_remote_code=True,
                              mm_processor_kwargs=mm_processor_kwargs,
                              limit_mm_per_prompt={"image": 1})

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)
    image = image_assets[0].pil_image
    mm_inputs = {"image": image}

374
375
376
377
378
379
380
381
    # Patch the image registry for phi3v with our lambda that is compatible
    # with overrides, then ensure that calling the method correctly echos
    # our num_crops value back from the mm_processor_kwargs.
    mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
        mm_model_cls())
    # Should filter out the inference time kwargs
    mapped_inputs = mm_registry.map_input(
        ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
382
383

    assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1