test_processor_kwargs.py 15.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from array import array
4
from typing import Callable, Dict, Mapping, Optional
5
6
7
8
9
from unittest.mock import patch

import pytest
import torch

10
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
11
                         InputRegistry, ProcessorInputs, token_inputs)
12
13
14
15
16
17
18
19
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

from ..models.utils import build_model_context

# Used for fast tests where the model doesn't matter
DUMMY_MODEL_ID = "facebook/opt-125m"
# Used for tests that need a multimodal model
20
MULTIMODAL_MODEL_ID = "OpenGVLab/InternVL2-2B"
21
22
23
24

# For mm_processor_kwargs - we test overrides by defining mocks for each place
# it is used, and ensuring that we can pass processor kwargs an override value
# to receive the intended result for things like sequence length etc.
25
26
DEFAULT_MAX_DYNAMIC_PATCH = 6
MAX_DYNAMIC_PATCH_OVERRIDE = 4
27
28
29
30
31
32
33
34
35


# Mocks for all of the places that we use the mm_processor_kwargs
# to override values in different callables
@pytest.fixture
def use_processor_mock():
    """Patches the internal model input processor with an override callable."""

    def custom_processor(ctx: InputContext,
36
                         inputs: DecoderOnlyInputs,
37
                         *,
38
                         max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH):
39
        # For testing purposes, we don't worry about the prompt
40
41
42
        return token_inputs(
            prompt_token_ids=[],
            mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch})
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
               return_value=custom_processor):
        yield


@pytest.fixture
def use_dummy_data_mock():
    """Patches the internal model input processor with an override callable."""

    def custom_dummy_data_factory(self,
                                  ctx: InputContext,
                                  seq_len: int,
                                  mm_counts: Mapping[str, int],
                                  *,
58
                                  max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH):
59
        seq_data = SequenceData(
60
            array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * max_dynamic_patch))
61
        return DummyData(seq_data, None)
62
63
64
65
66
67
68
69
70

    with patch(
            "vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
            custom_dummy_data_factory):
        yield


# Lazy import to avoid CUDA reinitialization error
def mm_model_cls():
71
    from vllm.model_executor.models.internvl import InternVLChatModel
72

73
    return InternVLChatModel
74
75
76


# lambda whose signature matches max token calcs extra & mapper + extra kwargs
77
78
79
get_max_dynamic_patch = lambda ctx, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: max_dynamic_patch  # noqa: E501
custom_mapper = lambda ctx, data, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: {  # noqa: E501
    "pixel_values": torch.zeros(size=(1, max_dynamic_patch + 1, 3, 448, 448))
80
81
82
}


83
### Tests for default processor logic & mm_processor_kwargs wrapping
84
85
86
87
88
def test_default_processor_is_a_noop():
    """Ensure that by default, there is no processor override."""
    dummy_registry = InputRegistry()
    ctx = build_model_context(DUMMY_MODEL_ID)
    processor = dummy_registry.create_input_processor(ctx.model_config)
89
    proc_inputs = token_inputs(prompt_token_ids=[], prompt="")
90
91
92
93
    proc_outputs = processor(inputs=proc_inputs)
    assert proc_inputs is proc_outputs


94
95
96
97
def _get_max_dynamic_patch_info(init_max_dynamic_patch: int,
                                inference_max_dynamic_patch: int):
    """Get the init / inference kwargs and expected max_dynamic_patch."""
    # If we have a value for max_dynamic_patch, pass the override value and make
98
99
    # sure we get that value as a return-value from out mock processor,
    # otherwise fall back to the default value
100
101
    init_kwargs = None if init_max_dynamic_patch is None else {
        "max_dynamic_patch": init_max_dynamic_patch
102
    }
103
104
    inference_kwargs = None if inference_max_dynamic_patch is None else {
        "max_dynamic_patch": inference_max_dynamic_patch
105
    }
106
107
108
109
    if inference_max_dynamic_patch is not None:
        expected_seq_count = inference_max_dynamic_patch
    elif init_max_dynamic_patch is not None:
        expected_seq_count = init_max_dynamic_patch
110
    else:
111
        expected_seq_count = DEFAULT_MAX_DYNAMIC_PATCH
112
113
114
    return init_kwargs, inference_kwargs, expected_seq_count


115
def _get_processed_max_dynamic_patch(
116
117
118
119
120
121
122
123
124
125
126
    processor: Callable[[ProcessorInputs], ProcessorInputs],
    inference_kwargs: Optional[Dict[str, int]],
) -> int:
    processed_inputs = processor(
        token_inputs(prompt_token_ids=[],
                     prompt="",
                     mm_processor_kwargs=inference_kwargs))

    assert "type" in processed_inputs
    assert processed_inputs["type"] == "token"
    assert "mm_processor_kwargs" in processed_inputs
127
    return processed_inputs["mm_processor_kwargs"]["max_dynamic_patch"]
128
129


130
131
132
133
134
135
136
137
@pytest.mark.parametrize(
    "init_max_dynamic_patch,inference_max_dynamic_patch", [
        (None, None),
        (MAX_DYNAMIC_PATCH_OVERRIDE, None),
        (DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE),
    ])
def test_input_processor_kwargs(use_processor_mock, init_max_dynamic_patch,
                                inference_max_dynamic_patch):
138
139
140
    """Ensure input processors can use processor kwargs."""
    dummy_registry = InputRegistry()

141
142
143
    (init_kwargs, inference_kwargs,
     expected_seq_count) = _get_max_dynamic_patch_info(
         init_max_dynamic_patch, inference_max_dynamic_patch)
144

145
146
    ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
    processor = dummy_registry.create_input_processor(ctx.model_config)
147
148
    max_dynamic_patch_val = _get_processed_max_dynamic_patch(
        processor, inference_kwargs)
149

150
    assert max_dynamic_patch_val == expected_seq_count
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168


@pytest.mark.parametrize(
    "mm_processor_kwargs",
    [
        # Not part of the signature
        {
            "does_not_exist": 100
        },
        # Part of the signature, not keyword only
        {
            "ctx": "something bad"
        }
    ])
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
                                            mm_processor_kwargs):
    """Ensure that input processors filter out invalid mm_processor_kwargs"""
    dummy_registry = InputRegistry()
169
    # Should filter out the init time kwargs
170
171
172
173
    ctx = build_model_context(DUMMY_MODEL_ID,
                              mm_processor_kwargs=mm_processor_kwargs)

    processor = dummy_registry.create_input_processor(ctx.model_config)
174
    # Should filter out the inference time kwargs
175
176
177
    max_dynamic_patch_val = _get_processed_max_dynamic_patch(
        processor, mm_processor_kwargs)
    assert max_dynamic_patch_val == DEFAULT_MAX_DYNAMIC_PATCH
178
179
180


### Test overrides for the dummy data
181
182
183
@pytest.mark.parametrize("max_dynamic_patch",
                         [None, MAX_DYNAMIC_PATCH_OVERRIDE])
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, max_dynamic_patch):
184
    """Ensure dummy data factories can use processor kwargs."""
185
186
    mm_processor_kwargs = None if max_dynamic_patch is None else {
        "max_dynamic_patch": max_dynamic_patch
187
    }
188
189
    expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH
                          if max_dynamic_patch is None else max_dynamic_patch)
190
191
192
193
194
195
196
197
198
    dummy_registry = InputRegistry()
    ctx = build_model_context(DUMMY_MODEL_ID,
                              mm_processor_kwargs=mm_processor_kwargs)
    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)

    # NOTE: seq_len is thrown away here since this will leverage the
    # default dummy data factory that we have patched in, whose seq
    # len is solely dependent on the value of the mm_processor_kwargs.
199
    dummy_data = dummy_registry.dummy_data_for_profiling(
200
        ctx.model_config, seq_len=-1, mm_registry=mm_registry)
201
    assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227


@pytest.mark.parametrize(
    "mm_processor_kwargs",
    [
        # Not part of the signature
        {
            "does_not_exist": 100
        },
        # Part of the signature, not keyword only
        {
            "ctx": "something bad"
        }
    ])
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
                                             mm_processor_kwargs):
    """Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
    dummy_registry = InputRegistry()
    ctx = build_model_context(DUMMY_MODEL_ID,
                              mm_processor_kwargs=mm_processor_kwargs)
    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)

    # NOTE: seq_len is thrown away here since this will leverage the
    # default dummy data factory that we have patched in, whose seq
    # len is solely dependent on the value of the mm_processor_kwargs.
228
    dummy_data = dummy_registry.dummy_data_for_profiling(
229
        ctx.model_config, seq_len=-1, mm_registry=mm_registry)
230
231
    assert len(
        dummy_data.seq_data.prompt_token_ids) == DEFAULT_MAX_DYNAMIC_PATCH
232
233
234


### Test overrides for the max token count per multimodal instance
235
236
237
@pytest.mark.parametrize("max_dynamic_patch",
                         [None, MAX_DYNAMIC_PATCH_OVERRIDE])
def test_max_tokens_kwarg_overrides(max_dynamic_patch):
238
    """Ensure max token calcs can use processor kwargs."""
239
240
    mm_processor_kwargs = None if max_dynamic_patch is None else {
        "max_dynamic_patch": max_dynamic_patch
241
    }
242
243
    expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH
                          if max_dynamic_patch is None else max_dynamic_patch)
244
245

    ctx = build_model_context(MULTIMODAL_MODEL_ID,
246
                              task="generate",
247
248
249
250
251
252
253
254
                              trust_remote_code=True,
                              mm_processor_kwargs=mm_processor_kwargs,
                              limit_mm_per_prompt={"image": 1})

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)
    # Patch the image registry for phi3v with our lambda that is compatible
    # with overrides, then ensure that calling the method correctly echos
255
    # our max_dynamic_patch value back from the mm_processor_kwargs.
256
257
258
    with patch.object(
            mm_registry._get_plugin("image"),
            "_max_mm_tokens",
259
        {mm_model_cls(): get_max_dynamic_patch},
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    ):
        max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
            ctx.model_config)

    assert expected_seq_count == max_multimodal_tokens


@pytest.mark.parametrize(
    "mm_processor_kwargs",
    [
        # Not part of the signature
        {
            "does_not_exist": 100
        },
        # Part of the signature, not keyword only
        {
            "ctx": "something bad"
        }
    ])
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
    """Ensure that max token calcs filters out invalid mm_processor_kwargs"""
    ctx = build_model_context(MULTIMODAL_MODEL_ID,
282
                              task="generate",
283
284
285
286
287
288
289
290
291
292
293
294
                              trust_remote_code=True,
                              mm_processor_kwargs=mm_processor_kwargs,
                              limit_mm_per_prompt={"image": 1})

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)

    # Similar before, but since these kwargs get filtered,
    # we always get our default value back.
    with patch.object(
            mm_registry._get_plugin("image"),
            "_max_mm_tokens",
295
        {mm_model_cls(): get_max_dynamic_patch},
296
297
298
299
    ):
        max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
            ctx.model_config)

300
    assert max_multimodal_tokens == DEFAULT_MAX_DYNAMIC_PATCH
301
302
303


### Test overrides for the mapper
304
305
306
307
@pytest.mark.parametrize(
    "max_dynamic_patch",
    [DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE])
def test_default_mapper_with_processor_kwargs(image_assets, max_dynamic_patch):
308
309
310
311
    """Ensure that the mapper processor kwargs can fall back to HF models."""
    # NOTE - we don't validate bad inputs for the default mapper, because it's
    # through the automodel interface in transformers, so we can't easily
    # inspect what kwargs are or are not allowed.
312
313
314
315
316
317
    ctx = build_model_context(
        MULTIMODAL_MODEL_ID,
        task="generate",
        trust_remote_code=True,
        mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch},
        limit_mm_per_prompt={"image": 1})
318
319
320
321
322
323
324
325

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)

    image = image_assets[0].pil_image
    mm_inputs = {"image": image}

    mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
326
327
    # pixel vals should have shape: [batch, max_dynamic_patch+1, ...]
    assert mapped_inputs["pixel_values"].shape[1] == max_dynamic_patch + 1
328
329


330
331
332
333
334
335
336
337
@pytest.mark.parametrize(
    "init_max_dynamic_patch,inference_max_dynamic_patch", [
        (None, None),
        (MAX_DYNAMIC_PATCH_OVERRIDE, None),
        (DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE),
    ])
def test_custom_mapper_kwarg_overrides(image_assets, init_max_dynamic_patch,
                                       inference_max_dynamic_patch):
338
    """Ensure custom mappers can use processor kwargs."""
339
340
341
    (init_kwargs, inference_kwargs,
     expected_seq_count) = _get_max_dynamic_patch_info(
         init_max_dynamic_patch, inference_max_dynamic_patch)
342

343
    ctx = build_model_context(MULTIMODAL_MODEL_ID,
344
                              task="generate",
345
                              trust_remote_code=True,
346
                              mm_processor_kwargs=init_kwargs,
347
348
349
350
351
352
353
                              limit_mm_per_prompt={"image": 1})

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)
    image = image_assets[0].pil_image
    mm_inputs = {"image": image}

354
355
    # Patch the image registry for phi3v with our lambda that is compatible
    # with overrides, then ensure that calling the method correctly echos
356
    # our max_dynamic_patch value back from the mm_processor_kwargs.
357
358
359
360
    mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
        mm_model_cls())
    mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
                                          inference_kwargs)
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379

    assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1


@pytest.mark.parametrize(
    "mm_processor_kwargs",
    [
        # Not part of the signature
        {
            "does_not_exist": 100
        },
        # Part of the signature, not keyword only
        {
            "ctx": "something bad"
        }
    ])
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
                                                mm_processor_kwargs):
    """Ensure that custom mappers filters out invalid mm_processor_kwargs"""
380
    # Should filter out the init time kwargs
381
    ctx = build_model_context(MULTIMODAL_MODEL_ID,
382
                              task="generate",
383
384
385
386
387
388
389
390
391
                              trust_remote_code=True,
                              mm_processor_kwargs=mm_processor_kwargs,
                              limit_mm_per_prompt={"image": 1})

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)
    image = image_assets[0].pil_image
    mm_inputs = {"image": image}

392
393
    # Patch the image registry for phi3v with our lambda that is compatible
    # with overrides, then ensure that calling the method correctly echos
394
    # our max_dynamic_patch value back from the mm_processor_kwargs.
395
396
397
398
399
    mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
        mm_model_cls())
    # Should filter out the inference time kwargs
    mapped_inputs = mm_registry.map_input(
        ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
400

401
402
    assert mapped_inputs["pixel_values"].shape[1] == (
        DEFAULT_MAX_DYNAMIC_PATCH + 1)