test_llava_next.py 6.7 KB
Newer Older
1
from typing import List, Optional, Tuple, Type, overload
2
3

import pytest
4
from transformers import AutoTokenizer
5

6
7
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
8

9
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
10
from .utils import check_logprobs_close
11

12
pytestmark = pytest.mark.vlm
13
14
15
16
17
18

_PREFACE = (
    "A chat between a curious human and an artificial intelligence assistant. "
    "The assistant gives helpful, detailed, and polite answers to the human's "
    "questions.")

19
20
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
21
    f"{_PREFACE} USER: <image>\nWhat's the content of the image? ASSISTANT:",
22
    "cherry_blossom":
23
    f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
24
})
25

26
IMAGE_TOKEN_ID = 32000
27

28
29
models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]

30

31
32
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
                                         Optional[SampleLogprobs]],
33
34
                      model: str):
    """Sanitize vllm output to be comparable with hf output."""
35
    output_ids, output_str, out_logprobs = vllm_output
36

37
    tokenizer = AutoTokenizer.from_pretrained(model)
38
    eos_token_id = tokenizer.eos_token_id
39

40
41
    hf_output_ids = [
        token_id for idx, token_id in enumerate(output_ids)
42
        if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
43
44
    ]

45
46
    assert output_str[0] == " "
    hf_output_str = output_str[1:]
47
48
49
50
    if hf_output_ids[-1] == eos_token_id:
        hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)

    return hf_output_ids, hf_output_str, out_logprobs
51
52


53
@overload
54
55
56
57
58
59
60
61
62
63
64
65
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
    model: str,
    *,
    size_factors: List[float],
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
):
    ...


@overload
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
    model: str,
    *,
    sizes: List[Tuple[int, int]],
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
    ...


def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
    model: str,
    *,
    size_factors: Optional[List[float]] = None,
    sizes: Optional[List[Tuple[int, int]]] = None,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
100
):
101
102
    images = [asset.pil_image for asset in image_assets]

103
104
105
106
107
108
109
110
111
112
113
114
    if size_factors is not None:
        inputs_per_image = [(
            [prompt for _ in size_factors],
            [rescale_image_size(image, factor) for factor in size_factors],
        ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
    elif sizes is not None:
        inputs_per_image = [(
            [prompt for _ in sizes],
            [image.resize(size) for size in sizes],
        ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
    else:
        raise ValueError("You must provide either `size_factors` or `sizes`")
115
116

    # max_model_len should be greater than image_feature_size
117
    with vllm_runner(model,
118
119
                     dtype=dtype,
                     max_model_len=4096,
120
121
                     tensor_parallel_size=tensor_parallel_size,
                     distributed_executor_backend=distributed_executor_backend,
122
                     enforce_eager=True) as vllm_model:
123
124
125
126
127
128
129
        vllm_outputs_per_image = [
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
                                                images=images)
            for prompts, images in inputs_per_image
        ]
130

131
    with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        hf_outputs_per_image = [
            hf_model.generate_greedy_logprobs_limit(prompts,
                                                    max_tokens,
                                                    num_logprobs=num_logprobs,
                                                    images=images)
            for prompts, images in inputs_per_image
        ]

    for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
                                        vllm_outputs_per_image):
        # TODO: Check whether using original CLIPVisionModel can improve
        # consistency against HF
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=[
147
                vllm_to_hf_output(vllm_output, model)
148
149
150
151
152
                for vllm_output in vllm_outputs
            ],
            name_0="hf",
            name_1="vllm",
        )
153
154


155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
                dtype, max_tokens, num_logprobs) -> None:
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
    For huggingface runner, we provide the PIL images as input.
    For vllm runner, we provide MultiModalDataDict objects 
    and corresponding vision language config as input.
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """
    run_test(
        hf_runner,
        vllm_runner,
        image_assets,
        model,
        size_factors=size_factors,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        tensor_parallel_size=1,
    )


196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "sizes",
    [[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
                            dtype, max_tokens, num_logprobs) -> None:
    run_test(
        hf_runner,
        vllm_runner,
        image_assets,
        model,
        sizes=sizes,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        tensor_parallel_size=1,
    )