test_phi3v.py 7.75 KB
Newer Older
1
import os
2
import re
3
from typing import List, Optional, Tuple, Type, Union
4
5

import pytest
6
from PIL import Image
7
8
from transformers import AutoTokenizer

9
10
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
11
from vllm.utils import is_cpu, is_hip
12

13
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner
14
from .utils import check_logprobs_close
15

16
pytestmark = pytest.mark.vlm
17

18
19
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
20
    "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n",  # noqa: E501
21
    "cherry_blossom":
22
    "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
23
})
24
HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n"  # noqa: E501
25

26
models = ["microsoft/Phi-3.5-vision-instruct"]
27
28


29
30
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
                                         Optional[SampleLogprobs]],
31
32
33
                      model: str):
    """Sanitize vllm output to be comparable with hf output."""
    _, output_str, out_logprobs = vllm_output
34

35
36
37
38
    output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
    assert output_str_without_image[0] == " "
    output_str_without_image = output_str_without_image[1:]

39
    hf_output_str = output_str_without_image + "<|end|><|endoftext|>"
40

41
    tokenizer = AutoTokenizer.from_pretrained(model)
42
43
44
45
46
    hf_output_ids = tokenizer.encode(output_str_without_image)
    assert hf_output_ids[0] == 1
    hf_output_ids = hf_output_ids[1:]

    return hf_output_ids, hf_output_str, out_logprobs
47
48
49
50
51
52


target_dtype = "half"
if is_cpu():
    target_dtype = "bfloat16"

53
54
55
56
57
58
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
    os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

59

60
61
62
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
63
64
    inputs: List[Tuple[List[str], Union[List[Image.Image],
                                        List[List[Image.Image]]]]],
65
    model: str,
66
67
68
    *,
    dtype: str,
    max_tokens: int,
69
    num_logprobs: int,
70
    mm_limit: int,
71
72
73
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
74
75
76
77
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
    For huggingface runner, we provide the PIL images as input.
78
    For vllm runner, we provide MultiModalDataDict objects 
79
    and corresponding MultiModalConfig as input.
80
81
82
83
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """

84
85
86
87
    # NOTE: take care of the order. run vLLM first, and then run HF.
    # vLLM needs a fresh new process without cuda initialization.
    # if we run HF first, the cuda initialization will be done and it
    # will hurt multiprocessing backend with fork method (the default method).
88

89
    # max_model_len should be greater than image_feature_size
90
    with vllm_runner(model,
91
                     max_model_len=4096,
92
                     max_num_seqs=1,
93
                     dtype=dtype,
94
                     limit_mm_per_prompt={"image": mm_limit},
95
96
                     tensor_parallel_size=tensor_parallel_size,
                     distributed_executor_backend=distributed_executor_backend,
97
                     enforce_eager=True) as vllm_model:
98
        vllm_outputs_per_case = [
99
100
101
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
102
                                                images=images)
103
            for prompts, images in inputs
104
105
106
107
        ]

    # use eager mode for hf runner, since phi3_v didn't work with flash_attn
    hf_model_kwargs = {"_attn_implementation": "eager"}
108
    with hf_runner(model, dtype=dtype,
109
                   model_kwargs=hf_model_kwargs) as hf_model:
110
        eos_token_id = hf_model.processor.tokenizer.eos_token_id
111
        hf_outputs_per_case = [
112
113
114
            hf_model.generate_greedy_logprobs_limit(prompts,
                                                    max_tokens,
                                                    num_logprobs=num_logprobs,
115
                                                    images=images,
116
                                                    eos_token_id=eos_token_id)
117
            for prompts, images in inputs
118
        ]
119

120
121
    for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
                                        vllm_outputs_per_case):
122
123
124
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=[
125
                vllm_to_hf_output(vllm_output, model)
126
127
128
129
130
131
132
133
134
                for vllm_output in vllm_outputs
            ],
            name_0="hf",
            name_1="vllm",
        )


# Since we use _attn_implementation="eager" for hf_runner, there is more
# significant numerical difference. The basic `logprobs=5` fails to pass.
135
@pytest.mark.parametrize("model", models)
136
137
138
139
140
141
142
143
144
145
146
147
148
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
149
150
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
151
@pytest.mark.parametrize("num_logprobs", [10])
152
153
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
                dtype: str, max_tokens: int, num_logprobs: int) -> None:
154
155
156
157
158
159
160
    images = [asset.pil_image for asset in image_assets]

    inputs_per_image = [(
        [prompt for _ in size_factors],
        [rescale_image_size(image, factor) for factor in size_factors],
    ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

161
162
163
    run_test(
        hf_runner,
        vllm_runner,
164
        inputs_per_image,
165
        model,
166
167
        dtype=dtype,
        max_tokens=max_tokens,
168
        num_logprobs=num_logprobs,
169
        mm_limit=1,
170
171
        tensor_parallel_size=1,
    )
172
173
174
175
176
177


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
def test_regression_7840(hf_runner, vllm_runner, image_assets, model,
                         dtype) -> None:
178
179
180
181
182
183
    images = [asset.pil_image for asset in image_assets]

    inputs_regresion_7840 = [
        ([prompt], [image]) for image, prompt in zip(images, HF_IMAGE_PROMPTS)
    ]

184
185
186
187
    # Regression test for #7840.
    run_test(
        hf_runner,
        vllm_runner,
188
        inputs_regresion_7840,
189
190
191
192
        model,
        dtype=dtype,
        max_tokens=128,
        num_logprobs=10,
193
        mm_limit=1,
194
195
        tensor_parallel_size=1,
    )
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
214
@pytest.mark.parametrize("num_logprobs", [10])
215
216
217
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
                             size_factors, dtype: str, max_tokens: int,
                             num_logprobs: int) -> None:
218
219
220
221
222
223
224
225
226
    images = [asset.pil_image for asset in image_assets]

    inputs_per_case = [
        ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
         [[rescale_image_size(image, factor) for image in images]
          for factor in size_factors])
    ]

    run_test(
227
228
        hf_runner,
        vllm_runner,
229
        inputs_per_case,
230
231
232
233
        model,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
234
        mm_limit=2,
235
236
        tensor_parallel_size=1,
    )