test_llava_next.py 8.84 KB
Newer Older
1
from typing import List, Optional, Tuple, Type, overload
2
3

import pytest
4
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
5

6
7
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
8

9
10
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
                        _ImageAssets)
11
from .utils import check_logprobs_close
12

13
pytestmark = pytest.mark.vlm
14

15
_LIMIT_IMAGE_PER_PROMPT = 4
16

17
18
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
19
    "[INST] <image>\nWhat's the content of the image? [/INST]",
20
    "cherry_blossom":
21
    "[INST] <image>\nWhat is the season? [/INST]",
22
})
23

24
models = ["llava-hf/llava-v1.6-mistral-7b-hf"]
25

26

27
28
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
                                         Optional[SampleLogprobs]],
29
30
                      model: str):
    """Sanitize vllm output to be comparable with hf output."""
31
    output_ids, output_str, out_logprobs = vllm_output
32

33
34
35
    config = AutoConfig.from_pretrained(model)
    image_token_id = config.image_token_index

36
    tokenizer = AutoTokenizer.from_pretrained(model)
37
    eos_token_id = tokenizer.eos_token_id
38

39
40
    hf_output_ids = [
        token_id for idx, token_id in enumerate(output_ids)
41
        if token_id != image_token_id or output_ids[idx - 1] != image_token_id
42
43
    ]

44
45
    assert output_str[0] == " "
    hf_output_str = output_str[1:]
46
47
48
49
    if hf_output_ids[-1] == eos_token_id:
        hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)

    return hf_output_ids, hf_output_str, out_logprobs
50
51


52
@overload
53
54
55
56
57
58
59
60
61
62
63
64
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
    model: str,
    *,
    size_factors: List[float],
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
):
    ...


@overload
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
    model: str,
    *,
    sizes: List[Tuple[int, int]],
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
    ...


def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
    model: str,
    *,
    size_factors: Optional[List[float]] = None,
    sizes: Optional[List[Tuple[int, int]]] = None,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
99
):
100
101
    images = [asset.pil_image for asset in image_assets]

102
103
104
105
106
107
108
109
110
111
112
113
    if size_factors is not None:
        inputs_per_image = [(
            [prompt for _ in size_factors],
            [rescale_image_size(image, factor) for factor in size_factors],
        ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
    elif sizes is not None:
        inputs_per_image = [(
            [prompt for _ in sizes],
            [image.resize(size) for size in sizes],
        ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
    else:
        raise ValueError("You must provide either `size_factors` or `sizes`")
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    _run_test(hf_runner,
              vllm_runner,
              inputs_per_image,
              model,
              dtype=dtype,
              max_tokens=max_tokens,
              num_logprobs=num_logprobs,
              tensor_parallel_size=tensor_parallel_size,
              distributed_executor_backend=distributed_executor_backend)


def _run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    inputs: List[Tuple[List[str], PromptImageInput]],
    model: str,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
137
    # max_model_len should be greater than image_feature_size
138
    with vllm_runner(model,
139
                     dtype=dtype,
140
                     max_model_len=10240,
141
142
                     tensor_parallel_size=tensor_parallel_size,
                     distributed_executor_backend=distributed_executor_backend,
143
144
145
                     enforce_eager=True,
                     limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
                                          }) as vllm_model:
146
147
148
149
150
        vllm_outputs_per_image = [
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
                                                images=images)
151
            for prompts, images in inputs
152
        ]
153

154
155
    with hf_runner(model, dtype=dtype,
                   auto_cls=AutoModelForVision2Seq) as hf_model:
156
157
158
159
160
        hf_outputs_per_image = [
            hf_model.generate_greedy_logprobs_limit(prompts,
                                                    max_tokens,
                                                    num_logprobs=num_logprobs,
                                                    images=images)
161
            for prompts, images in inputs
162
163
164
165
166
167
168
169
170
        ]

    for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
                                        vllm_outputs_per_image):
        # TODO: Check whether using original CLIPVisionModel can improve
        # consistency against HF
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=[
171
                vllm_to_hf_output(vllm_output, model)
172
173
174
175
176
                for vllm_output in vllm_outputs
            ],
            name_0="hf",
            name_1="vllm",
        )
177
178


179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
                dtype, max_tokens, num_logprobs) -> None:
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
    For huggingface runner, we provide the PIL images as input.
202
    For vllm runner, we provide MultiModalDataDict objects
203
    and corresponding MultiModalConfig as input.
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """
    run_test(
        hf_runner,
        vllm_runner,
        image_assets,
        model,
        size_factors=size_factors,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        tensor_parallel_size=1,
    )


220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "sizes",
    [[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
                            dtype, max_tokens, num_logprobs) -> None:
    run_test(
        hf_runner,
        vllm_runner,
        image_assets,
        model,
        sizes=sizes,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        tensor_parallel_size=1,
    )
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
                                      model, dtype, max_tokens,
                                      num_logprobs) -> None:
    stop_sign = image_assets[0].pil_image
    cherry_blossom = image_assets[1].pil_image

    inputs = [(
        [
            "[INST] <image><image>\nDescribe 2 images. [/INST]",
            "[INST] <image><image>\nDescribe 2 images. [/INST]",
            "[INST] <image><image><image><image>\nDescribe 4 images. [/INST]",
            "[INST] <image>\nWhat is the season? [/INST]"
        ],
        [
            [stop_sign, cherry_blossom],
            # Images with different sizes and aspect-ratios
            [
                rescale_image_size(stop_sign, 0.1),
                stop_sign,
            ],
            [
                stop_sign,
                rescale_image_size(stop_sign, 0.25),
                cherry_blossom.resize((183, 488)),
                cherry_blossom.resize((488, 183))
            ],
            cherry_blossom,
        ])]

    _run_test(
        hf_runner,
        vllm_runner,
        inputs,
        model,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        tensor_parallel_size=1,
    )