test_common.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from functools import partial
5
from typing import Optional, Union
6
7

import numpy as np
zhuwenwen's avatar
zhuwenwen committed
8
import os
9
import pytest
10
11
12
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
                                                       UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
13
14
15
16
from PIL import Image

from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
17
18
19
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
20
21
22
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
                                               cached_tokenizer_from_config,
                                               encode_tokens)
23
24

from ....multimodal.utils import random_audio, random_image, random_video
25
from ...registry import HF_EXAMPLE_MODELS
zhuwenwen's avatar
zhuwenwen committed
26
from ....utils import models_path_prefix
27
28


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
    """
    Patch the multimodal data for GLM4.1V model.
    """
    # Ensure video metadata is included
    if "video" in mm_data:
        video = mm_data["video"]
        mm_data["video"] = (video, {
            "total_num_frames": len(video),
            "fps": len(video),
            "duration": 1,
            "video_backend": "opencv"
        })
    return mm_data


45
46
47
48
49
50
def _test_processing_correctness(
    model_id: str,
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
51
52
53
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")
54
55
56
57

    model_config = ModelConfig(
        model_id,
        task="auto",
58
59
        tokenizer=model_info.tokenizer or model_id,
        tokenizer_mode=model_info.tokenizer_mode,
60
        trust_remote_code=model_info.trust_remote_code,
61
        seed=0,
62
        dtype="auto",
63
        revision=None,
64
        hf_overrides=model_info.hf_overrides,
65
66
67
68
69
70
    )

    model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
    factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
    ctx = InputProcessingContext(
        model_config,
71
        tokenizer=cached_tokenizer_from_config(model_config),
72
73
    )
    # Ensure that it can fit all of the data
74
    cache = ProcessingCache(capacity_gb=2048)
75

76
77
78
79
80
81
82
83
84
    processing_info = factories.info(ctx)
    supported_mm_limits = processing_info.get_supported_mm_limits()
    limit_mm_per_prompt = {
        modality: 3 if limit is None else limit
        for modality, limit in supported_mm_limits.items()
    }

    model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    baseline_processor = factories.build_processor(ctx, cache=None)
    cached_processor = factories.build_processor(ctx, cache=cache)
    dummy_inputs = baseline_processor.dummy_inputs
    tokenizer = baseline_processor.info.get_tokenizer()

    rng = np.random.RandomState(0)

    input_to_hit = {
        "image": Image.new("RGB", size=(128, 128)),
        "video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
        "audio": (np.zeros((512, )), 16000),
    }
    input_factory = {
        "image":
        partial(random_image, rng, min_wh=128, max_wh=256),
        "video":
        partial(random_video,
                rng,
                min_frames=2,
                max_frames=8,
                min_wh=128,
                max_wh=256),
        "audio":
        partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
    }

    for batch_idx in range(num_batches):
        mm_data = {
            k:
            [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
115
             for _ in range(rng.randint(limit + 1))]
116
            for k, limit in limit_mm_per_prompt.items()
117
118
119
        }

        mm_counts = {k: len(vs) for k, vs in mm_data.items()}
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

        # Mistral chat outputs tokens directly, rather than text prompts
        if isinstance(tokenizer, MistralTokenizer):
            images = mm_data.get("image", [])
            request = ChatCompletionRequest(messages=[
                UserMessage(content=[
                    TextChunk(text=""),
                    *(ImageChunk(image=image) for image in images),
                ]),
            ])
            res = tokenizer.mistral.encode_chat_completion(request)
            prompt = res.tokens
        else:
            prompt = dummy_inputs.get_dummy_processor_inputs(
                model_config.max_model_len,
                mm_counts,
            ).prompt
137
138
139
140
141
142
143
144
145

        # Drop unnecessary keys and test single -> multi conversion
        if rng.rand() < simplify_rate:
            for k in list(mm_data.keys()):
                if not mm_data[k]:
                    del mm_data[k]
                elif len(mm_data[k]) == 1:
                    mm_data[k] = mm_data[k][0]

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        _test_processing_correctness_one(
            model_config,
            tokenizer,
            prompt,
            mm_data,
            baseline_processor,
            cached_processor,
            batch_idx,
        )


# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
_ADD_SPECIAL_TOKENS_OVERRIDES = {
    "mllama": False,
    "ovis": False,
164
    "paligemma": False,
165
166
167
168
169
170
171
172
173
174
175
    "ultravox": False,
    "whisper": False,
}

_IGNORE_MM_KEYS = {
    # In Ultravox, the audio_features can be different depending on padding
    # The slight difference should not be a problem though, since
    # attention_mask lets us ignore the difference.
    "ultravox": {"audio_features"},
}

176
177
178
179
180
MM_DATA_PATCHES = {
    # GLM4.1V requires video metadata to be included in the input
    "glm4v": glm4_1v_patch_mm_data,
}

181
182

def _test_processing_correctness_one(
183
    model_config: ModelConfig,
184
185
    tokenizer: AnyTokenizer,
    prompt: Union[str, list[int]],
186
187
188
189
190
    mm_data: MultiModalDataDict,
    baseline_processor: BaseMultiModalProcessor,
    cached_processor: BaseMultiModalProcessor,
    batch_idx: int,
):
191
192
    model_type = model_config.hf_config.model_type
    ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
193
194
    if model_type in MM_DATA_PATCHES:
        mm_data = MM_DATA_PATCHES[model_type](mm_data)
195
196
197
198
199
200
201
202

    if isinstance(prompt, str):
        text_prompt = prompt
        token_prompt = encode_tokens(
            tokenizer,
            prompt,
            add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type),
        )
203
    else:
204
205
206
        # Mistral does not support decode_tokens with skip_special_tokens=False
        text_prompt = None
        token_prompt = prompt
207
208
209
210
211
212
213
214
215
216
217
218
219

    baseline_tokenized_result = baseline_processor.apply(
        token_prompt,
        mm_data=mm_data,
        hf_processor_mm_kwargs={},
    )

    cached_tokenized_result = cached_processor.apply(
        token_prompt,
        mm_data=mm_data,
        hf_processor_mm_kwargs={},
    )

220
    _assert_inputs_equal(
221
222
        baseline_tokenized_result,
        cached_tokenized_result,
223
        ignore_mm_keys=ignore_mm_keys,
224
        msg=f"Failed ({batch_idx=}, {token_prompt=}, {mm_data=})",
225
    )
226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    if text_prompt is not None:
        baseline_text_result = baseline_processor.apply(
            text_prompt,
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )
        cached_text_result = cached_processor.apply(
            text_prompt,
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )

        _assert_inputs_equal(
            baseline_text_result,
            cached_text_result,
            ignore_mm_keys=ignore_mm_keys,
            msg=f"Failed ({batch_idx=}, {text_prompt=}, {mm_data=})",
        )

        _assert_inputs_equal(
            baseline_text_result,
            baseline_tokenized_result,
            ignore_mm_keys=ignore_mm_keys,
            msg=f"Failed ({batch_idx=}, {text_prompt=}, "
            f"{token_prompt=}, {mm_data=})",
        )

        _assert_inputs_equal(
            cached_text_result,
            cached_tokenized_result,
            ignore_mm_keys=ignore_mm_keys,
            msg=f"Failed ({batch_idx=}, {text_prompt=}, "
            f"{token_prompt=}, {mm_data=})",
        )

262
263

# yapf: disable
264
@pytest.mark.parametrize("model_id", [
zhuwenwen's avatar
zhuwenwen committed
265
266
267
268
269
270
271
272
273
    os.path.join(models_path_prefix, "rhymes-ai/Aria"),
    os.path.join(models_path_prefix, "CohereForAI/aya-vision-8b"),
    os.path.join(models_path_prefix, "Salesforce/blip2-opt-2.7b"),
    os.path.join(models_path_prefix, "facebook/chameleon-7b"),
    os.path.join(models_path_prefix, "deepseek-ai/deepseek-vl2-tiny"),
    os.path.join(models_path_prefix, "microsoft/Florence-2-base"),
    os.path.join(models_path_prefix, "adept/fuyu-8b"),
    os.path.join(models_path_prefix, "google/gemma-3-4b-it"),
    os.path.join(models_path_prefix, "THUDM/glm-4v-9b"),
zhuwenwen's avatar
zhuwenwen committed
274
    os.path.join(models_path_prefix, "THUDM/GLM-4.1V-9B-Thinking"),
zhuwenwen's avatar
zhuwenwen committed
275
    os.path.join(models_path_prefix, "ibm-granite/granite-speech-3.3-2b"),
zhuwenwen's avatar
zhuwenwen committed
276
277
    os.path.join(models_path_prefix, "h2oai/h2ovl-mississippi-800m"),
    os.path.join(models_path_prefix, "OpenGVLab/InternVL2-1B"),
zhuwenwen's avatar
zhuwenwen committed
278
    os.path.join(models_path_prefix, "OpenGVLab/InternVL3-1B"),
zhuwenwen's avatar
zhuwenwen committed
279
280
281
282
283
284
285
286
287
288
289
290
291
    os.path.join(models_path_prefix, "HuggingFaceM4/Idefics3-8B-Llama3"),
    os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM2-2.2B-Instruct"),
    os.path.join(models_path_prefix, "moonshotai/Kimi-VL-A3B-Instruct"),
    os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"),
    os.path.join(models_path_prefix, "llava-hf/llava-1.5-7b-hf"),
    os.path.join(models_path_prefix, "llava-hf/llava-v1.6-mistral-7b-hf"),
    os.path.join(models_path_prefix, "llava-hf/LLaVA-NeXT-Video-7B-hf"),
    os.path.join(models_path_prefix, "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"),
    os.path.join(models_path_prefix, "meta-llama/Llama-3.2-11B-Vision-Instruct"),
    os.path.join(models_path_prefix, "TIGER-Lab/Mantis-8B-siglip-llama3"),
    os.path.join(models_path_prefix, "openbmb/MiniCPM-Llama3-V-2_5"),
    os.path.join(models_path_prefix, "openbmb/MiniCPM-o-2_6"),
    os.path.join(models_path_prefix, "openbmb/MiniCPM-V-2_6"),
zhuwenwen's avatar
zhuwenwen committed
292
    os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-VL-01"),
zhuwenwen's avatar
zhuwenwen committed
293
294
    os.path.join(models_path_prefix, "allenai/Molmo-7B-D-0924"),
    os.path.join(models_path_prefix, "allenai/Molmo-7B-O-0924"),
295
296
    os.path.join(models_path_prefix,  "nvidia/NVLM-D-72B"),
    os.path.join(models_path_prefix, "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"),
zhuwenwen's avatar
zhuwenwen committed
297
298
299
    os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Gemma2-9B"),
    os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Llama3.2-3B"),
    os.path.join(models_path_prefix, "AIDC-AI/Ovis2-1B"),
zhuwenwen's avatar
zhuwenwen committed
300
301
    os.path.join(models_path_prefix, "google/paligemma-3b-mix-224"),
    os.path.join(models_path_prefix, "google/paligemma2-3b-ft-docci-448"),
zhuwenwen's avatar
zhuwenwen committed
302
    os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct"),
zhuwenwen's avatar
zhuwenwen committed
303
304
305
306
    os.path.join(models_path_prefix, "microsoft/Phi-4-multimodal-instruct"),
    os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409"),
    os.path.join(models_path_prefix, "mistral-community/pixtral-12b"),
    os.path.join(models_path_prefix, "Qwen/Qwen-VL-Chat"),
307
    os.path.join(models_path_prefix,  "Qwen/Qwen2-VL-2B-Instruct"),
zhuwenwen's avatar
zhuwenwen committed
308
309
    os.path.join(models_path_prefix, "Qwen/Qwen2.5-VL-3B-Instruct"),
    os.path.join(models_path_prefix, "Qwen/Qwen2-Audio-7B-Instruct"),
zhuwenwen's avatar
zhuwenwen committed
310
    os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-3B"),
zhuwenwen's avatar
zhuwenwen committed
311
312
313
    os.path.join(models_path_prefix, "Skywork/Skywork-R1V-38B"),
    os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-llama-3_2-1b"),
    os.path.join(models_path_prefix, "openai/whisper-large-v3"),
zhuwenwen's avatar
zhuwenwen committed
314
    os.path.join(models_path_prefix, "omni-research/Tarsier-7b"),
zhuwenwen's avatar
zhuwenwen committed
315
    os.path.join(models_path_prefix, "omni-research/Tarsier2-Recap-7b")
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_correctness(
    model_id: str,
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
    _test_processing_correctness(
        model_id,
        hit_rate=hit_rate,
        num_batches=num_batches,
        simplify_rate=simplify_rate,
    )
333
334


335
def _assert_inputs_equal(
336
337
    a: MultiModalInputs,
    b: MultiModalInputs,
338
339
340
    *,
    ignore_mm_keys: Optional[set[str]] = None,
    msg: str = "",
341
):
342
343
344
    if ignore_mm_keys is None:
        ignore_mm_keys = set()

345
    assert "mm_kwargs" in a and "mm_kwargs" in b, msg
346
347
348
349
350

    for key in ignore_mm_keys:
        a["mm_kwargs"].pop(key, None)
        b["mm_kwargs"].pop(key, None)

351
    assert a == b, msg