test_common.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Set as AbstractSet
5
6
7
8
9
10
11
from functools import partial

import numpy as np
import pytest
from PIL import Image

from vllm.config import ModelConfig
12
13
14
15
16
17
from vllm.config.multimodal import (
    AudioDummyOptions,
    BaseDummyOptions,
    ImageDummyOptions,
    VideoDummyOptions,
)
18
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
19
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
20
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal
21
22
23
24
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    InputProcessingContext,
)
25
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
26
from vllm.utils.mistral import is_mistral_tokenizer
27
28

from ....multimodal.utils import random_audio, random_image, random_video
29
30
31
32
33
from ...registry import (
    _MULTIMODAL_EXAMPLE_MODELS,
    _TRANSFORMERS_BACKEND_MODELS,
    HF_EXAMPLE_MODELS,
)
34
35


36
def add_video_metadata(mm_data: MultiModalDataDict) -> MultiModalDataDict:
37
    """
38
    Add metadata to video mm_data
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    """

    def create_metadata(frames: np.ndarray):
        num_frames = len(frames)
        return {
            "total_num_frames": num_frames,
            "fps": 2.0,
            "duration": num_frames / 2.0,
            "video_backend": "opencv",
            "frames_indices": list(range(num_frames)),
            "do_sample_frames": True,
        }

    # Ensure video metadata is included
    if "video" in mm_data:
        video = mm_data["video"]
        if isinstance(video, list):
            # multiple videos
            mm_data["video"] = [(vid, create_metadata(vid)) for vid in video]
        else:
            # single video
            mm_data["video"] = (video, create_metadata(video))
    return mm_data


64
65
66
67
68
69
70
71
72
73
74
75
76
def glmasr_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
    """
    Patch the multimodal data for GLM-ASR model.
    GLM-ASR requires text and audio to match 1:1, so we limit audio to 1.
    """
    if "audio" in mm_data:
        audio = mm_data["audio"]
        if isinstance(audio, list) and len(audio) > 1:
            # Limit to single audio to match text requirement
            mm_data["audio"] = [audio[0]]
    return mm_data


77
78
79
80
81
82
83
84
_IGNORE_MM_KEYS = {
    # In Ultravox, the audio_features can be different depending on padding
    # The slight difference should not be a problem though, since
    # attention_mask lets us ignore the difference.
    "ultravox": {"audio_features"},
}

MM_DATA_PATCHES = {
85
    "glmasr": glmasr_patch_mm_data,
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
}


def _iter_model_ids_to_test(model_arch_list: AbstractSet[str]):
    for model_arch in model_arch_list:
        model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
        yield model_info.default

        for extra_type, extra_model_id in model_info.extras.items():
            if "fp" in extra_type:
                continue  # Redundant to test quantized models

            yield extra_model_id


def _get_model_ids_to_test(model_arch_list: AbstractSet[str]):
    return list(_iter_model_ids_to_test(model_arch_list))


def get_model_ids_to_test():
    transformers_arch_ids = {
        model_id
        for info in _TRANSFORMERS_BACKEND_MODELS.values()
        for model_id in (info.default, *info.extras.values())
    }
    vllm_only_archs = {
        arch
        for arch, info in _MULTIMODAL_EXAMPLE_MODELS.items()
        if not any(
            model_id in transformers_arch_ids
            for model_id in (info.default, *info.extras.values())
        )
    }

    return _get_model_ids_to_test(vllm_only_archs)


def get_text_token_prompts(
    processor: BaseMultiModalProcessor,
    mm_data: MultiModalDataDict,
):
    dummy_inputs = processor.dummy_inputs
128
    tokenizer: TokenizerLike = processor.info.get_tokenizer()
129
130
    model_config = processor.info.ctx.model_config

131
132
133
    if processor.info.data_parser.video_needs_metadata:
        mm_data = add_video_metadata(mm_data)

134
135
136
137
    model_type = model_config.hf_config.model_type
    if model_type in MM_DATA_PATCHES:
        mm_data = MM_DATA_PATCHES[model_type](mm_data)

138
    parsed_data = processor.info.parse_mm_data(mm_data)
139
140
    mm_counts = {k: len(vs) for k, vs in parsed_data.items()}

141
    if is_mistral_tokenizer(tokenizer):
142
143
144
145
146
147
        inputs = dummy_inputs.get_dummy_processor_inputs(
            model_config.max_model_len,
            mm_counts,
            mm_options={},
            # Assume all Mistral models define this extra argument
            mm_data=mm_data,  # type: ignore[call-arg]
148
149
150
151
152
        )
    else:
        inputs = dummy_inputs.get_dummy_processor_inputs(
            model_config.max_model_len,
            mm_counts,
153
            mm_options={},
154
        )
155
156
157
158
159
160
161
162
163
164
165
166
167
168

    text_prompt: str | None
    token_prompt: list[int]
    if isinstance(inputs.prompt, list):
        text_prompt = None
        token_prompt = inputs.prompt
    elif isinstance(inputs.prompt, str):
        text_prompt = inputs.prompt
        token_prompt = tokenizer.encode(
            text_prompt,
            **processor.info.get_default_tok_params().get_encode_kwargs(),
        )
    else:
        raise TypeError(type(inputs.prompt))
169
170
171
172

    return text_prompt, token_prompt


173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def random_vision_chunk(
    rng: np.random.RandomState,
    min_wh: int,
    max_wh: int,
    min_frames: int,
    max_frames: int,
) -> dict:
    num_frames = rng.randint(min_frames, max_frames + 1)
    if num_frames == 1:
        # Single image chunk
        wh = rng.randint(min_wh, max_wh + 1)
        image = random_image(rng, wh, wh + 1)
        return {"type": "image", "image": image}
    frames = []
    for _ in range(num_frames):
        wh = rng.randint(min_wh, max_wh + 1)
        frame = rng.randint(0, 256, size=(wh, wh, 3), dtype=np.uint8)
        frames.append(frame)
    video_array = np.stack(frames, axis=0)
    return {"type": "video_chunk", "video_chunk": video_array}


195
def _test_processing_correctness(
196
    model_id_or_arch: str,
197
198
199
200
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
201
202
203
204
205
206
207
    if model_id_or_arch in HF_EXAMPLE_MODELS.get_supported_archs():
        # Use model architecture to get the default model id
        model_info = HF_EXAMPLE_MODELS.get_hf_info(model_id_or_arch)
        model_id = model_info.default
    else:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id_or_arch)
        model_id = model_id_or_arch
208
    model_info.check_available_online(on_fail="skip")
209
210
211
212
213
    model_info.check_transformers_version(
        on_fail="skip",
        check_max_version=False,
        check_version_reason="vllm",
    )
214

215
216
217
218
219
220
221
222
223
224
225
226
    model_config = ModelConfig(
        model_id,
        tokenizer=model_info.tokenizer or model_id,
        tokenizer_mode=model_info.tokenizer_mode,
        revision=model_info.revision,
        trust_remote_code=model_info.trust_remote_code,
        hf_overrides=model_info.hf_overrides,
        skip_tokenizer_init=model_info.require_embed_inputs,
        enable_prompt_embeds=model_info.require_embed_inputs,
        enable_mm_embeds=model_info.require_embed_inputs,
        enforce_eager=model_info.enforce_eager,
        dtype=model_info.dtype,
227
    )
228
229
230
    # Ensure that the cache can fit all of the data
    # (set after because ModelConfig would set it to 0 for encoder-decoder models)
    model_config.multimodal_config.mm_processor_cache_gb = 2048
231
232

    model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
233
    factories = model_cls._processor_factory
234
235
236
237
    ctx = InputProcessingContext(
        model_config,
        tokenizer=cached_tokenizer_from_config(model_config),
    )
238
    cache = MultiModalProcessorOnlyCache(model_config)
239

240
241
    processing_info = factories.info(ctx)
    supported_mm_limits = processing_info.get_supported_mm_limits()
242
243
    # Keep integer limits for local data generation
    limit_mm_per_prompt_ints = {
244
245
246
247
        modality: 3 if limit is None else limit
        for modality, limit in supported_mm_limits.items()
    }

248
249
250
251
252
253
254
255
256
257
258
259
260
261
    def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
        if modality == "video":
            return VideoDummyOptions(count=count)
        if modality == "image":
            return ImageDummyOptions(count=count)
        if modality == "audio":
            return AudioDummyOptions(count=count)
        return BaseDummyOptions(count=count)

    # Assign normalized DummyOptions to the model config
    model_config.get_multimodal_config().limit_per_prompt = {
        modality: _to_dummy_options(modality, count)
        for modality, count in limit_mm_per_prompt_ints.items()
    }
262

263
264
265
266
267
    baseline_processor = factories.build_processor(ctx, cache=None)
    cached_processor = factories.build_processor(ctx, cache=cache)

    rng = np.random.RandomState(0)

268
269
    # GLM-ASR requires a minimum audio length of 70ms
    min_audio_len = 512 if model_config.hf_config.model_type != "glmasr" else 1120
270
271
272
    input_to_hit = {
        "image": Image.new("RGB", size=(128, 128)),
        "video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
273
        "audio": (np.zeros((min_audio_len,)), 16000),
274
        "vision_chunk": {"type": "image", "image": Image.new("RGB", size=(128, 128))},
275
276
    }
    input_factory = {
277
278
279
280
        "image": partial(random_image, rng, min_wh=128, max_wh=256),
        "video": partial(
            random_video, rng, min_frames=2, max_frames=16, min_wh=128, max_wh=256
        ),
281
282
283
284
285
286
287
        "audio": partial(
            random_audio,
            rng,
            min_len=min_audio_len,
            max_len=min_audio_len + 512,
            sr=16000,
        ),
288
289
290
        "vision_chunk": partial(
            random_vision_chunk, rng, min_wh=128, max_wh=256, min_frames=1, max_frames=1
        ),
291
292
293
294
    }

    for batch_idx in range(num_batches):
        mm_data = {
295
296
297
298
            k: [
                (input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
                for _ in range(rng.randint(limit + 1))
            ]
299
            for k, limit in limit_mm_per_prompt_ints.items()
300
301
302
303
304
305
306
307
308
309
        }

        # Drop unnecessary keys and test single -> multi conversion
        if rng.rand() < simplify_rate:
            for k in list(mm_data.keys()):
                if not mm_data[k]:
                    del mm_data[k]
                elif len(mm_data[k]) == 1:
                    mm_data[k] = mm_data[k][0]

310
311
312
313
314
315
316
317
318
319
        _test_processing_correctness_one(
            model_config,
            mm_data,
            baseline_processor,
            cached_processor,
            batch_idx,
        )


def _test_processing_correctness_one(
320
321
322
323
324
325
    model_config: ModelConfig,
    mm_data: MultiModalDataDict,
    baseline_processor: BaseMultiModalProcessor,
    cached_processor: BaseMultiModalProcessor,
    batch_idx: int,
):
326
327
    model_type = model_config.hf_config.model_type

328
    text_prompt, token_prompt = get_text_token_prompts(baseline_processor, mm_data)
329
    mm_items = baseline_processor.info.parse_mm_data(mm_data)
330
    ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
331

332
    baseline_tokenized_result = baseline_processor(
333
        token_prompt,
334
        mm_items=mm_items,
335
336
337
        hf_processor_mm_kwargs={},
    )

338
    cached_tokenized_result = cached_processor(
339
        token_prompt,
340
        mm_items=mm_items,
341
342
343
        hf_processor_mm_kwargs={},
    )

344
    _assert_inputs_equal(
345
346
        baseline_tokenized_result,
        cached_tokenized_result,
347
        ignore_mm_keys=ignore_mm_keys,
348
        msg=f"Failed ({batch_idx=}, {token_prompt=}, {mm_data=})",
349
    )
350

351
    if text_prompt is not None:
352
        baseline_text_result = baseline_processor(
353
            text_prompt,
354
            mm_items=mm_items,
355
356
            hf_processor_mm_kwargs={},
        )
357
        cached_text_result = cached_processor(
358
            text_prompt,
359
            mm_items=mm_items,
360
361
362
363
364
365
366
367
368
369
370
371
372
373
            hf_processor_mm_kwargs={},
        )

        _assert_inputs_equal(
            baseline_text_result,
            cached_text_result,
            ignore_mm_keys=ignore_mm_keys,
            msg=f"Failed ({batch_idx=}, {text_prompt=}, {mm_data=})",
        )

        _assert_inputs_equal(
            baseline_text_result,
            baseline_tokenized_result,
            ignore_mm_keys=ignore_mm_keys,
374
            msg=f"Failed ({batch_idx=}, {text_prompt=}, {token_prompt=}, {mm_data=})",
375
376
377
378
379
380
        )

        _assert_inputs_equal(
            cached_text_result,
            cached_tokenized_result,
            ignore_mm_keys=ignore_mm_keys,
381
            msg=f"Failed ({batch_idx=}, {text_prompt=}, {token_prompt=}, {mm_data=})",
382
383
        )

384

385
@pytest.mark.parametrize("model_id", get_model_ids_to_test())
386
387
388
389
390
391
392
393
394
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
def test_processing_correctness(
    model_id: str,
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
395
    if model_id == "google/gemma-3n-E2B-it":
396
397
398
399
400
        pytest.skip("Fix later")
    if model_id == "OpenGVLab/InternVL2-2B":
        pytest.skip("Fix later")
    if model_id == "jinaai/jina-reranker-m0":
        pytest.skip("Fix later")
401
402
403
404
405
    if model_id in {"Qwen/Qwen-VL", "Qwen/Qwen-VL-Chat"}:
        pytest.skip(
            "Qwen-VL tokenizer requires downloading a font file from "
            "servers that often refuse connections in CI"
        )
406
407
    if model_id == "mistralai/Voxtral-Mini-4B-Realtime-2602":
        pytest.skip(
408
            "Voxtral Realtime doesn't make use of any place-holder "
409
410
411
412
            "tokens and hence cannot pass the processing "
            "correctness test as is. Let's revisit adapting this "
            "test once more realtime models exist."
        )
413

414
415
416
417
418
419
    _test_processing_correctness(
        model_id,
        hit_rate=hit_rate,
        num_batches=num_batches,
        simplify_rate=simplify_rate,
    )
420
421


422
def _assert_inputs_equal(
423
424
    a: MultiModalInputs,
    b: MultiModalInputs,
425
    *,
426
    ignore_mm_keys: set[str] | None = None,
427
    msg: str = "",
428
):
429
430
431
    if ignore_mm_keys is None:
        ignore_mm_keys = set()

432
433
434
    ignore_prompt_keys = ("prompt", "mm_kwargs")
    a_rest = {k: v for k, v in a.items() if k not in ignore_prompt_keys}
    b_rest = {k: v for k, v in b.items() if k not in ignore_prompt_keys}
435
436
437
438
439

    assert a_rest == b_rest, msg

    a_data = a["mm_kwargs"].get_data()
    b_data = b["mm_kwargs"].get_data()
440
441

    for key in ignore_mm_keys:
442
443
        a_data.pop(key, None)
        b_data.pop(key, None)
444

445
    assert batched_tensors_equal(a_data, b_data), msg