test_llava.py 6.14 KB
Newer Older
1
from typing import List, Optional, Tuple, Type
2
3

import pytest
4
5
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
                          BatchEncoding)
6

7
8
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
9
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
10

11
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
12
from .utils import check_logprobs_close
13

14
pytestmark = pytest.mark.vlm
15

16
17
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
18
    "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
19
    "cherry_blossom":
20
    "USER: <image>\nWhat is the season?\nASSISTANT:",
21
})
22

23
24
25
26
27
models = [
    "llava-hf/llava-1.5-7b-hf",
    # TODO: Get this model to produce meaningful output in vLLM
    # "TIGER-Lab/Mantis-8B-siglip-llama3",
]
28
29


30
31
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
                                         Optional[SampleLogprobs]],
32
33
                      model: str):
    """Sanitize vllm output to be comparable with hf output."""
34
    output_ids, output_str, out_logprobs = vllm_output
35

36
37
38
    config = AutoConfig.from_pretrained(model)
    image_token_id = config.image_token_index

39
    tokenizer = AutoTokenizer.from_pretrained(model)
40
    eos_token_id = tokenizer.eos_token_id
41

42
43
    hf_output_ids = [
        token_id for idx, token_id in enumerate(output_ids)
44
        if token_id != image_token_id or output_ids[idx - 1] != image_token_id
45
    ]
46

47
48
    assert output_str[0] == " "
    hf_output_str = output_str[1:]
49
50
    if hf_output_ids[-1] == eos_token_id:
        hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
51

52
    return hf_output_ids, hf_output_str, out_logprobs
53
54


55
56
57
58
def run_test(
    hf_runner: Type[HfRunner],
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
59
    model: str,
60
    *,
61
    size_factors: List[float],
62
63
    dtype: str,
    max_tokens: int,
64
    num_logprobs: int,
65
66
67
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
68
69
70
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
71
    For huggingface runner, we provide the PIL images as input.
72
    For vllm runner, we provide MultiModalDataDict objects 
73
    and corresponding MultiModalConfig as input.
74
75
76
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """
77
78
79
80
81
82
83
84
85
86
87
    # NOTE: For local use; this isn't tested in CI yet (see TODO above)
    if model.startswith("TIGER-Lab/Mantis"):
        from mantis.models.mllava import MLlavaProcessor

        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
        mantis_processor = MLlavaProcessor.from_pretrained(
            model, torch_dtype=torch_dtype)
        assert isinstance(mantis_processor, MLlavaProcessor)
    else:
        mantis_processor = None

88
89
90
91
92
93
    images = [asset.pil_image for asset in image_assets]

    inputs_per_image = [(
        [prompt for _ in size_factors],
        [rescale_image_size(image, factor) for factor in size_factors],
    ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
94

95
96
97
98
    # NOTE: take care of the order. run vLLM first, and then run HF.
    # vLLM needs a fresh new process without cuda initialization.
    # if we run HF first, the cuda initialization will be done and it
    # will hurt multiprocessing backend with fork method (the default method).
99

100
    # max_model_len should be greater than image_feature_size
101
    with vllm_runner(model,
102
                     dtype=dtype,
103
104
                     tensor_parallel_size=tensor_parallel_size,
                     distributed_executor_backend=distributed_executor_backend,
105
                     enforce_eager=True) as vllm_model:
106
107
108
109
110
111
        vllm_outputs_per_image = [
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
                                                images=images)
            for prompts, images in inputs_per_image
112
113
        ]

114
    if mantis_processor is not None:
115

116
117
118
119
120
        def process(hf_inputs: BatchEncoding):
            hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
                .to(torch_dtype)  # type: ignore
            return hf_inputs
    else:
121

122
123
        def process(hf_inputs: BatchEncoding):
            return hf_inputs
124

125
126
127
    with hf_runner(model,
                   dtype=dtype,
                   postprocess_inputs=process,
128
                   auto_cls=AutoModelForVision2Seq) as hf_model:
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        hf_outputs_per_image = [
            hf_model.generate_greedy_logprobs_limit(prompts,
                                                    max_tokens,
                                                    num_logprobs=num_logprobs,
                                                    images=images)
            for prompts, images in inputs_per_image
        ]

    for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
                                        vllm_outputs_per_image):
        # TODO: Check whether using original CLIPVisionModel can improve
        # consistency against HF
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=[
144
                vllm_to_hf_output(vllm_output, model)
145
146
147
148
149
                for vllm_output in vllm_outputs
            ],
            name_0="hf",
            name_1="vllm",
        )
150
151


152
@pytest.mark.parametrize("model", models)
153
154
155
156
157
158
159
160
161
162
163
164
165
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
166
167
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
168
@pytest.mark.parametrize("num_logprobs", [5])
169
170
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
                dtype: str, max_tokens: int, num_logprobs: int) -> None:
171
172
173
174
    run_test(
        hf_runner,
        vllm_runner,
        image_assets,
175
        model,
176
        size_factors=size_factors,
177
178
        dtype=dtype,
        max_tokens=max_tokens,
179
        num_logprobs=num_logprobs,
180
181
        tensor_parallel_size=1,
    )