test_h2ovl.py 5.42 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
"""Tests for H2OVL's multimodal preprocessing kwargs."""
3
4
from collections.abc import Mapping
from typing import Optional
5
6

import pytest
7
8
from PIL import Image
from transformers import PretrainedConfig
9
10
11

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import rescale_image_size
12
13
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
14
15
16
17
18

from ....conftest import _ImageAssets
from ...utils import build_model_context


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def _get_expected_num_patches(
    config: PretrainedConfig,
    image: Image.Image,
    num_imgs: int,
    min_num: int,
    max_num: int,
):
    from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets,
                                                  get_h2ovl_target_ratios)

    width, height = image.size

    # Calculate the expected number of blocks
    if num_imgs == 1 and config.use_msac:
        # First pass
        blocks1, _, _, aspect_ratio = calculate_h2ovl_targets(
            orig_width=width,
            orig_height=height,
            target_ratios=get_h2ovl_target_ratios(
                min_num=1,
                max_num=max_num,
                prior_aspect_ratio=None,
            ),
            image_size=config.vision_config.image_size,
            use_thumbnail=False,  # Thumbnail is handled separately
        )

        # Second pass
        blocks2, _, _, _ = calculate_h2ovl_targets(
            orig_width=width,
            orig_height=height,
            target_ratios=get_h2ovl_target_ratios(
                min_num=3,
                max_num=max_num,
                prior_aspect_ratio=aspect_ratio,
            ),
            image_size=config.vision_config.image_size,
            use_thumbnail=False,
        )

        # Add thumbnail if use_thumbnail is True and total_blocks > 1
        if config.use_thumbnail:
            blocks1 += 1 if blocks1 > 1 else 0
            blocks2 += 1 if blocks2 > 1 else 0

        # Total blocks is the sum of blocks from both passes minus
        # overlapping
        total_blocks = blocks1 + blocks2 - 1

        return total_blocks

    blocks, _, _, _ = calculate_h2ovl_targets(
        orig_width=width,
        orig_height=height,
        target_ratios=get_h2ovl_target_ratios(
            min_num,
            max_num,
            prior_aspect_ratio=None,
        ),
        image_size=config.vision_config.image_size,
        use_thumbnail=False,
    )
    expected_num_patches = blocks

    if config.use_thumbnail and expected_num_patches > 1:
        expected_num_patches += 1

    return expected_num_patches


def _run_check(
    processor: BaseMultiModalProcessor,
    images: list[Image.Image],
    min_num: int,
    max_num: int,
    mm_processor_kwargs: Mapping[str, object],
):
    tokenizer = processor.info.get_tokenizer()
    config = processor.info.get_hf_config()

    mm_data = {"image": images}

    total_expected_num_patches = sum(
        _get_expected_num_patches(config, image, len(images), min_num, max_num)
        for image in images)

    processed_inputs = processor.apply("<image>" * len(images), mm_data,
                                       mm_processor_kwargs)

    # Ensure we have the right number of placeholders per num_crops size
    image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
    img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
    pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape

    assert img_tok_count == 256 * total_expected_num_patches
    assert pixel_shape[0] == total_expected_num_patches


117
118
119
120
121
122
123
124
125
126
127
128
129
@pytest.mark.parametrize("model_id", [
    "h2oai/h2ovl-mississippi-800m",
    "h2oai/h2ovl-mississippi-2b",
])
@pytest.mark.parametrize(
    "size_factors",
    [
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
130
        [4.0, 2.0, 1.0],
131
132
    ],
)
133
134
135
136
@pytest.mark.parametrize(
    ("min_dynamic_patch", "max_dynamic_patch"),
    [(1, 1), (1, 2), (1, 4), (1, 8), (2, 4), (4, 8)],
)
137
@pytest.mark.parametrize("dynamic_image_size", [True, False])
138
@pytest.mark.parametrize("kwargs_on_init", [True, False])
139
140
141
142
def test_processor_override(
    model_id: str,
    image_assets: _ImageAssets,
    size_factors: list[int],
143
    min_dynamic_patch: int,
144
145
    max_dynamic_patch: int,
    dynamic_image_size: Optional[bool],
146
    kwargs_on_init: bool,
147
):
148
149
150
151
152
    mm_processor_kwargs = {
        "min_dynamic_patch": min_dynamic_patch,
        "max_dynamic_patch": max_dynamic_patch,
        "dynamic_image_size": dynamic_image_size,
    }
153
154
155
156
157

    ctx = build_model_context(
        model_name=model_id,
        tokenizer_name=model_id,
        trust_remote_code=True,
158
159
        mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
        limit_mm_per_prompt={"image": len(size_factors)},
160
    )
161
    tokenizer = cached_tokenizer_from_config(ctx.model_config)
162
163
164
165
    processor = MULTIMODAL_REGISTRY.create_processor(
        ctx.model_config,
        tokenizer=tokenizer,
    )
166
    hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
167

168
    min_num = min_dynamic_patch if dynamic_image_size else 1
169
170
    max_num = max_dynamic_patch if dynamic_image_size else 1

171
172
173
174
175
176
177
178
179
180
    _run_check(
        processor,
        [
            rescale_image_size(image_assets[0].pil_image, f)
            for f in size_factors
        ],
        min_num,
        max_num,
        hf_processor_mm_kwargs,
    )