test_phi3v.py 5.81 KB
Newer Older
1
import os
2
import re
3
from typing import List, Optional, Tuple, Type
4
5
6
7

import pytest
from transformers import AutoTokenizer

8
9
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
10
from vllm.utils import is_cpu, is_hip
11

12
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
13
from .utils import check_logprobs_close
14

15
pytestmark = pytest.mark.vlm
16

17
18
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
19
    "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n",  # noqa: E501
20
    "cherry_blossom":
21
    "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
22
})
23

24
models = ["microsoft/Phi-3-vision-128k-instruct"]
25
26


27
28
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
                                         Optional[SampleLogprobs]],
29
30
31
                      model: str):
    """Sanitize vllm output to be comparable with hf output."""
    _, output_str, out_logprobs = vllm_output
32

33
34
35
36
    output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
    assert output_str_without_image[0] == " "
    output_str_without_image = output_str_without_image[1:]

37
    hf_output_str = output_str_without_image + "<|end|><|endoftext|>"
38

39
    tokenizer = AutoTokenizer.from_pretrained(model)
40
41
42
43
44
    hf_output_ids = tokenizer.encode(output_str_without_image)
    assert hf_output_ids[0] == 1
    hf_output_ids = hf_output_ids[1:]

    return hf_output_ids, hf_output_str, out_logprobs
45
46
47
48
49
50


target_dtype = "half"
if is_cpu():
    target_dtype = "bfloat16"

51
52
53
54
55
56
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
    os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

57

58
59
60
61
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
62
    model: str,
63
    *,
64
    size_factors: List[float],
65
66
    dtype: str,
    max_tokens: int,
67
    num_logprobs: int,
68
69
70
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
71
72
73
74
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
    For huggingface runner, we provide the PIL images as input.
75
76
    For vllm runner, we provide MultiModalDataDict objects 
    and corresponding vision language config as input.
77
78
79
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """
80
81
82
83
84
85
    images = [asset.pil_image for asset in image_assets]

    inputs_per_image = [(
        [prompt for _ in size_factors],
        [rescale_image_size(image, factor) for factor in size_factors],
    ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
86

87
88
89
90
    # NOTE: take care of the order. run vLLM first, and then run HF.
    # vLLM needs a fresh new process without cuda initialization.
    # if we run HF first, the cuda initialization will be done and it
    # will hurt multiprocessing backend with fork method (the default method).
91

92
    # max_model_len should be greater than image_feature_size
93
    with vllm_runner(model,
94
                     max_model_len=4096,
95
                     max_num_seqs=1,
96
                     dtype=dtype,
97
98
                     tensor_parallel_size=tensor_parallel_size,
                     distributed_executor_backend=distributed_executor_backend,
99
                     enforce_eager=True) as vllm_model:
100
101
102
103
104
105
        vllm_outputs_per_image = [
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
                                                images=vllm_images)
            for prompts, vllm_images in inputs_per_image
106
107
108
109
        ]

    # use eager mode for hf runner, since phi3_v didn't work with flash_attn
    hf_model_kwargs = {"_attn_implementation": "eager"}
110
    with hf_runner(model, dtype=dtype,
111
                   model_kwargs=hf_model_kwargs) as hf_model:
112
113
114
115
116
117
118
119
120
        eos_token_id = hf_model.processor.tokenizer.eos_token_id
        hf_outputs_per_image = [
            hf_model.generate_greedy_logprobs_limit(prompts,
                                                    max_tokens,
                                                    num_logprobs=num_logprobs,
                                                    images=hf_images,
                                                    eos_token_id=eos_token_id)
            for prompts, hf_images in inputs_per_image
        ]
121

122
123
124
125
126
    for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
                                        vllm_outputs_per_image):
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=[
127
                vllm_to_hf_output(vllm_output, model)
128
129
130
131
132
133
134
135
136
                for vllm_output in vllm_outputs
            ],
            name_0="hf",
            name_1="vllm",
        )


# Since we use _attn_implementation="eager" for hf_runner, there is more
# significant numerical difference. The basic `logprobs=5` fails to pass.
137
@pytest.mark.parametrize("model", models)
138
139
140
141
142
143
144
145
146
147
148
149
150
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
151
152
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
153
@pytest.mark.parametrize("num_logprobs", [10])
154
155
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
                dtype: str, max_tokens: int, num_logprobs: int) -> None:
156
157
158
159
    run_test(
        hf_runner,
        vllm_runner,
        image_assets,
160
        model,
161
        size_factors=size_factors,
162
163
        dtype=dtype,
        max_tokens=max_tokens,
164
        num_logprobs=num_logprobs,
165
166
        tensor_parallel_size=1,
    )