test_llava_next.py 5.92 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import itertools
from functools import partial

7
8
import pytest
from PIL import Image
9
from pqdm.threads import pqdm
10

11
from vllm.multimodal import MULTIMODAL_REGISTRY
12
from vllm.multimodal.parse import ImageSize
13
from vllm.multimodal.processing import BaseMultiModalProcessor
14

15
from ...utils import build_model_context
16
17


18
19
20
21
22
23
24
def _validate_image_max_tokens_one(
    processor: BaseMultiModalProcessor,
    max_tokens: int,
    failed_size_excs: list[tuple[ImageSize, Exception]],
    image_size: ImageSize,
) -> None:
    info = processor.info
25
26
27
    feature_size = info.get_num_image_tokens(
        image_width=image_size.width, image_height=image_size.height
    )
28
29
30
31
32
33
34

    try:
        assert feature_size <= max_tokens, f"{feature_size} <= {max_tokens}"
    except Exception as exc:
        failed_size_excs.append((image_size, exc))


35
36
37
@pytest.mark.skip(
    "This test takes around 5 minutes to run. Comment this out to run it manually."
)
38
39
40
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_processor_max_tokens(model_id):
    ctx = build_model_context(
41
        model_id,
42
43
44
        mm_processor_kwargs=None,
        limit_mm_per_prompt={"image": 1},
    )
45
    processor = MULTIMODAL_REGISTRY.create_processor(ctx.renderer_config)
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    info = processor.info

    seen_aspect_ratios = set[float]()
    image_sizes = list[ImageSize]()

    # The aspect ratio of the grid layout is between 1 and 2
    # NOTE: Assumes that feature size calculation is the same if we
    # swap the width and height of the image
    for w, h in itertools.product(range(32, 4096), repeat=2):
        aspect_ratio = w / h
        if 1 <= aspect_ratio <= 2 and aspect_ratio not in seen_aspect_ratios:
            image_sizes.append(ImageSize(w, h))
            seen_aspect_ratios.add(aspect_ratio)

    failed_size_excs = list[tuple[ImageSize, Exception]]()

    validate_one = partial(
        _validate_image_max_tokens_one,
        processor,
        info.get_max_image_tokens(),  # type: ignore
        failed_size_excs,
    )
    pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes")

    if failed_size_excs:
71
72
73
        msg = "Found failing image sizes:" + "\n========\n".join(
            f"[{size}]\n{exc}" for size, exc in failed_size_excs
        )
74
75
76
        raise AssertionError(msg)


77
def _validate_image_prompt_replacements_one(
78
    processor: BaseMultiModalProcessor,
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    num_imgs: int,
    failed_size_excs: list[tuple[ImageSize, Exception]],
    image_size: ImageSize,
) -> None:
    prompt = "<image>" * num_imgs
    image = Image.new("RGB", size=image_size)
    mm_data = {"image": [image] * num_imgs}

    try:
        # The processor will throw an error if there is a mismatch
        # in the prompt replacements
        processed_inputs = processor.apply(prompt, mm_data, {})

        image_placeholders = processed_inputs["mm_placeholders"]["image"]
        assert len(image_placeholders) == num_imgs

        first_placeholder = image_placeholders[0]

        # NOTE: There is a BOS token
98
        assert first_placeholder.offset == 1
99
100
101
102
        assert (
            first_placeholder.length
            == (len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
        )
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    except Exception as exc:
        failed_size_excs.append((image_size, exc))


def _test_image_prompt_replacements(
    processor,
    *,
    num_imgs: int,
    image_sizes: list[ImageSize],
) -> None:
    """
    Ensure LlavaNextMultiModalProcessor
    handles prompt replacement properly for input images.
    """
    failed_size_excs = list[tuple[ImageSize, Exception]]()

    validate_one = partial(
        _validate_image_prompt_replacements_one,
        processor,
        num_imgs,
        failed_size_excs,
    )
    pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes")

    if failed_size_excs:
129
130
131
        msg = "Found failing image sizes:" + "\n========\n".join(
            f"[{size}]\n{exc}" for size, exc in failed_size_excs
        )
132
133
134
        raise AssertionError(msg)


135
136
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("num_imgs", [1, 2])
137
def test_processor_prompt_replacements_regression(model_id, num_imgs):
138
    ctx = build_model_context(
139
        model_id,
140
141
142
        mm_processor_kwargs=None,
        limit_mm_per_prompt={"image": num_imgs},
    )
143
    processor = MULTIMODAL_REGISTRY.create_processor(ctx.renderer_config)
144

145
146
147
148
149
150
151
152
153
    image_ratios = [
        (171, 152),
        (184, 161),
        (198, 176),
        (333, 296),
        (369, 328),
        (488, 183),
        (2560, 1669),
    ]
154
    image_sizes = [
155
        size for w, h in image_ratios for size in [ImageSize(w, h), ImageSize(h, w)]
156
157
158
159
160
161
162
    ]

    _test_image_prompt_replacements(
        processor,
        num_imgs=num_imgs,
        image_sizes=image_sizes,
    )
163
164


165
166
167
@pytest.mark.skip(
    "This test takes around 2 hours to run. Comment this out to run it manually."
)
168
169
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("num_imgs", [1])
170
def test_processor_prompt_replacements_all(model_id, num_imgs):
171
    ctx = build_model_context(
172
        model_id,
173
174
175
        mm_processor_kwargs=None,
        limit_mm_per_prompt={"image": num_imgs},
    )
176
    processor = MULTIMODAL_REGISTRY.create_processor(ctx.renderer_config)
177

178
179
    seen_aspect_ratios = set[float]()
    image_sizes = list[ImageSize]()
180

181
182
183
184
185
186
187
188
    # The aspect ratio of the grid layout is between 1 and 2
    # NOTE: Assumes that feature size calculation is the same if we
    # swap the width and height of the image
    for w, h in itertools.product(range(64, 1024), repeat=2):
        aspect_ratio = w / h
        if 1 <= aspect_ratio <= 2 and aspect_ratio not in seen_aspect_ratios:
            image_sizes.append(ImageSize(w, h))
            seen_aspect_ratios.add(aspect_ratio)
189

190
191
192
193
194
    _test_image_prompt_replacements(
        processor,
        num_imgs=num_imgs,
        image_sizes=image_sizes,
    )