test_chameleon.py 4.2 KB
Newer Older
1
2
3
from typing import List, Optional, Type

import pytest
4
from transformers import AutoModelForVision2Seq, BatchEncoding
5
6

from vllm.multimodal.utils import rescale_image_size
7
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
8

9
10
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal
11
12
13
14
15
16
17
18
19
20
21
22
23
24

pytestmark = pytest.mark.vlm

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
    "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
    "cherry_blossom":
    "USER: <image>\nWhat is the season?\nASSISTANT:",
})

models = ["facebook/chameleon-7b"]


def run_test(
25
    hf_runner: Type[HfRunner],
26
27
28
29
30
31
32
    vllm_runner: Type[VllmRunner],
    image_assets: _ImageAssets,
    model: str,
    *,
    size_factors: List[float],
    dtype: str,
    max_tokens: int,
33
    num_logprobs: int,
34
35
36
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
37
38
39
40
41
42
43
44
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
    For huggingface runner, we provide the PIL images as input.
    For vllm runner, we provide MultiModalDataDict objects 
    and corresponding vision language config as input.
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
45
    """
46
    torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    images = [asset.pil_image for asset in image_assets]

    inputs_per_image = [(
        [prompt for _ in size_factors],
        [rescale_image_size(image, factor) for factor in size_factors],
    ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

    with vllm_runner(model,
                     max_model_len=4096,
                     dtype=dtype,
                     tensor_parallel_size=tensor_parallel_size,
                     distributed_executor_backend=distributed_executor_backend,
                     enforce_eager=True) as vllm_model:

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        vllm_outputs_per_image = [
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
                                                images=images)
            for prompts, images in inputs_per_image
        ]

    def process(hf_inputs: BatchEncoding):
        hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
            .to(torch_dtype)  # type: ignore
        return hf_inputs

    with hf_runner(model,
                   dtype=dtype,
                   postprocess_inputs=process,
77
                   auto_cls=AutoModelForVision2Seq) as hf_model:
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        hf_outputs_per_image = [
            hf_model.generate_greedy_logprobs_limit(prompts,
                                                    max_tokens,
                                                    num_logprobs=num_logprobs,
                                                    images=images)
            for prompts, images in inputs_per_image
        ]

    for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
                                        vllm_outputs_per_image):
        # HF Logprobs include image tokens, unlike vLLM, so we don't directly
        # compare them
        check_outputs_equal(
            outputs_0_lst=[outputs[:2] for outputs in hf_outputs],
            outputs_1_lst=[outputs[:2] for outputs in vllm_outputs],
            name_0="hf",
            name_1="vllm",
        )
96
97
98
99
100
101


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
102
103
        # No image
        [],
104
105
106
107
108
109
110
111
112
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
113
114
115
116
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
                dtype, max_tokens, num_logprobs) -> None:
117
    run_test(
118
        hf_runner,
119
120
121
122
123
124
        vllm_runner,
        image_assets,
        model,
        size_factors=size_factors,
        dtype=dtype,
        max_tokens=max_tokens,
125
        num_logprobs=num_logprobs,
126
127
        tensor_parallel_size=1,
    )