test_common.py 12.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from functools import partial
4
from typing import Optional, Union
5
6

import numpy as np
zhuwenwen's avatar
zhuwenwen committed
7
import os
8
import pytest
9
10
11
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
                                                       UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
12
from PIL import Image
13
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
14
15
16

from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
17
18
19
20
21
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
                                               cached_tokenizer_from_config)
22
23

from ....multimodal.utils import random_audio, random_image, random_video
24
from ...registry import HF_EXAMPLE_MODELS
zhuwenwen's avatar
zhuwenwen committed
25
from ....utils import models_path_prefix
26
27
28
29
30
31
32


def _test_processing_correctness(
    model_id: str,
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
33
    ignore_mm_keys: Optional[set[str]] = None,
34
):
35
36
37
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")
38
39
40
41

    model_config = ModelConfig(
        model_id,
        task="auto",
42
43
        tokenizer=model_info.tokenizer or model_id,
        tokenizer_mode=model_info.tokenizer_mode,
44
        trust_remote_code=model_info.trust_remote_code,
45
46
47
        seed=0,
        dtype="float16",
        revision=None,
48
        hf_overrides=model_info.hf_overrides,
49
50
51
52
53
54
    )

    model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
    factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
    ctx = InputProcessingContext(
        model_config,
55
        tokenizer=cached_tokenizer_from_config(model_config),
56
57
    )
    # Ensure that it can fit all of the data
58
    cache = ProcessingCache(capacity_gb=2048)
59

60
61
62
63
64
65
66
67
68
    processing_info = factories.info(ctx)
    supported_mm_limits = processing_info.get_supported_mm_limits()
    limit_mm_per_prompt = {
        modality: 3 if limit is None else limit
        for modality, limit in supported_mm_limits.items()
    }

    model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    baseline_processor = factories.build_processor(ctx, cache=None)
    cached_processor = factories.build_processor(ctx, cache=cache)
    dummy_inputs = baseline_processor.dummy_inputs
    tokenizer = baseline_processor.info.get_tokenizer()

    rng = np.random.RandomState(0)

    input_to_hit = {
        "image": Image.new("RGB", size=(128, 128)),
        "video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
        "audio": (np.zeros((512, )), 16000),
    }
    input_factory = {
        "image":
        partial(random_image, rng, min_wh=128, max_wh=256),
        "video":
        partial(random_video,
                rng,
                min_frames=2,
                max_frames=8,
                min_wh=128,
                max_wh=256),
        "audio":
        partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
    }

    for batch_idx in range(num_batches):
        mm_data = {
            k:
            [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
99
             for _ in range(rng.randint(limit + 1))]
100
            for k, limit in limit_mm_per_prompt.items()
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        }

        mm_counts = {k: len(vs) for k, vs in mm_data.items()}
        prompt = dummy_inputs.get_dummy_processor_inputs(
            model_config.max_model_len,
            mm_counts,
        ).prompt_text

        # Drop unnecessary keys and test single -> multi conversion
        if rng.rand() < simplify_rate:
            for k in list(mm_data.keys()):
                if not mm_data[k]:
                    del mm_data[k]
                elif len(mm_data[k]) == 1:
                    mm_data[k] = mm_data[k][0]

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        if isinstance(tokenizer, MistralTokenizer):
            _test_processing_correctness_mistral(
                model_config,
                tokenizer,
                prompt,
                mm_data,
                baseline_processor,
                cached_processor,
                batch_idx,
                ignore_mm_keys=ignore_mm_keys,
            )
        else:
            _test_processing_correctness_hf(
                model_config,
                tokenizer,
                prompt,
                mm_data,
                baseline_processor,
                cached_processor,
                batch_idx,
                ignore_mm_keys=ignore_mm_keys,
            )


def _test_processing_correctness_hf(
    model_config: ModelConfig,
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    prompt: str,
    mm_data: MultiModalDataDict,
    baseline_processor: BaseMultiModalProcessor,
    cached_processor: BaseMultiModalProcessor,
    batch_idx: int,
149
    ignore_mm_keys: Optional[set[str]] = None,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
):
    if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
        # For some multimodal models, tokenizer will always add bos_token
        # at the beginning of prompt by default, causing hf_processor outputs
        # incorrect token ids. So we need use `add_special_tokens=False` here
        # to leave bos_token to be added by the processor.
        token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
    else:
        token_prompt = tokenizer.encode(prompt)

    baseline_result = baseline_processor.apply(
        prompt,
        mm_data=mm_data,
        hf_processor_mm_kwargs={},
    )
    cached_result = cached_processor.apply(
        prompt,
        mm_data=mm_data,
        hf_processor_mm_kwargs={},
    )

171
    _assert_inputs_equal(
172
173
        baseline_result,
        cached_result,
174
175
176
        ignore_mm_keys=ignore_mm_keys,
        msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
    )
177
178
179
180
181
182
183

    baseline_tokenized_result = baseline_processor.apply(
        token_prompt,
        mm_data=mm_data,
        hf_processor_mm_kwargs={},
    )

184
    _assert_inputs_equal(
185
186
        baseline_result,
        baseline_tokenized_result,
187
188
189
        ignore_mm_keys=ignore_mm_keys,
        msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
    )
190
191
192
193
194
195
196

    cached_tokenized_result = cached_processor.apply(
        token_prompt,
        mm_data=mm_data,
        hf_processor_mm_kwargs={},
    )

197
    _assert_inputs_equal(
198
199
        cached_result,
        cached_tokenized_result,
200
201
202
        ignore_mm_keys=ignore_mm_keys,
        msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
    )
203
204
205
206
207
208
209
210
211
212


def _test_processing_correctness_mistral(
    model_config: ModelConfig,
    tokenizer: MistralTokenizer,
    prompt: str,
    mm_data: MultiModalDataDict,
    baseline_processor: BaseMultiModalProcessor,
    cached_processor: BaseMultiModalProcessor,
    batch_idx: int,
213
    ignore_mm_keys: Optional[set[str]] = None,
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
):
    images = mm_data.get("image", [])
    if not isinstance(images, list):
        images = [images]

    request = ChatCompletionRequest(messages=[
        UserMessage(content=[
            TextChunk(text=prompt),
            *(ImageChunk(image=image) for image in images),
        ]),
    ])
    res = tokenizer.mistral.encode_chat_completion(request)
    token_prompt = res.tokens

    # Mistral chat outputs tokens directly, rather than text prompts
    baseline_tokenized_result = baseline_processor.apply(
        token_prompt,
        mm_data=mm_data,
        hf_processor_mm_kwargs={},
    )
    cached_tokenized_result = cached_processor.apply(
        token_prompt,
        mm_data=mm_data,
        hf_processor_mm_kwargs={},
    )

240
    _assert_inputs_equal(
241
242
        baseline_tokenized_result,
        cached_tokenized_result,
243
244
245
        ignore_mm_keys=ignore_mm_keys,
        msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
    )
246
247
248


# yapf: disable
249
@pytest.mark.parametrize("model_id", [
zhuwenwen's avatar
zhuwenwen committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    os.path.join(models_path_prefix, "rhymes-ai/Aria"),
    os.path.join(models_path_prefix, "CohereForAI/aya-vision-8b"),
    os.path.join(models_path_prefix, "Salesforce/blip2-opt-2.7b"),
    os.path.join(models_path_prefix, "facebook/chameleon-7b"),
    os.path.join(models_path_prefix, "deepseek-ai/deepseek-vl2-tiny"),
    os.path.join(models_path_prefix, "microsoft/Florence-2-base"),
    os.path.join(models_path_prefix, "adept/fuyu-8b"),
    os.path.join(models_path_prefix, "google/gemma-3-4b-it"),
    os.path.join(models_path_prefix, "THUDM/glm-4v-9b"),
    os.path.join(models_path_prefix, "ibm-granite/granite-speech-3.3-8b"),
    os.path.join(models_path_prefix, "h2oai/h2ovl-mississippi-800m"),
    os.path.join(models_path_prefix, "OpenGVLab/InternVL2-1B"),
    os.path.join(models_path_prefix, "HuggingFaceM4/Idefics3-8B-Llama3"),
    os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM2-2.2B-Instruct"),
    os.path.join(models_path_prefix, "moonshotai/Kimi-VL-A3B-Instruct"),
    os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"),
    os.path.join(models_path_prefix, "llava-hf/llava-1.5-7b-hf"),
    os.path.join(models_path_prefix, "llava-hf/llava-v1.6-mistral-7b-hf"),
    os.path.join(models_path_prefix, "llava-hf/LLaVA-NeXT-Video-7B-hf"),
    os.path.join(models_path_prefix, "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"),
    os.path.join(models_path_prefix, "meta-llama/Llama-3.2-11B-Vision-Instruct"),
    os.path.join(models_path_prefix, "TIGER-Lab/Mantis-8B-siglip-llama3"),
    os.path.join(models_path_prefix, "openbmb/MiniCPM-Llama3-V-2_5"),
    os.path.join(models_path_prefix, "openbmb/MiniCPM-o-2_6"),
    os.path.join(models_path_prefix, "openbmb/MiniCPM-V-2_6"),
    os.path.join(models_path_prefix, "allenai/Molmo-7B-D-0924"),
    os.path.join(models_path_prefix, "allenai/Molmo-7B-O-0924"),
    os.path.join(models_path_prefix, "nvidia/NVLM-D-72B"),
    os.path.join(models_path_prefix, "google/paligemma-3b-mix-224"),
    os.path.join(models_path_prefix, "google/paligemma2-3b-ft-docci-448"),
    os.path.join(models_path_prefix, "microsoft/Phi-4-multimodal-instruct"),
    os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409"),
    os.path.join(models_path_prefix, "mistral-community/pixtral-12b"),
    os.path.join(models_path_prefix, "Qwen/Qwen-VL-Chat"),
    os.path.join(models_path_prefix, "Qwen/Qwen2-VL-2B-Instruct"),
    os.path.join(models_path_prefix, "Qwen/Qwen2.5-VL-3B-Instruct"),
    os.path.join(models_path_prefix, "Qwen/Qwen2-Audio-7B-Instruct"),
    os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-7B"),
    os.path.join(models_path_prefix, "Skywork/Skywork-R1V-38B"),
    os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-llama-3_2-1b"),
    os.path.join(models_path_prefix, "openai/whisper-large-v3"),
    os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"),
292
293
294
295
296
297
298
299
300
301
302
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_correctness(
    model_id: str,
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
303
304
305
306
307
    ignore_mm_keys = None
    if 'ultravox' in model_id:
        # In Ultravox, the audio_features can be different depending on padding
        # The slight difference should not be a problem though, since
        # attention_mask lets us ignore the difference.
308
        ignore_mm_keys = {"audio_features"}
309

310
311
312
313
314
    _test_processing_correctness(
        model_id,
        hit_rate=hit_rate,
        num_batches=num_batches,
        simplify_rate=simplify_rate,
315
        ignore_mm_keys=ignore_mm_keys,
316
317
318
319
    )


# yapf: disable
zhuwenwen's avatar
zhuwenwen committed
320
@pytest.mark.parametrize("model_id", [os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct")])
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_correctness_phi3v(
    model_id: str,
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
    # HACK - this is an attempted workaround for the following bug
    # https://github.com/huggingface/transformers/issues/34307
    from transformers import AutoImageProcessor  # noqa: F401
    from transformers import AutoProcessor  # noqa: F401

    AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)

    _test_processing_correctness(
        model_id,
        hit_rate=hit_rate,
        num_batches=num_batches,
        simplify_rate=simplify_rate,
    )
344
345


346
def _assert_inputs_equal(
347
348
    a: MultiModalInputs,
    b: MultiModalInputs,
349
350
351
    *,
    ignore_mm_keys: Optional[set[str]] = None,
    msg: str = "",
352
):
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    if ignore_mm_keys is None:
        ignore_mm_keys = set()

    if msg is None:
        assert "mm_kwargs" in a and "mm_kwargs" in b
    else:
        assert "mm_kwargs" in a and "mm_kwargs" in b, msg

    for key in ignore_mm_keys:
        a["mm_kwargs"].pop(key, None)
        b["mm_kwargs"].pop(key, None)

    if msg is None:
        assert a == b
    else:
zhuwenwen's avatar
zhuwenwen committed
368
        assert a == b, msg