test_h2ovl.py 5.26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Tests for H2OVL's multimodal preprocessing kwargs."""
4

5
from collections.abc import Mapping
6
7

import pytest
8
9
from PIL import Image
from transformers import PretrainedConfig
10
11
12

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import rescale_image_size
13
from vllm.multimodal.processing import BaseMultiModalProcessor
14

15
from ....conftest import ImageTestAssets
16
17
18
from ...utils import build_model_context


19
20
21
22
23
24
25
def _get_expected_num_patches(
    config: PretrainedConfig,
    image: Image.Image,
    num_imgs: int,
    min_num: int,
    max_num: int,
):
26
27
28
29
    from vllm.model_executor.models.h2ovl import (
        calculate_h2ovl_targets,
        get_h2ovl_target_ratios,
    )
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

    width, height = image.size

    # Calculate the expected number of blocks
    if num_imgs == 1 and config.use_msac:
        # First pass
        blocks1, _, _, aspect_ratio = calculate_h2ovl_targets(
            orig_width=width,
            orig_height=height,
            target_ratios=get_h2ovl_target_ratios(
                min_num=1,
                max_num=max_num,
                prior_aspect_ratio=None,
            ),
            image_size=config.vision_config.image_size,
            use_thumbnail=False,  # Thumbnail is handled separately
        )

        # Second pass
        blocks2, _, _, _ = calculate_h2ovl_targets(
            orig_width=width,
            orig_height=height,
            target_ratios=get_h2ovl_target_ratios(
                min_num=3,
                max_num=max_num,
                prior_aspect_ratio=aspect_ratio,
            ),
            image_size=config.vision_config.image_size,
            use_thumbnail=False,
        )

        # Add thumbnail if use_thumbnail is True and total_blocks > 1
        if config.use_thumbnail:
            blocks1 += 1 if blocks1 > 1 else 0
            blocks2 += 1 if blocks2 > 1 else 0

        # Total blocks is the sum of blocks from both passes minus
        # overlapping
        total_blocks = blocks1 + blocks2 - 1

        return total_blocks

    blocks, _, _, _ = calculate_h2ovl_targets(
        orig_width=width,
        orig_height=height,
        target_ratios=get_h2ovl_target_ratios(
            min_num,
            max_num,
            prior_aspect_ratio=None,
        ),
        image_size=config.vision_config.image_size,
        use_thumbnail=False,
    )
    expected_num_patches = blocks

    if config.use_thumbnail and expected_num_patches > 1:
        expected_num_patches += 1

    return expected_num_patches


def _run_check(
    processor: BaseMultiModalProcessor,
    images: list[Image.Image],
    min_num: int,
    max_num: int,
    mm_processor_kwargs: Mapping[str, object],
):
    tokenizer = processor.info.get_tokenizer()
    config = processor.info.get_hf_config()

101
    prompt = "<image>" * len(images)
102
103
104
105
    mm_data = {"image": images}

    total_expected_num_patches = sum(
        _get_expected_num_patches(config, image, len(images), min_num, max_num)
106
107
        for image in images
    )
108

109
    processed_inputs = processor(
110
111
112
113
        prompt,
        mm_items=processor.info.parse_mm_data(mm_data),
        hf_processor_mm_kwargs=mm_processor_kwargs,
    )
114
115
116
117

    # Ensure we have the right number of placeholders per num_crops size
    image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
    img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
118
    pixel_shape = processed_inputs["mm_kwargs"].get_data()["pixel_values_flat"].shape
119
120
121
122
123

    assert img_tok_count == 256 * total_expected_num_patches
    assert pixel_shape[0] == total_expected_num_patches


124
125
126
127
128
129
130
@pytest.mark.parametrize(
    "model_id",
    [
        "h2oai/h2ovl-mississippi-800m",
        "h2oai/h2ovl-mississippi-2b",
    ],
)
131
132
133
134
135
136
137
138
139
@pytest.mark.parametrize(
    "size_factors",
    [
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
140
        [4.0, 2.0, 1.0],
141
142
    ],
)
143
144
145
146
@pytest.mark.parametrize(
    ("min_dynamic_patch", "max_dynamic_patch"),
    [(1, 1), (1, 2), (1, 4), (1, 8), (2, 4), (4, 8)],
)
147
@pytest.mark.parametrize("dynamic_image_size", [True, False])
148
@pytest.mark.parametrize("kwargs_on_init", [True, False])
149
150
def test_processor_override(
    model_id: str,
151
    image_assets: ImageTestAssets,
152
    size_factors: list[int],
153
    min_dynamic_patch: int,
154
    max_dynamic_patch: int,
155
    dynamic_image_size: bool | None,
156
    kwargs_on_init: bool,
157
):
158
159
160
161
162
    mm_processor_kwargs = {
        "min_dynamic_patch": min_dynamic_patch,
        "max_dynamic_patch": max_dynamic_patch,
        "dynamic_image_size": dynamic_image_size,
    }
163
164

    ctx = build_model_context(
165
        model_id,
166
167
        mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
        limit_mm_per_prompt={"image": len(size_factors)},
168
    )
169
    processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
170
    hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
171

172
    min_num = min_dynamic_patch if dynamic_image_size else 1
173
174
    max_num = max_dynamic_patch if dynamic_image_size else 1

175
176
    _run_check(
        processor,
177
        [rescale_image_size(image_assets[0].pil_image, f) for f in size_factors],
178
179
180
181
        min_num,
        max_num,
        hf_processor_mm_kwargs,
    )