test_phi4mm.py 9.53 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import os
5
from collections.abc import Sequence
6

7
import librosa
8
import pytest
9
import regex as re
10
11
12
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer

13
from vllm.assets.image import ImageAsset
14
from vllm.logprobs import SampleLogprobs
15
from vllm.lora.request import LoRARequest
16
from vllm.multimodal.image import convert_image_mode, rescale_image_size
17
18
from vllm.platforms import current_platform

19
20
21
22
23
24
25
from ....conftest import (
    IMAGE_ASSETS,
    HfRunner,
    PromptAudioInput,
    PromptImageInput,
    VllmRunner,
)
26
27
28
from ....utils import large_gpu_test
from ...utils import check_logprobs_close

29
30
31
32
33
34
35
36
37
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
    {
        "stop_sign": "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n",  # noqa: E501
        "cherry_blossom": "<|user|>\n<|image_1|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n",  # noqa: E501
    }
)
HF_MULTIIMAGE_IMAGE_PROMPT = (
    "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n"  # noqa: E501
)
38
39
40
41
42

model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora")
43
44
45
speech_question = os.path.join(
    model_path, "examples", "what_is_shown_in_this_image.wav"
)
46
47
48
models = [model_path]


49
def vllm_to_hf_output(
50
    vllm_output: tuple[list[int], str, SampleLogprobs | None], model: str
51
):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    """Sanitize vllm output to be comparable with hf output."""
    _, output_str, out_logprobs = vllm_output

    output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
    assert output_str_without_image[0] == " "
    output_str_without_image = output_str_without_image[1:]

    hf_output_str = output_str_without_image + "<|end|><|endoftext|>"

    tokenizer = AutoTokenizer.from_pretrained(model)
    hf_output_ids = tokenizer.encode(output_str_without_image)
    assert hf_output_ids[0] == 1
    hf_output_ids = hf_output_ids[1:]

    return hf_output_ids, hf_output_str, out_logprobs


target_dtype = "half"

# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if current_platform.is_rocm():
    os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"


def run_test(
    hf_runner: type[HfRunner],
    vllm_runner: type[VllmRunner],
81
    inputs: Sequence[tuple[list[str], PromptImageInput, PromptAudioInput | None]],
82
83
84
85
86
87
88
89
    model: str,
    *,
    max_model_len: int,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    mm_limit: int,
    tensor_parallel_size: int,
90
    distributed_executor_backend: str | None = None,
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
):
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test are from IMAGE_ASSETS.
    For huggingface runner, we provide the PIL images as input.
    For vllm runner, we provide MultiModalDataDict objects
    and corresponding MultiModalConfig as input.
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """
    # NOTE: take care of the order. run vLLM first, and then run HF.
    # vLLM needs a fresh new process without cuda initialization.
    # if we run HF first, the cuda initialization will be done and it
    # will hurt multiprocessing backend with fork method (the default method).
    # max_model_len should be greater than image_feature_size
    with vllm_runner(
107
108
109
110
111
112
113
114
115
116
117
118
        model,
        runner="generate",
        max_model_len=max_model_len,
        max_num_seqs=2,
        dtype=dtype,
        limit_mm_per_prompt={"image": mm_limit},
        tensor_parallel_size=tensor_parallel_size,
        distributed_executor_backend=distributed_executor_backend,
        enable_lora=True,
        max_lora_rank=320,
        gpu_memory_utilization=0.8,  # set to 0.8 to avoid OOM in CI
        enforce_eager=True,
119
120
121
    ) as vllm_model:
        lora_request = LoRARequest("vision", 1, vision_lora_path)
        vllm_outputs_per_case = [
122
123
124
125
126
127
128
129
            vllm_model.generate_greedy_logprobs(
                prompts,
                max_tokens,
                num_logprobs=num_logprobs,
                images=images,
                audios=audios,
                lora_request=lora_request,
            )
130
            for prompts, images, audios in inputs
131
132
        ]

133
134
135
136
    # This error occurs inside `get_peft_model`
    # FIXME: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/75
    pytest.skip("HF impl is not compatible with current transformers")

137
    hf_model_kwargs = {"_attn_implementation": "sdpa"}
138
    with hf_runner(model, dtype=dtype, model_kwargs=hf_model_kwargs) as hf_model:
139
140
141
        hf_processor = hf_model.processor
        eos_token_id = hf_processor.tokenizer.eos_token_id

142
143
144
        def patch_hf_processor(
            *args, text="", images=None, audio=None, sampling_rate=None, **kwargs
        ):
145
146
147
            audios = None
            if audio is not None and sampling_rate is not None:
                audios = [(audio, sampling_rate)]
148
149
150
            return hf_processor(
                *args, text=text, images=images, audios=audios, **kwargs
            )
151
152
153

        hf_model.processor = patch_hf_processor

154
        hf_outputs_per_case = [
155
156
157
158
159
160
161
162
163
            hf_model.generate_greedy_logprobs_limit(
                prompts,
                max_tokens,
                num_logprobs=num_logprobs,
                images=images,
                audios=audios,
                eos_token_id=eos_token_id,
                num_logits_to_keep=0,
            )
164
            for prompts, images, audios in inputs
165
166
        ]

167
    for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case):
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=vllm_outputs,
            name_0="hf",
            name_1="vllm",
        )


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
187
        [0.25, 0.5, 1.0],
188
189
190
    ],
)
@pytest.mark.parametrize("dtype", [target_dtype])
191
@pytest.mark.parametrize("max_model_len", [12800])
192
193
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
194
195
196
197
198
199
200
201
202
203
204
def test_models(
    hf_runner,
    vllm_runner,
    image_assets,
    model,
    size_factors,
    dtype: str,
    max_model_len: int,
    max_tokens: int,
    num_logprobs: int,
) -> None:
205
206
    images = [asset.pil_image for asset in image_assets]

207
208
209
210
211
212
213
214
    inputs_per_image = [
        (
            [prompt for _ in size_factors],
            [rescale_image_size(image, factor) for factor in size_factors],
            None,
        )
        for image, prompt in zip(images, HF_IMAGE_PROMPTS)
    ]
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

    run_test(
        hf_runner,
        vllm_runner,
        inputs_per_image,
        model,
        dtype=dtype,
        max_model_len=max_model_len,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        mm_limit=1,
        tensor_parallel_size=1,
    )


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        # [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
@pytest.mark.parametrize("dtype", [target_dtype])
246
@pytest.mark.parametrize("max_model_len", [25600])
247
248
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
249
250
251
252
253
254
255
256
257
258
259
def test_multi_images_models(
    hf_runner,
    vllm_runner,
    image_assets,
    model,
    size_factors,
    dtype: str,
    max_model_len: int,
    max_tokens: int,
    num_logprobs: int,
) -> None:
260
261
262
    images = [asset.pil_image for asset in image_assets]

    inputs_per_case = [
263
264
        (
            [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
265
266
267
268
            [
                [rescale_image_size(image, factor) for image in images]
                for factor in size_factors
            ],
269
270
            None,
        ),
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    ]

    run_test(
        hf_runner,
        vllm_runner,
        inputs_per_case,
        model,
        dtype=dtype,
        max_model_len=max_model_len,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        mm_limit=2,
        tensor_parallel_size=1,
    )
285
286
287
288


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
289
@pytest.mark.parametrize("max_model_len", [12800])
290
291
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
292
293
294
295
296
297
298
299
300
def test_vision_speech_models(
    hf_runner,
    vllm_runner,
    model,
    dtype: str,
    max_model_len: int,
    max_tokens: int,
    num_logprobs: int,
) -> None:
301
302
    # use the example speech question so that the model outputs are reasonable
    audio = librosa.load(speech_question, sr=None)
303
    image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

    inputs_vision_speech = [
        (
            ["<|user|><|image_1|><|audio_1|><|end|><|assistant|>"],
            [image],
            [audio],
        ),
    ]

    run_test(
        hf_runner,
        vllm_runner,
        inputs_vision_speech,
        model,
        dtype=dtype,
        max_model_len=max_model_len,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        mm_limit=1,
        tensor_parallel_size=1,
    )