test_utils.py 6.87 KB
Newer Older
1
2
import base64
import mimetypes
3
4
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
5
6
7
8
from typing import Dict, Tuple

import numpy as np
import pytest
9
from PIL import Image, ImageChops
10
from transformers import AutoConfig, AutoTokenizer
11

12
from vllm.multimodal.utils import (MediaConnector,
13
                                   repeat_and_pad_placeholder_tokens)
14
15
16
17
18
19
20
21
22
23

# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
    "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
    "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
    "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
    "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]


24
25
@pytest.fixture(scope="module")
def url_images() -> Dict[str, Image.Image]:
26
27
28
29
30
31
    connector = MediaConnector()

    return {
        image_url: connector.fetch_image(image_url)
        for image_url in TEST_IMAGE_URLS
    }
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


def get_supported_suffixes() -> Tuple[str, ...]:
    # We should at least test the file types mentioned in GPT-4 with Vision
    OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif')

    # Additional file types that are supported by us
    EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff')

    return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES


def _image_equals(a: Image.Image, b: Image.Image) -> bool:
    return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()


48
@pytest.mark.asyncio
49
50
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str):
51
52
53
54
    connector = MediaConnector()

    image_sync = connector.fetch_image(image_url)
    image_async = await connector.fetch_image_async(image_url)
55
56
57
    assert _image_equals(image_sync, image_async)


58
@pytest.mark.asyncio
59
60
61
62
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
                                  image_url: str, suffix: str):
63
    connector = MediaConnector()
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    url_image = url_images[image_url]

    try:
        mime_type = Image.MIME[Image.registered_extensions()[suffix]]
    except KeyError:
        try:
            mime_type = mimetypes.types_map[suffix]
        except KeyError:
            pytest.skip('No MIME type')

    with NamedTemporaryFile(suffix=suffix) as f:
        try:
            url_image.save(f.name)
        except Exception as e:
            if e.args[0] == 'cannot write mode RGBA as JPEG':
                pytest.skip('Conversion not supported')

            raise

        base64_image = base64.b64encode(f.read()).decode("utf-8")
        data_url = f"data:{mime_type};base64,{base64_image}"

86
        data_image_sync = connector.fetch_image(data_url)
87
        if _image_equals(url_image, Image.open(f)):
88
            assert _image_equals(url_image, data_image_sync)
89
90
        else:
            pass  # Lossy format; only check that image can be opened
91

92
        data_image_async = await connector.fetch_image_async(data_url)
93
        assert _image_equals(data_image_sync, data_image_async)
94
95


96
97
98
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_local_files(image_url: str):
99
100
    connector = MediaConnector()

101
    with TemporaryDirectory() as temp_dir:
102
103
104
        local_connector = MediaConnector(allowed_local_media_path=temp_dir)

        origin_image = connector.fetch_image(image_url)
105
106
107
108
        origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
                          quality=100,
                          icc_profile=origin_image.info.get('icc_profile'))

109
110
111
112
        image_async = await local_connector.fetch_image_async(
            f"file://{temp_dir}/{os.path.basename(image_url)}")
        image_sync = local_connector.fetch_image(
            f"file://{temp_dir}/{os.path.basename(image_url)}")
113
114
115
        # Check that the images are equal
        assert not ImageChops.difference(image_sync, image_async).getbbox()

116
117
118
119
120
        with pytest.raises(ValueError, match="must be a subpath"):
            await local_connector.fetch_image_async(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
        with pytest.raises(RuntimeError, match="Cannot load local files"):
            await connector.fetch_image_async(
121
122
                f"file://{temp_dir}/../{os.path.basename(image_url)}")

123
124
125
126
127
128
        with pytest.raises(ValueError, match="must be a subpath"):
            local_connector.fetch_image(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
        with pytest.raises(RuntimeError, match="Cannot load local files"):
            connector.fetch_image(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
129
130


131
132
133
134
135
136
137
138
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_repeat_and_pad_placeholder_tokens(model):
    config = AutoConfig.from_pretrained(model)
    image_token_id = config.image_token_index

    tokenizer = AutoTokenizer.from_pretrained(model)

    test_cases = [
139
140
141
142
143
144
145
146
147
148
149
150
        (
            "<image>",
            2,
            "<image><image>",
            [32000, 32000],
            [{ "offset": 0, "length": 2 }],
        ),
        (
            "<image><image>",
            2,
            "<image><image><image>",
            [32000, 32000, 32000],
151
152
            [{ "offset": 0, "length": 2 }],
        ),
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        (
            "<image><image>",
            [3, 2],
            "<image><image><image><image><image>",
            [32000, 32000, 32000, 32000, 32000],
            [{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }],
        ),
        (
            "Image:<image>Image:<image>!",
            [3, 2],
            "Image:<image><image><image>Image:<image><image>!",
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }],
        ),
        (
            "<image>",
            [3, 2],
            "<image><image><image>",
            [32000, 32000, 32000],
            [{ "offset": 0, "length": 3 }],
        ),
    ]  # yapf: disable

    for (
            prompt,
            repeat_count,
            expected_prompt,
            expected_token_ids,
            expected_ranges,
    ) in test_cases:
        new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
184
185
186
187
188
189
190
191
192
            tokenizer=tokenizer,
            prompt=prompt,
            prompt_token_ids=tokenizer.encode(prompt,
                                              add_special_tokens=False),
            placeholder_token_id=image_token_id,
            repeat_count=repeat_count,
        )
        assert new_prompt == expected_prompt
        assert new_token_ids == expected_token_ids
193
        assert ranges == expected_ranges