test_llava_next.py 4.44 KB
Newer Older
1
2
import re
from typing import List, Optional, Tuple
3
4
5
6

import pytest
from transformers import AutoTokenizer

7
8
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
9

10
from ..conftest import IMAGE_ASSETS
11
from .utils import check_logprobs_close
12

13
pytestmark = pytest.mark.vlm
14
15
16
17
18
19

_PREFACE = (
    "A chat between a curious human and an artificial intelligence assistant. "
    "The assistant gives helpful, detailed, and polite answers to the human's "
    "questions.")

20
21
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
22
    f"{_PREFACE} USER: <image>\nWhat's the content of the image? ASSISTANT:",
23
    "cherry_blossom":
24
25
26
    f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
    "boardwalk":
    f"{_PREFACE} USER: <image>\nWhat's in this image? ASSISTANT:",
27
})
28

29
IMAGE_TOKEN_ID = 32000
30
31


32
33
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
                                         Optional[SampleLogprobs]],
34
35
                      model: str):
    """Sanitize vllm output to be comparable with hf output."""
36
    output_ids, output_str, out_logprobs = vllm_output
37

38
39
    tokenizer = AutoTokenizer.from_pretrained(model)
    image_token_str = tokenizer.decode(IMAGE_TOKEN_ID)
40
    eos_token_id = tokenizer.eos_token_id
41

42
43
    hf_output_ids = [
        token_id for idx, token_id in enumerate(output_ids)
44
        if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
45
46
    ]

47
48
49
50
51
52
53
    hf_output_str = re.sub(fr"({image_token_str})+", "", output_str)
    assert hf_output_str[0] == " "
    hf_output_str = hf_output_str[1:]
    if hf_output_ids[-1] == eos_token_id:
        hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)

    return hf_output_ids, hf_output_str, out_logprobs
54
55


56
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-vicuna-7b-hf"])
57
58
59
60
61
62
63
64
65
66
67
68
69
@pytest.mark.parametrize(
    "size_factors",
    [
        # No image
        [],
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
    ],
)
70
71
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
72
@pytest.mark.parametrize("num_logprobs", [5])
73
74
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
                dtype, max_tokens, num_logprobs) -> None:
75
76
77
78
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
    For huggingface runner, we provide the PIL images as input.
79
80
    For vllm runner, we provide MultiModalDataDict objects 
    and corresponding vision language config as input.
81
82
83
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """
84
85
86
87
88
89
90
91
    images = [asset.pil_image for asset in image_assets]

    inputs_per_image = [(
        [prompt for _ in size_factors],
        [rescale_image_size(image, factor) for factor in size_factors],
    ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

    # max_model_len should be greater than image_feature_size
92
    with vllm_runner(model,
93
94
                     dtype=dtype,
                     max_model_len=4096,
95
                     enforce_eager=True) as vllm_model:
96
97
98
99
100
101
102
        vllm_outputs_per_image = [
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
                                                images=images)
            for prompts, images in inputs_per_image
        ]
103

104
    with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        hf_outputs_per_image = [
            hf_model.generate_greedy_logprobs_limit(prompts,
                                                    max_tokens,
                                                    num_logprobs=num_logprobs,
                                                    images=images)
            for prompts, images in inputs_per_image
        ]

    for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
                                        vllm_outputs_per_image):
        # TODO: Check whether using original CLIPVisionModel can improve
        # consistency against HF
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=[
120
                vllm_to_hf_output(vllm_output, model)
121
122
123
124
125
                for vllm_output in vllm_outputs
            ],
            name_0="hf",
            name_1="vllm",
        )