test_phi3v.py 6.32 KB
Newer Older
1
import os
2
import re
3
from typing import List, Optional, Tuple, Type
4
5

import pytest
6
from PIL import Image
7
8
from transformers import AutoTokenizer

9
10
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
11
from vllm.utils import is_cpu, is_hip
12

13
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner
14
from .utils import check_logprobs_close
15

16
pytestmark = pytest.mark.vlm
17

18
19
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
20
    "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n",  # noqa: E501
21
    "cherry_blossom":
22
    "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
23
})
24

25
models = ["microsoft/Phi-3.5-vision-instruct"]
26
27


28
29
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
                                         Optional[SampleLogprobs]],
30
31
32
                      model: str):
    """Sanitize vllm output to be comparable with hf output."""
    _, output_str, out_logprobs = vllm_output
33

34
35
36
37
    output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
    assert output_str_without_image[0] == " "
    output_str_without_image = output_str_without_image[1:]

38
    hf_output_str = output_str_without_image + "<|end|><|endoftext|>"
39

40
    tokenizer = AutoTokenizer.from_pretrained(model)
41
42
43
44
45
    hf_output_ids = tokenizer.encode(output_str_without_image)
    assert hf_output_ids[0] == 1
    hf_output_ids = hf_output_ids[1:]

    return hf_output_ids, hf_output_str, out_logprobs
46
47
48
49
50
51


target_dtype = "half"
if is_cpu():
    target_dtype = "bfloat16"

52
53
54
55
56
57
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
    os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

58

59
60
61
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
62
    images: List[Image.Image],
63
    model: str,
64
    *,
65
    size_factors: List[float],
66
67
    dtype: str,
    max_tokens: int,
68
    num_logprobs: int,
69
70
71
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
72
73
74
75
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
    For huggingface runner, we provide the PIL images as input.
76
    For vllm runner, we provide MultiModalDataDict objects 
77
    and corresponding MultiModalConfig as input.
78
79
80
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """
81
82
    inputs_per_image = [(
        [prompt for _ in size_factors],
83
84
85
86
        [
            rescale_image_size(image, factor, transpose=idx)
            for idx, factor in enumerate(size_factors)
        ],
87
    ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
88

89
90
91
92
    # NOTE: take care of the order. run vLLM first, and then run HF.
    # vLLM needs a fresh new process without cuda initialization.
    # if we run HF first, the cuda initialization will be done and it
    # will hurt multiprocessing backend with fork method (the default method).
93

94
    # max_model_len should be greater than image_feature_size
95
    with vllm_runner(model,
96
                     max_model_len=4096,
97
                     max_num_seqs=1,
98
                     dtype=dtype,
99
100
                     tensor_parallel_size=tensor_parallel_size,
                     distributed_executor_backend=distributed_executor_backend,
101
                     enforce_eager=True) as vllm_model:
102
103
104
105
        vllm_outputs_per_image = [
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
106
107
                                                images=images)
            for prompts, images in inputs_per_image
108
109
110
111
        ]

    # use eager mode for hf runner, since phi3_v didn't work with flash_attn
    hf_model_kwargs = {"_attn_implementation": "eager"}
112
    with hf_runner(model, dtype=dtype,
113
                   model_kwargs=hf_model_kwargs) as hf_model:
114
115
116
117
118
        eos_token_id = hf_model.processor.tokenizer.eos_token_id
        hf_outputs_per_image = [
            hf_model.generate_greedy_logprobs_limit(prompts,
                                                    max_tokens,
                                                    num_logprobs=num_logprobs,
119
                                                    images=images,
120
                                                    eos_token_id=eos_token_id)
121
            for prompts, images in inputs_per_image
122
        ]
123

124
125
126
127
128
    for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
                                        vllm_outputs_per_image):
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=[
129
                vllm_to_hf_output(vllm_output, model)
130
131
132
133
134
135
136
137
138
                for vllm_output in vllm_outputs
            ],
            name_0="hf",
            name_1="vllm",
        )


# Since we use _attn_implementation="eager" for hf_runner, there is more
# significant numerical difference. The basic `logprobs=5` fails to pass.
139
@pytest.mark.parametrize("model", models)
140
141
142
143
144
145
146
147
148
149
150
151
152
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
153
154
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
155
@pytest.mark.parametrize("num_logprobs", [10])
156
157
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
                dtype: str, max_tokens: int, num_logprobs: int) -> None:
158
159
160
    run_test(
        hf_runner,
        vllm_runner,
161
        [asset.pil_image for asset in image_assets],
162
        model,
163
        size_factors=size_factors,
164
165
        dtype=dtype,
        max_tokens=max_tokens,
166
        num_logprobs=num_logprobs,
167
168
        tensor_parallel_size=1,
    )
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
def test_regression_7840(hf_runner, vllm_runner, image_assets, model,
                         dtype) -> None:
    # Regression test for #7840.
    run_test(
        hf_runner,
        vllm_runner,
        [image_assets[0].pil_image.resize((465, 226))],
        model,
        size_factors=[1.0],
        dtype=dtype,
        max_tokens=128,
        num_logprobs=10,
        tensor_parallel_size=1,
    )