test_internvl.py 4.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Tests for InternVL's multimodal preprocessing kwargs."""
3
from typing import Mapping, Optional
4
5

import pytest
6
7
from PIL import Image
from transformers import PretrainedConfig
8

9
from vllm.multimodal import MULTIMODAL_REGISTRY
10
11
12
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
13

14
15
from ....conftest import _ImageAssets
from ...utils import build_model_context
16
17


18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def _get_expected_num_patches(
    config: PretrainedConfig,
    image: Image.Image,
    num_imgs: int,
    min_num: int,
    max_num: int,
):
    from vllm.model_executor.models.internvl import (
        calculate_internvl_targets, get_internvl_target_ratios)

    width, height = image.size

    blocks, _, _ = calculate_internvl_targets(
        orig_width=width,
        orig_height=height,
        target_ratios=get_internvl_target_ratios(
            min_num,
            max_num,
        ),
        image_size=config.vision_config.image_size,
        use_thumbnail=False,
    )
    expected_num_patches = blocks

    if config.use_thumbnail and expected_num_patches > 1:
        expected_num_patches += 1

    return expected_num_patches


def _run_check(
    processor: BaseMultiModalProcessor,
    images: list[Image.Image],
    min_num: int,
    max_num: int,
    mm_processor_kwargs: Mapping[str, object],
):
    tokenizer = processor.info.get_tokenizer()
    config = processor.info.get_hf_config()

    mm_data = {"image": images}

    total_expected_num_patches = sum(
        _get_expected_num_patches(config, image, len(images), min_num, max_num)
        for image in images)

    processed_inputs = processor.apply("<image>" * len(images), mm_data,
                                       mm_processor_kwargs)

    # Ensure we have the right number of placeholders per num_crops size
    image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
    img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
    pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape

    assert img_tok_count == 256 * total_expected_num_patches
    assert pixel_shape[0] == total_expected_num_patches


76
@pytest.mark.parametrize("model_id", ["OpenGVLab/InternVL2-2B"])
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
@pytest.mark.parametrize(
    "size_factors",
    [
        # Single-scale
        [1.0],
        # Single-scale, batched
        [1.0, 1.0, 1.0],
        # Multi-scale
        [0.25, 0.5, 1.0],
        [4.0, 2.0, 1.0],
    ],
)
@pytest.mark.parametrize(
    ("min_dynamic_patch", "max_dynamic_patch"),
    [(1, 1), (1, 2), (1, 4), (1, 8), (2, 4), (4, 8)],
)
@pytest.mark.parametrize("dynamic_image_size", [True, False])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
95
96
def test_processor_override(
    model_id: str,
97
    image_assets: _ImageAssets,
98
99
    size_factors: list[int],
    min_dynamic_patch: int,
100
101
    max_dynamic_patch: int,
    dynamic_image_size: Optional[bool],
102
    kwargs_on_init: bool,
103
):
104
105
106
107
108
109
    mm_processor_kwargs = {
        "min_dynamic_patch": min_dynamic_patch,
        "max_dynamic_patch": max_dynamic_patch,
        "dynamic_image_size": dynamic_image_size,
    }

110
    ctx = build_model_context(
111
112
        model_name=model_id,
        tokenizer_name=model_id,
113
        trust_remote_code=True,
114
115
        mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
        limit_mm_per_prompt={"image": len(size_factors)},
116
    )
117
    tokenizer = cached_tokenizer_from_config(ctx.model_config)
118
119
120
    processor = MULTIMODAL_REGISTRY.create_processor(
        ctx.model_config,
        tokenizer=tokenizer,
121
    )
122
    hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
123

124
125
    min_num = min_dynamic_patch if dynamic_image_size else 1
    max_num = max_dynamic_patch if dynamic_image_size else 1
126

127
128
129
130
131
132
133
134
135
136
    _run_check(
        processor,
        [
            rescale_image_size(image_assets[0].pil_image, f)
            for f in size_factors
        ],
        min_num,
        max_num,
        hf_processor_mm_kwargs,
    )