test_internvl.py 7.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
9
10
"""Tests for InternVL's multimodal preprocessing kwargs."""
from typing import Callable, Optional

import pytest
from transformers import AutoTokenizer

from vllm.inputs import InputContext, token_inputs
from vllm.multimodal import MultiModalRegistry

11
12
from ....conftest import _ImageAssets
from ...utils import build_model_context
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

models = ["OpenGVLab/InternVL2-2B"]


# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def input_processor_for_internvl():
    from vllm.model_executor.models.internvl import InternVLInputPipeline

    pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
    return pipeline.input_processor


@pytest.fixture()
def dummy_data_for_internvl():
    from vllm.model_executor.models.internvl import InternVLInputPipeline

    pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
    return pipeline.dummy_data


@pytest.fixture()
def get_max_internvl_image_tokens():
    from vllm.model_executor.models.internvl import (
        get_max_internvl_image_tokens)
    return get_max_internvl_image_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_input_mapper_override(
    model: str,
    image_assets: _ImageAssets,
    max_dynamic_patch: int,
    dynamic_image_size: Optional[bool],
):
    mm_processor_kwargs = {
        "max_dynamic_patch": max_dynamic_patch,
    }
    if dynamic_image_size is not None:
        mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size

    expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
    if dynamic_image_size is False:
        expected_num_patches = 1

    ctx = build_model_context(
        model_name=model,
        tokenizer_name=model,
        trust_remote_code=True,
        mm_processor_kwargs=mm_processor_kwargs,
    )

    mm_registry = MultiModalRegistry()
    mm_registry.init_mm_limits_per_prompt(ctx.model_config)

    image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
    vllm_result = mm_registry.map_input(
        ctx.model_config,
        {"image": image},
    )
    assert vllm_result["pixel_values"].size(1) == expected_num_patches


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_max_tokens_override(
    get_max_internvl_image_tokens: Callable,
    model: str,
    max_dynamic_patch: Optional[int],
    dynamic_image_size: Optional[bool],
):
    """Ensure get_max_internvl_image_tokens handles mm_processor_kwargs."""
    ctx = build_model_context(
        model_name=model,
        tokenizer_name=model,
        trust_remote_code=True,
        mm_processor_kwargs=None,
    )

    if max_dynamic_patch is None:
        max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
    expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
    if dynamic_image_size is False:
        expected_num_patches = 1
    expected_max_tokens = 256 * expected_num_patches

    actual_max_tokens = get_max_internvl_image_tokens(
        ctx=InputContext(ctx.model_config),
        max_dynamic_patch=max_dynamic_patch,
        dynamic_image_size=dynamic_image_size,
    )
    assert expected_max_tokens == actual_max_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_dummy_data_override(
    dummy_data_for_internvl: Callable,
    model: str,
    num_imgs: int,
    max_dynamic_patch: Optional[int],
    dynamic_image_size: Optional[bool],
):
    """Ensure dummy_data_for_internvl handles kwargs properly."""
    # Same as the previous test - don't initialize mm_processor_kwargs
    # in this test and assume that the kwargs will be correctly expanded by
    # the partial when calling the dummy data func.
    ctx = build_model_context(
        model_name=model,
        tokenizer_name=model,
        trust_remote_code=True,
        mm_processor_kwargs=None,
    )

    if max_dynamic_patch is None:
        max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
    expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
    if dynamic_image_size is False:
        expected_num_patches = 1
    expected_max_tokens = 256 * expected_num_patches

    dummy_data = dummy_data_for_internvl(
        ctx=ctx,
        seq_len=8192,  # Should be bigger than num_imgs * toks_per_img
        mm_counts={"image": num_imgs},
        max_dynamic_patch=max_dynamic_patch,
        dynamic_image_size=dynamic_image_size,
    )
    sequence_data = dummy_data.seq_data

    tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
    image_token_id = tokenizer.encode('<IMG_CONTEXT>',
                                      add_special_tokens=False)[0]

    # Ensure we have the right number of placeholders per size
    img_tok_count = sequence_data.get_token_ids().count(image_token_id)
    assert img_tok_count == expected_max_tokens * num_imgs


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_input_processor_override(
    input_processor_for_internvl: Callable,
    image_assets: _ImageAssets,
    model: str,
    num_imgs: int,
    max_dynamic_patch: int,
    dynamic_image_size: Optional[bool],
):
    """Ensure input_processor_for_internvl handles kwargs properly."""
    # Same as the previous test - don't initialize mm_processor_kwargs
    # in this test and assume that the kwargs will be correctly expanded by
    # the partial when calling the custom input processor.
    expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
    if dynamic_image_size is False:
        expected_num_patches = 1

    ctx = build_model_context(
        model_name=model,
        tokenizer_name=model,
        trust_remote_code=True,
        mm_processor_kwargs=None,
    )
    expected_toks_per_img = 256 * expected_num_patches

    # Build the image str / prompt based on the number of images we pass
    tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
    placeholders = "<image>" if num_imgs == 1 else "\n".join(
        f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
    prompt = placeholders
    images = [image_assets[0].pil_image.resize((448 * 2, 448 * 2))] * num_imgs

    inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
                          prompt=prompt,
                          multi_modal_data={"image": images})

    processed_inputs = input_processor_for_internvl(
        ctx,
        inputs,
        max_dynamic_patch=max_dynamic_patch,
        dynamic_image_size=dynamic_image_size,
    )

    # Ensure we have the right number of placeholders per num_crops size
    image_token_id = tokenizer.encode('<IMG_CONTEXT>',
                                      add_special_tokens=False)[0]
    img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
    assert img_tok_count == expected_toks_per_img * num_imgs