test_common.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Set as AbstractSet
5
6
7
8
9
10
11
from functools import partial

import numpy as np
import pytest
from PIL import Image

from vllm.config import ModelConfig
12
13
14
15
16
17
from vllm.config.multimodal import (
    AudioDummyOptions,
    BaseDummyOptions,
    ImageDummyOptions,
    VideoDummyOptions,
)
18
19
from vllm.inputs import MultiModalDataDict, MultiModalInput
from vllm.multimodal import MULTIMODAL_REGISTRY
20
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
21
22
from vllm.multimodal.inputs import batched_tensors_equal
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
23
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
24
from vllm.utils.mistral import is_mistral_tokenizer
25
26

from ....multimodal.utils import random_audio, random_image, random_video
27
28
29
30
31
from ...registry import (
    _MULTIMODAL_EXAMPLE_MODELS,
    _TRANSFORMERS_BACKEND_MODELS,
    HF_EXAMPLE_MODELS,
)
32
33


34
def add_video_metadata(mm_data: MultiModalDataDict) -> MultiModalDataDict:
35
    """
36
    Add metadata to video mm_data
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    """

    def create_metadata(frames: np.ndarray):
        num_frames = len(frames)
        return {
            "total_num_frames": num_frames,
            "fps": 2.0,
            "duration": num_frames / 2.0,
            "video_backend": "opencv",
            "frames_indices": list(range(num_frames)),
            "do_sample_frames": True,
        }

    # Ensure video metadata is included
    if "video" in mm_data:
        video = mm_data["video"]
        if isinstance(video, list):
            # multiple videos
            mm_data["video"] = [(vid, create_metadata(vid)) for vid in video]
        else:
            # single video
            mm_data["video"] = (video, create_metadata(video))
    return mm_data


62
63
64
65
66
67
68
69
70
71
72
73
74
def glmasr_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
    """
    Patch the multimodal data for GLM-ASR model.
    GLM-ASR requires text and audio to match 1:1, so we limit audio to 1.
    """
    if "audio" in mm_data:
        audio = mm_data["audio"]
        if isinstance(audio, list) and len(audio) > 1:
            # Limit to single audio to match text requirement
            mm_data["audio"] = [audio[0]]
    return mm_data


75
76
77
78
79
80
81
82
_IGNORE_MM_KEYS = {
    # In Ultravox, the audio_features can be different depending on padding
    # The slight difference should not be a problem though, since
    # attention_mask lets us ignore the difference.
    "ultravox": {"audio_features"},
}

MM_DATA_PATCHES = {
83
    "glmasr": glmasr_patch_mm_data,
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
}


def _iter_model_ids_to_test(model_arch_list: AbstractSet[str]):
    for model_arch in model_arch_list:
        model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
        yield model_info.default

        for extra_type, extra_model_id in model_info.extras.items():
            if "fp" in extra_type:
                continue  # Redundant to test quantized models

            yield extra_model_id


def _get_model_ids_to_test(model_arch_list: AbstractSet[str]):
    return list(_iter_model_ids_to_test(model_arch_list))


def get_model_ids_to_test():
    transformers_arch_ids = {
        model_id
        for info in _TRANSFORMERS_BACKEND_MODELS.values()
        for model_id in (info.default, *info.extras.values())
    }
    vllm_only_archs = {
        arch
        for arch, info in _MULTIMODAL_EXAMPLE_MODELS.items()
        if not any(
            model_id in transformers_arch_ids
            for model_id in (info.default, *info.extras.values())
        )
    }

    return _get_model_ids_to_test(vllm_only_archs)


def get_text_token_prompts(
    processor: BaseMultiModalProcessor,
    mm_data: MultiModalDataDict,
):
    dummy_inputs = processor.dummy_inputs
126
    tokenizer: TokenizerLike = processor.info.get_tokenizer()
127
128
    model_config = processor.info.ctx.model_config

129
130
131
    if processor.info.data_parser.video_needs_metadata:
        mm_data = add_video_metadata(mm_data)

132
133
134
135
    model_type = model_config.hf_config.model_type
    if model_type in MM_DATA_PATCHES:
        mm_data = MM_DATA_PATCHES[model_type](mm_data)

136
    parsed_data = processor.info.parse_mm_data(mm_data)
137
138
    mm_counts = {k: len(vs) for k, vs in parsed_data.items()}

139
    if is_mistral_tokenizer(tokenizer):
140
141
142
143
144
145
        inputs = dummy_inputs.get_dummy_processor_inputs(
            model_config.max_model_len,
            mm_counts,
            mm_options={},
            # Assume all Mistral models define this extra argument
            mm_data=mm_data,  # type: ignore[call-arg]
146
147
148
149
150
        )
    else:
        inputs = dummy_inputs.get_dummy_processor_inputs(
            model_config.max_model_len,
            mm_counts,
151
            mm_options={},
152
        )
153
154
155
156
157
158
159
160
161
162
163
164
165
166

    text_prompt: str | None
    token_prompt: list[int]
    if isinstance(inputs.prompt, list):
        text_prompt = None
        token_prompt = inputs.prompt
    elif isinstance(inputs.prompt, str):
        text_prompt = inputs.prompt
        token_prompt = tokenizer.encode(
            text_prompt,
            **processor.info.get_default_tok_params().get_encode_kwargs(),
        )
    else:
        raise TypeError(type(inputs.prompt))
167
168
169
170

    return text_prompt, token_prompt


171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def random_vision_chunk(
    rng: np.random.RandomState,
    min_wh: int,
    max_wh: int,
    min_frames: int,
    max_frames: int,
) -> dict:
    num_frames = rng.randint(min_frames, max_frames + 1)
    if num_frames == 1:
        # Single image chunk
        wh = rng.randint(min_wh, max_wh + 1)
        image = random_image(rng, wh, wh + 1)
        return {"type": "image", "image": image}
    frames = []
    for _ in range(num_frames):
        wh = rng.randint(min_wh, max_wh + 1)
        frame = rng.randint(0, 256, size=(wh, wh, 3), dtype=np.uint8)
        frames.append(frame)
    video_array = np.stack(frames, axis=0)
    return {"type": "video_chunk", "video_chunk": video_array}


193
def _test_processing_correctness(
194
    model_id_or_arch: str,
195
196
197
198
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
199
200
201
202
203
204
205
    if model_id_or_arch in HF_EXAMPLE_MODELS.get_supported_archs():
        # Use model architecture to get the default model id
        model_info = HF_EXAMPLE_MODELS.get_hf_info(model_id_or_arch)
        model_id = model_info.default
    else:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id_or_arch)
        model_id = model_id_or_arch
206
    model_info.check_available_online(on_fail="skip")
207
208
209
210
211
    model_info.check_transformers_version(
        on_fail="skip",
        check_max_version=False,
        check_version_reason="vllm",
    )
212

213
214
215
216
217
218
219
220
221
222
223
224
    model_config = ModelConfig(
        model_id,
        tokenizer=model_info.tokenizer or model_id,
        tokenizer_mode=model_info.tokenizer_mode,
        revision=model_info.revision,
        trust_remote_code=model_info.trust_remote_code,
        hf_overrides=model_info.hf_overrides,
        skip_tokenizer_init=model_info.require_embed_inputs,
        enable_prompt_embeds=model_info.require_embed_inputs,
        enable_mm_embeds=model_info.require_embed_inputs,
        enforce_eager=model_info.enforce_eager,
        dtype=model_info.dtype,
225
    )
226
227
228
    # Ensure that the cache can fit all of the data
    # (set after because ModelConfig would set it to 0 for encoder-decoder models)
    model_config.multimodal_config.mm_processor_cache_gb = 2048
229
230

    model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
231
    factories = model_cls._processor_factory
232
233
234
235
    ctx = InputProcessingContext(
        model_config,
        tokenizer=cached_tokenizer_from_config(model_config),
    )
236
    cache = MultiModalProcessorOnlyCache(model_config)
237

238
239
    processing_info = factories.info(ctx)
    supported_mm_limits = processing_info.get_supported_mm_limits()
240
241
    # Keep integer limits for local data generation
    limit_mm_per_prompt_ints = {
242
243
244
245
        modality: 3 if limit is None else limit
        for modality, limit in supported_mm_limits.items()
    }

246
247
248
249
250
251
252
253
254
255
256
257
258
259
    def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
        if modality == "video":
            return VideoDummyOptions(count=count)
        if modality == "image":
            return ImageDummyOptions(count=count)
        if modality == "audio":
            return AudioDummyOptions(count=count)
        return BaseDummyOptions(count=count)

    # Assign normalized DummyOptions to the model config
    model_config.get_multimodal_config().limit_per_prompt = {
        modality: _to_dummy_options(modality, count)
        for modality, count in limit_mm_per_prompt_ints.items()
    }
260

261
262
263
264
265
    baseline_processor = factories.build_processor(ctx, cache=None)
    cached_processor = factories.build_processor(ctx, cache=cache)

    rng = np.random.RandomState(0)

266
267
    # GLM-ASR requires a minimum audio length of 70ms
    min_audio_len = 512 if model_config.hf_config.model_type != "glmasr" else 1120
268
269
270
    input_to_hit = {
        "image": Image.new("RGB", size=(128, 128)),
        "video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
271
        "audio": (np.zeros((min_audio_len,)), 16000),
272
        "vision_chunk": {"type": "image", "image": Image.new("RGB", size=(128, 128))},
273
274
    }
    input_factory = {
275
276
277
278
        "image": partial(random_image, rng, min_wh=128, max_wh=256),
        "video": partial(
            random_video, rng, min_frames=2, max_frames=16, min_wh=128, max_wh=256
        ),
279
280
281
282
283
284
285
        "audio": partial(
            random_audio,
            rng,
            min_len=min_audio_len,
            max_len=min_audio_len + 512,
            sr=16000,
        ),
286
287
288
        "vision_chunk": partial(
            random_vision_chunk, rng, min_wh=128, max_wh=256, min_frames=1, max_frames=1
        ),
289
290
291
292
    }

    for batch_idx in range(num_batches):
        mm_data = {
293
294
295
296
            k: [
                (input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
                for _ in range(rng.randint(limit + 1))
            ]
297
            for k, limit in limit_mm_per_prompt_ints.items()
298
299
300
301
302
303
304
305
306
307
        }

        # Drop unnecessary keys and test single -> multi conversion
        if rng.rand() < simplify_rate:
            for k in list(mm_data.keys()):
                if not mm_data[k]:
                    del mm_data[k]
                elif len(mm_data[k]) == 1:
                    mm_data[k] = mm_data[k][0]

308
309
310
311
312
313
314
315
316
317
        _test_processing_correctness_one(
            model_config,
            mm_data,
            baseline_processor,
            cached_processor,
            batch_idx,
        )


def _test_processing_correctness_one(
318
319
320
321
322
323
    model_config: ModelConfig,
    mm_data: MultiModalDataDict,
    baseline_processor: BaseMultiModalProcessor,
    cached_processor: BaseMultiModalProcessor,
    batch_idx: int,
):
324
325
    model_type = model_config.hf_config.model_type

326
    text_prompt, token_prompt = get_text_token_prompts(baseline_processor, mm_data)
327
    mm_items = baseline_processor.info.parse_mm_data(mm_data)
328
    ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
329

330
    baseline_tokenized_result = baseline_processor(
331
        token_prompt,
332
        mm_items=mm_items,
333
334
335
        hf_processor_mm_kwargs={},
    )

336
    cached_tokenized_result = cached_processor(
337
        token_prompt,
338
        mm_items=mm_items,
339
340
341
        hf_processor_mm_kwargs={},
    )

342
    _assert_inputs_equal(
343
344
        baseline_tokenized_result,
        cached_tokenized_result,
345
        ignore_mm_keys=ignore_mm_keys,
346
        msg=f"Failed ({batch_idx=}, {token_prompt=}, {mm_data=})",
347
    )
348

349
    if text_prompt is not None:
350
        baseline_text_result = baseline_processor(
351
            text_prompt,
352
            mm_items=mm_items,
353
354
            hf_processor_mm_kwargs={},
        )
355
        cached_text_result = cached_processor(
356
            text_prompt,
357
            mm_items=mm_items,
358
359
360
361
362
363
364
365
366
367
368
369
370
371
            hf_processor_mm_kwargs={},
        )

        _assert_inputs_equal(
            baseline_text_result,
            cached_text_result,
            ignore_mm_keys=ignore_mm_keys,
            msg=f"Failed ({batch_idx=}, {text_prompt=}, {mm_data=})",
        )

        _assert_inputs_equal(
            baseline_text_result,
            baseline_tokenized_result,
            ignore_mm_keys=ignore_mm_keys,
372
            msg=f"Failed ({batch_idx=}, {text_prompt=}, {token_prompt=}, {mm_data=})",
373
374
375
376
377
378
        )

        _assert_inputs_equal(
            cached_text_result,
            cached_tokenized_result,
            ignore_mm_keys=ignore_mm_keys,
379
            msg=f"Failed ({batch_idx=}, {text_prompt=}, {token_prompt=}, {mm_data=})",
380
381
        )

382

383
@pytest.mark.parametrize("model_id", get_model_ids_to_test())
384
385
386
387
388
389
390
391
392
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
def test_processing_correctness(
    model_id: str,
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
393
    if model_id == "google/gemma-3n-E2B-it":
394
395
396
397
398
        pytest.skip("Fix later")
    if model_id == "OpenGVLab/InternVL2-2B":
        pytest.skip("Fix later")
    if model_id == "jinaai/jina-reranker-m0":
        pytest.skip("Fix later")
399
400
401
402
403
    if model_id in {"Qwen/Qwen-VL", "Qwen/Qwen-VL-Chat"}:
        pytest.skip(
            "Qwen-VL tokenizer requires downloading a font file from "
            "servers that often refuse connections in CI"
        )
404
405
    if model_id == "mistralai/Voxtral-Mini-4B-Realtime-2602":
        pytest.skip(
406
            "Voxtral Realtime doesn't make use of any place-holder "
407
408
409
410
            "tokens and hence cannot pass the processing "
            "correctness test as is. Let's revisit adapting this "
            "test once more realtime models exist."
        )
411

412
413
414
415
416
417
    _test_processing_correctness(
        model_id,
        hit_rate=hit_rate,
        num_batches=num_batches,
        simplify_rate=simplify_rate,
    )
418
419


420
def _assert_inputs_equal(
421
422
    a: MultiModalInput,
    b: MultiModalInput,
423
    *,
424
    ignore_mm_keys: set[str] | None = None,
425
    msg: str = "",
426
):
427
428
429
    if ignore_mm_keys is None:
        ignore_mm_keys = set()

430
431
432
    ignore_prompt_keys = ("prompt", "mm_kwargs")
    a_rest = {k: v for k, v in a.items() if k not in ignore_prompt_keys}
    b_rest = {k: v for k, v in b.items() if k not in ignore_prompt_keys}
433
434
435
436
437

    assert a_rest == b_rest, msg

    a_data = a["mm_kwargs"].get_data()
    b_data = b["mm_kwargs"].get_data()
438
439

    for key in ignore_mm_keys:
440
441
        a_data.pop(key, None)
        b_data.pop(key, None)
442

443
    assert batched_tensors_equal(a_data, b_data), msg