test_llava.py 10.1 KB
Newer Older
1
from typing import List, Optional, Tuple, Type, overload
2
3

import pytest
4
5
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
                          BatchEncoding)
6

7
8
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
9
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
10

11
12
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
                        _ImageAssets)
13
from .utils import check_logprobs_close
14

15
pytestmark = pytest.mark.vlm
16

17
18
_LIMIT_IMAGE_PER_PROMPT = 4

19
20
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
21
    "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
22
    "cherry_blossom":
23
    "USER: <image>\nWhat is the season?\nASSISTANT:",
24
})
25

26
27
28
29
30
models = [
    "llava-hf/llava-1.5-7b-hf",
    # TODO: Get this model to produce meaningful output in vLLM
    # "TIGER-Lab/Mantis-8B-siglip-llama3",
]
31
32


33
34
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
                                         Optional[SampleLogprobs]],
35
36
                      model: str):
    """Sanitize vllm output to be comparable with hf output."""
37
    output_ids, output_str, out_logprobs = vllm_output
38

39
40
41
    config = AutoConfig.from_pretrained(model)
    image_token_id = config.image_token_index

42
    tokenizer = AutoTokenizer.from_pretrained(model)
43
    eos_token_id = tokenizer.eos_token_id
44

45
46
    hf_output_ids = [
        token_id for idx, token_id in enumerate(output_ids)
47
        if token_id != image_token_id or output_ids[idx - 1] != image_token_id
48
    ]
49

50
51
    assert output_str[0] == " "
    hf_output_str = output_str[1:]
52
53
    if hf_output_ids[-1] == eos_token_id:
        hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
54

55
    return hf_output_ids, hf_output_str, out_logprobs
56
57


58
@overload
59
60
61
62
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
63
    model: str,
64
    *,
65
    size_factors: List[float],
66
67
    dtype: str,
    max_tokens: int,
68
    num_logprobs: int,
69
70
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
):
    ...


@overload
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
    model: str,
    *,
    sizes: List[Tuple[int, int]],
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
    ...


def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
    model: str,
    *,
    size_factors: Optional[List[float]] = None,
    sizes: Optional[List[Tuple[int, int]]] = None,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
    images = [asset.pil_image for asset in image_assets]

    if size_factors is not None:
        inputs_per_image = [(
            [prompt for _ in size_factors],
            [rescale_image_size(image, factor) for factor in size_factors],
        ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
    elif sizes is not None:
        inputs_per_image = [(
            [prompt for _ in sizes],
            [image.resize(size) for size in sizes],
        ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
    else:
        raise ValueError("You must provide either `size_factors` or `sizes`")

    _run_test(hf_runner,
              vllm_runner,
              inputs_per_image,
              model,
              dtype=dtype,
              max_tokens=max_tokens,
              num_logprobs=num_logprobs,
              tensor_parallel_size=tensor_parallel_size,
              distributed_executor_backend=distributed_executor_backend)


def _run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    inputs: List[Tuple[List[str], PromptImageInput]],
    model: str,
    *,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
143
):
144
145
146
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
147
    For huggingface runner, we provide the PIL images as input.
148
    For vllm runner, we provide MultiModalDataDict objects 
149
    and corresponding MultiModalConfig as input.
150
151
152
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """
153
154
155
156
157
158
159
160
161
162
163
    # NOTE: For local use; this isn't tested in CI yet (see TODO above)
    if model.startswith("TIGER-Lab/Mantis"):
        from mantis.models.mllava import MLlavaProcessor

        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
        mantis_processor = MLlavaProcessor.from_pretrained(
            model, torch_dtype=torch_dtype)
        assert isinstance(mantis_processor, MLlavaProcessor)
    else:
        mantis_processor = None

164
165
166
167
    # NOTE: take care of the order. run vLLM first, and then run HF.
    # vLLM needs a fresh new process without cuda initialization.
    # if we run HF first, the cuda initialization will be done and it
    # will hurt multiprocessing backend with fork method (the default method).
168

169
    # max_model_len should be greater than image_feature_size
170
    with vllm_runner(model,
171
                     dtype=dtype,
172
                     max_model_len=4096,
173
174
                     tensor_parallel_size=tensor_parallel_size,
                     distributed_executor_backend=distributed_executor_backend,
175
176
177
                     enforce_eager=True,
                     limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
                                          }) as vllm_model:
178
179
180
181
182
        vllm_outputs_per_image = [
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
                                                images=images)
183
            for prompts, images in inputs
184
185
        ]

186
    if mantis_processor is not None:
187

188
189
190
191
192
        def process(hf_inputs: BatchEncoding):
            hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
                .to(torch_dtype)  # type: ignore
            return hf_inputs
    else:
193

194
195
        def process(hf_inputs: BatchEncoding):
            return hf_inputs
196

197
198
199
    with hf_runner(model,
                   dtype=dtype,
                   postprocess_inputs=process,
200
                   auto_cls=AutoModelForVision2Seq) as hf_model:
201
202
203
204
205
        hf_outputs_per_image = [
            hf_model.generate_greedy_logprobs_limit(prompts,
                                                    max_tokens,
                                                    num_logprobs=num_logprobs,
                                                    images=images)
206
            for prompts, images in inputs
207
208
209
210
211
212
213
214
215
        ]

    for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
                                        vllm_outputs_per_image):
        # TODO: Check whether using original CLIPVisionModel can improve
        # consistency against HF
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=[
216
                vllm_to_hf_output(vllm_output, model)
217
218
219
220
221
                for vllm_output in vllm_outputs
            ],
            name_0="hf",
            name_1="vllm",
        )
222
223


224
@pytest.mark.parametrize("model", models)
225
226
227
228
229
230
231
232
233
234
235
236
237
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
238
239
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
240
@pytest.mark.parametrize("num_logprobs", [5])
241
242
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
                dtype: str, max_tokens: int, num_logprobs: int) -> None:
243
244
245
246
    run_test(
        hf_runner,
        vllm_runner,
        image_assets,
247
        model,
248
        size_factors=size_factors,
249
250
        dtype=dtype,
        max_tokens=max_tokens,
251
        num_logprobs=num_logprobs,
252
253
        tensor_parallel_size=1,
    )
254
255


256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
                                      model, dtype, max_tokens,
                                      num_logprobs) -> None:
    stop_sign = image_assets[0].pil_image
    cherry_blossom = image_assets[1].pil_image

    inputs = [(
        [
            "USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
            "USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
            "USER: <image><image><image><image>\nDescribe 4 images.\nASSISTANT:",  # noqa: E501
            "USER: <image>\nWhat is the season?\nASSISTANT:",
        ],
        [
            [stop_sign, cherry_blossom],
            # Images with different sizes and aspect-ratios
            [
                rescale_image_size(stop_sign, 0.1),
                stop_sign,
            ],
            [
                stop_sign,
                rescale_image_size(stop_sign, 0.25),
                cherry_blossom.resize((183, 488)),
                cherry_blossom.resize((488, 183))
            ],
            cherry_blossom,
        ])]

    _run_test(
        hf_runner,
        vllm_runner,
        inputs,
        model,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        tensor_parallel_size=1,
    )


301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
@pytest.mark.parametrize("model", models)
def test_context_length_too_short(vllm_runner, image_assets, model):
    images = [asset.pil_image for asset in image_assets]

    with pytest.raises(ValueError, match="too long to fit into the model"):
        vllm_model = vllm_runner(
            model,
            max_model_len=128,  # LLaVA has a feature size of 576
            enforce_eager=True,
        )

        with vllm_model:
            vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]],
                                       max_tokens=1,
                                       images=[images[0]])