test_llava.py 5.09 KB
Newer Older
1
from typing import List, Optional, Tuple, Type
2
3
4
5

import pytest
from transformers import AutoTokenizer

6
7
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
8

9
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
10
from .utils import check_logprobs_close
11

12
pytestmark = pytest.mark.vlm
13

14
15
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
16
    "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
17
    "cherry_blossom":
18
19
20
    "USER: <image>\nWhat is the season?\nASSISTANT:",
    "boardwalk":
    "USER: <image>\nWhat's in this image?\nASSISTANT:",
21
})
22

23
IMAGE_TOKEN_ID = 32000
24

25
models = ["llava-hf/llava-1.5-7b-hf"]
26
27


28
29
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
                                         Optional[SampleLogprobs]],
30
31
                      model: str):
    """Sanitize vllm output to be comparable with hf output."""
32
    output_ids, output_str, out_logprobs = vllm_output
33

34
    tokenizer = AutoTokenizer.from_pretrained(model)
35
    eos_token_id = tokenizer.eos_token_id
36

37
38
    hf_output_ids = [
        token_id for idx, token_id in enumerate(output_ids)
39
        if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
40
    ]
41

42
43
    assert output_str[0] == " "
    hf_output_str = output_str[1:]
44
45
    if hf_output_ids[-1] == eos_token_id:
        hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
46

47
    return hf_output_ids, hf_output_str, out_logprobs
48
49


50
51
52
53
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
54
    model: str,
55
    *,
56
    size_factors: List[float],
57
58
    dtype: str,
    max_tokens: int,
59
    num_logprobs: int,
60
61
62
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
63
64
65
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
66
    For huggingface runner, we provide the PIL images as input.
67
68
    For vllm runner, we provide MultiModalDataDict objects 
    and corresponding vision language config as input.
69
70
71
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """
72
73
74
75
76
77
    images = [asset.pil_image for asset in image_assets]

    inputs_per_image = [(
        [prompt for _ in size_factors],
        [rescale_image_size(image, factor) for factor in size_factors],
    ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
78

79
80
81
82
    # NOTE: take care of the order. run vLLM first, and then run HF.
    # vLLM needs a fresh new process without cuda initialization.
    # if we run HF first, the cuda initialization will be done and it
    # will hurt multiprocessing backend with fork method (the default method).
83

84
    # max_model_len should be greater than image_feature_size
85
    with vllm_runner(model,
86
                     dtype=dtype,
87
88
                     tensor_parallel_size=tensor_parallel_size,
                     distributed_executor_backend=distributed_executor_backend,
89
                     enforce_eager=True) as vllm_model:
90
91
92
93
94
95
        vllm_outputs_per_image = [
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
                                                images=images)
            for prompts, images in inputs_per_image
96
97
        ]

98
    with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        hf_outputs_per_image = [
            hf_model.generate_greedy_logprobs_limit(prompts,
                                                    max_tokens,
                                                    num_logprobs=num_logprobs,
                                                    images=images)
            for prompts, images in inputs_per_image
        ]

    for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
                                        vllm_outputs_per_image):
        # TODO: Check whether using original CLIPVisionModel can improve
        # consistency against HF
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=[
114
                vllm_to_hf_output(vllm_output, model)
115
116
117
118
119
                for vllm_output in vllm_outputs
            ],
            name_0="hf",
            name_1="vllm",
        )
120
121


122
@pytest.mark.parametrize("model", models)
123
124
125
126
127
128
129
130
131
132
133
134
135
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
136
137
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
138
@pytest.mark.parametrize("num_logprobs", [5])
139
140
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
                dtype: str, max_tokens: int, num_logprobs: int) -> None:
141
142
143
144
    run_test(
        hf_runner,
        vllm_runner,
        image_assets,
145
        model,
146
        size_factors=size_factors,
147
148
        dtype=dtype,
        max_tokens=max_tokens,
149
        num_logprobs=num_logprobs,
150
151
        tensor_parallel_size=1,
    )