test_utils.py 17.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
6
import base64
import mimetypes
7
8
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
9
10
11

import numpy as np
import pytest
12
import torch
13
from PIL import Image, ImageChops
14

15
from vllm.multimodal.image import convert_image_mode
16
from vllm.multimodal.inputs import PlaceholderRange
17
18
from vllm.multimodal.media import MediaConnector
from vllm.multimodal.utils import argsort_mm_positions
19
20

# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
21
TEST_IMAGE_ASSETS = [
22
23
24
25
    "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",  # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
    "Grayscale_8bits_palette_sample_image.png",  # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/Grayscale_8bits_palette_sample_image.png",
    "1280px-Venn_diagram_rgb.svg.png",  # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/1280px-Venn_diagram_rgb.svg.png",
    "RGBA_comp.png",  # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/RGBA_comp.png",
26
27
]

28
29
TEST_VIDEO_URLS = [
    "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4",
30
    "https://github.com/opencv/opencv/raw/refs/tags/4.12.0/samples/data/vtest.avi",
31
32
]

33

34
@pytest.fixture(scope="module")
35
def url_images(local_asset_server) -> dict[str, Image.Image]:
36
    return {
37
38
        image_url: local_asset_server.get_image_asset(image_url)
        for image_url in TEST_IMAGE_ASSETS
39
    }
40
41


42
def get_supported_suffixes() -> tuple[str, ...]:
43
    # We should at least test the file types mentioned in GPT-4 with Vision
44
    OPENAI_SUPPORTED_SUFFIXES = (".png", ".jpeg", ".jpg", ".webp", ".gif")
45
46

    # Additional file types that are supported by us
47
    EXTRA_SUPPORTED_SUFFIXES = (".bmp", ".tiff")
48
49
50
51
52

    return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES


def _image_equals(a: Image.Image, b: Image.Image) -> bool:
53
    return (np.asarray(a) == np.asarray(convert_image_mode(b, a.mode))).all()
54
55


56
@pytest.mark.asyncio
57
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
58
async def test_fetch_image_http(image_url: str):
59
60
61
62
    connector = MediaConnector()

    image_sync = connector.fetch_image(image_url)
    image_async = await connector.fetch_image_async(image_url)
63
64
65
    assert _image_equals(image_sync, image_async)


66
@pytest.mark.asyncio
67
@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS)
68
@pytest.mark.parametrize("suffix", get_supported_suffixes())
69
70
71
async def test_fetch_image_base64(
    url_images: dict[str, Image.Image], raw_image_url: str, suffix: str
):
72
73
74
75
76
    connector = MediaConnector(
        # Domain restriction should not apply to data URLs.
        allowed_media_domains=[
            "www.bogotobogo.com",
            "github.com",
77
78
        ]
    )
79
    url_image = url_images[raw_image_url]
80
81
82
83
84
85
86

    try:
        mime_type = Image.MIME[Image.registered_extensions()[suffix]]
    except KeyError:
        try:
            mime_type = mimetypes.types_map[suffix]
        except KeyError:
87
            pytest.skip("No MIME type")
88
89
90
91
92

    with NamedTemporaryFile(suffix=suffix) as f:
        try:
            url_image.save(f.name)
        except Exception as e:
93
94
            if e.args[0] == "cannot write mode RGBA as JPEG":
                pytest.skip("Conversion not supported")
95
96
97
98
99
100

            raise

        base64_image = base64.b64encode(f.read()).decode("utf-8")
        data_url = f"data:{mime_type};base64,{base64_image}"

101
        data_image_sync = connector.fetch_image(data_url)
102
        if _image_equals(url_image, Image.open(f)):
103
            assert _image_equals(url_image, data_image_sync)
104
105
        else:
            pass  # Lossy format; only check that image can be opened
106

107
        data_image_async = await connector.fetch_image_async(data_url)
108
        assert _image_equals(data_image_sync, data_image_async)
109
110


111
@pytest.mark.asyncio
112
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
113
async def test_fetch_image_local_files(image_url: str):
114
115
    connector = MediaConnector()

116
    with TemporaryDirectory() as temp_dir:
117
118
119
        local_connector = MediaConnector(allowed_local_media_path=temp_dir)

        origin_image = connector.fetch_image(image_url)
120
121
122
123
124
        origin_image.save(
            os.path.join(temp_dir, os.path.basename(image_url)),
            quality=100,
            icc_profile=origin_image.info.get("icc_profile"),
        )
125

126
        image_async = await local_connector.fetch_image_async(
127
128
            f"file://{temp_dir}/{os.path.basename(image_url)}"
        )
129
        image_sync = local_connector.fetch_image(
130
131
            f"file://{temp_dir}/{os.path.basename(image_url)}"
        )
132
133
134
        # Check that the images are equal
        assert not ImageChops.difference(image_sync, image_async).getbbox()

135
136
        with pytest.raises(ValueError, match="must be a subpath"):
            await local_connector.fetch_image_async(
137
138
                f"file://{temp_dir}/../{os.path.basename(image_url)}"
            )
139
140
        with pytest.raises(RuntimeError, match="Cannot load local files"):
            await connector.fetch_image_async(
141
142
                f"file://{temp_dir}/../{os.path.basename(image_url)}"
            )
143

144
145
        with pytest.raises(ValueError, match="must be a subpath"):
            local_connector.fetch_image(
146
147
                f"file://{temp_dir}/../{os.path.basename(image_url)}"
            )
148
        with pytest.raises(RuntimeError, match="Cannot load local files"):
149
            connector.fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
150
151


152
@pytest.mark.asyncio
153
154
@pytest.mark.parametrize("image_url", [TEST_IMAGE_ASSETS[0]], indirect=True)
async def test_fetch_image_local_files_with_space_in_name(image_url: str):
155
156
157
158
159
160
161
    connector = MediaConnector()

    with TemporaryDirectory() as temp_dir:
        local_connector = MediaConnector(allowed_local_media_path=temp_dir)

        origin_image = connector.fetch_image(image_url)
        filename = "file name with space.jpg"
162
163
164
165
166
        origin_image.save(
            os.path.join(temp_dir, filename),
            quality=100,
            icc_profile=origin_image.info.get("icc_profile"),
        )
167
168
169

        try:
            image_async = await local_connector.fetch_image_async(
170
171
172
                f"file://{temp_dir}/{filename}"
            )
            image_sync = local_connector.fetch_image(f"file://{temp_dir}/{filename}")
173
        except FileNotFoundError as e:
174
            pytest.fail("Failed to fetch image with space in name: {}".format(e))
175
176
177
178
        # Check that the images are equal
        assert not ImageChops.difference(image_sync, image_async).getbbox()


179
180
181
182
183
184
185
186
187
188
189
190
191
@pytest.mark.asyncio
async def test_fetch_image_error_conversion():
    connector = MediaConnector()
    broken_img = "data:image/png;base64,aGVsbG9fdmxsbV9jb21tdW5pdHkK"

    # PIL.UnidentifiedImageError should be converted to ValueError
    with pytest.raises(ValueError):
        await connector.fetch_image_async(broken_img)

    with pytest.raises(ValueError):
        connector.fetch_image(broken_img)


192
@pytest.mark.flaky(reruns=3, reruns_delay=5)
193
194
195
196
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_fetch_video_http(video_url: str, num_frames: int):
197
    connector = MediaConnector(
198
199
200
201
202
203
        media_io_kwargs={
            "video": {
                "num_frames": num_frames,
            }
        }
    )
204

205
206
207
208
209
210
    try:
        video_sync, metadata_sync = connector.fetch_video(video_url)
        video_async, metadata_async = await connector.fetch_video_async(video_url)
    except (TimeoutError, asyncio.TimeoutError) as e:
        pytest.skip(f"Timeout fetching video (CI network flakiness): {e}")

211
212
    assert np.array_equal(video_sync, video_async)
    assert metadata_sync == metadata_async
213
214


215
216
217
218
219
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("max_duration", [1, 60, 1800])
@pytest.mark.parametrize("requested_fps", [2, 24])
async def test_fetch_video_http_with_dynamic_loader(
220
221
222
223
224
    video_url: str,
    max_duration: int,
    requested_fps: int,
    monkeypatch: pytest.MonkeyPatch,
):
225
226
227
228
229
230
231
232
    with monkeypatch.context() as m:
        m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv_dynamic")
        connector = MediaConnector(
            media_io_kwargs={
                "video": {
                    "max_duration": max_duration,
                    "requested_fps": requested_fps,
                }
233
234
            }
        )
235
236

        video_sync, metadata_sync = connector.fetch_video(video_url)
237
        video_async, metadata_async = await connector.fetch_video_async(video_url)
238
239
240
241
242
243

        assert np.array_equal(video_sync, video_async)
        assert metadata_sync == metadata_async
        assert metadata_sync["video_backend"] == "opencv_dynamic"


244
245
246
@pytest.mark.parametrize(
    "case",
    [
247
248
        # Single modality
        ## Internally sorted
249
        dict(
250
251
252
253
254
255
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=3, length=2),
                ]
            },
256
257
258
            expected_modality_idxs=[
                ("image", 0),
                ("image", 1),
259
260
            ],
        ),
261
        ## Internally unsorted
262
        dict(
263
264
            mm_positions={
                "image": [
265
                    PlaceholderRange(offset=3, length=2),
266
267
268
                    PlaceholderRange(offset=0, length=2),
                ]
            },
269
270
271
            expected_modality_idxs=[
                ("image", 1),
                ("image", 0),
272
273
            ],
        ),
274
275
        # Two modalities
        ## Internally sorted
276
        dict(
277
278
279
280
281
282
283
284
            mm_positions={
                "image": [
                    PlaceholderRange(offset=7, length=4),
                    PlaceholderRange(offset=11, length=5),
                ],
                "audio": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=2, length=3),
285
                ],
286
            },
287
288
289
290
291
            expected_modality_idxs=[
                ("audio", 0),
                ("audio", 1),
                ("image", 0),
                ("image", 1),
292
            ],
293
294
        ),
        ## Interleaved, internally sorted
295
        dict(
296
297
298
299
300
301
302
303
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=4),
                    PlaceholderRange(offset=8, length=2),
                ],
                "audio": [
                    PlaceholderRange(offset=5, length=2),
                    PlaceholderRange(offset=11, length=4),
304
                ],
305
306
307
308
309
310
            },
            expected_modality_idxs=[
                ("image", 0),
                ("audio", 0),
                ("image", 1),
                ("audio", 1),
311
312
            ],
        ),
313
        ## Interleaved, internally unsorted
314
        dict(
315
316
            mm_positions={
                "image": [
317
318
                    PlaceholderRange(offset=8, length=2),
                    PlaceholderRange(offset=0, length=4),
319
320
                ],
                "audio": [
321
322
                    PlaceholderRange(offset=11, length=4),
                    PlaceholderRange(offset=5, length=2),
323
                ],
324
            },
325
326
327
328
329
            expected_modality_idxs=[
                ("image", 1),
                ("audio", 1),
                ("image", 0),
                ("audio", 0),
330
331
332
            ],
        ),
        # Three modalities
333
        ## Internally sorted
334
        dict(
335
336
337
338
339
340
341
342
343
344
345
346
            mm_positions={
                "image": [
                    PlaceholderRange(offset=15, length=7),
                    PlaceholderRange(offset=22, length=8),
                ],
                "audio": [
                    PlaceholderRange(offset=0, length=2),
                ],
                "video": [
                    PlaceholderRange(offset=3, length=4),
                    PlaceholderRange(offset=7, length=5),
                    PlaceholderRange(offset=12, length=6),
347
                ],
348
            },
349
350
351
352
353
354
355
            expected_modality_idxs=[
                ("audio", 0),
                ("video", 0),
                ("video", 1),
                ("video", 2),
                ("image", 0),
                ("image", 1),
356
            ],
357
        ),
358
        ## Interleaved, internally sorted
359
        dict(
360
361
362
363
364
365
366
367
368
369
370
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=2, length=3),
                    PlaceholderRange(offset=20, length=4),
                ],
                "audio": [
                    PlaceholderRange(offset=5, length=2),
                ],
                "video": [
                    PlaceholderRange(offset=8, length=5),
371
                ],
372
            },
373
374
375
376
377
378
            expected_modality_idxs=[
                ("image", 0),
                ("image", 1),
                ("audio", 0),
                ("video", 0),
                ("image", 2),
379
            ],
380
        ),
381
382
        ## Interleaved, internally unsorted
        dict(
383
384
385
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
386
387
                    PlaceholderRange(offset=20, length=4),
                    PlaceholderRange(offset=2, length=3),
388
389
                ],
                "audio": [
390
                    PlaceholderRange(offset=5, length=2),
391
392
                ],
                "video": [
393
                    PlaceholderRange(offset=8, length=5),
394
                ],
395
            },
396
397
398
399
400
401
            expected_modality_idxs=[
                ("image", 0),
                ("image", 2),
                ("audio", 0),
                ("video", 0),
                ("image", 1),
402
403
            ],
        ),
404
405
406
407
408
    ],
)
def test_argsort_mm_positions(case):
    mm_positions = case["mm_positions"]
    expected_modality_idxs = case["expected_modality_idxs"]
409

410
    modality_idxs = argsort_mm_positions(mm_positions)
411

412
    assert modality_idxs == expected_modality_idxs
413
414


415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
@pytest.mark.parametrize(
    "is_embed,expected",
    [
        (None, 5),
        (torch.tensor([True, True, True, True, True]), 5),
        (torch.tensor([False, False, False, False, False]), 0),
        (torch.tensor([True, False, True, False, True]), 3),
        (torch.tensor([True]), 1),
    ],
)
def test_placeholder_range_get_num_embeds(is_embed, expected):
    length = len(is_embed) if is_embed is not None else 5
    pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
    assert pr.get_num_embeds == expected


@pytest.mark.parametrize(
    "is_embed,expected",
    [
        (None, None),
        (
            torch.tensor([False, True, False, True, True]),
            torch.tensor([0, 1, 1, 2, 3]),
        ),
        (torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
    ],
)
def test_placeholder_range_embeds_cumsum(is_embed, expected):
    length = len(is_embed) if is_embed is not None else 5
    pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)

    if expected is None:
        assert pr.embeds_cumsum is None
        return

    assert torch.equal(pr.embeds_cumsum, expected)
    # cached_property should return the same object on repeated access
    assert pr.embeds_cumsum is pr.embeds_cumsum


@pytest.mark.parametrize(
    "is_embed,start_idx,end_idx,expected",
    [
        (None, 2, 4, (2, 4)),
        (
            torch.tensor([False, True, False, True, True]),
            3,
            5,
            (1, 3),
        ),
        (
            torch.tensor([False, True, False, True, True]),
            0,
            2,
            (0, 1),
        ),
        (
            torch.tensor([True, False, True, False]),
            2,
            2,
            (1, 1),
        ),
    ],
)
def test_placeholder_range_get_embeds_indices_in_range(
    is_embed, start_idx, end_idx, expected
):
    length = len(is_embed) if is_embed is not None else 5
    pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
    assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected


@pytest.mark.parametrize(
    "offset,is_embed,expected",
    [
        (0, None, [(0, 4)]),
        (
            2,
            torch.tensor([False, True, False, True, True]),
            [(3, 3), (5, 6)],
        ),
        (0, torch.tensor([True, True, True, True]), [(0, 3)]),
        (0, torch.tensor([False, False, False, False]), []),
    ],
)
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
    length = len(is_embed) if is_embed is not None else 5
    pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
    assert pr.extract_embeds_range() == expected


506
507
508
509
510
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_allowed_media_domains(video_url: str, num_frames: int):
    connector = MediaConnector(
511
512
513
514
515
        media_io_kwargs={
            "video": {
                "num_frames": num_frames,
            }
        },
516
517
518
        allowed_media_domains=[
            "www.bogotobogo.com",
            "github.com",
519
520
        ],
    )
521
522
523
524
525
526
527
528
529
530
531
532

    video_sync, metadata_sync = connector.fetch_video(video_url)
    video_async, metadata_async = await connector.fetch_video_async(video_url)
    assert np.array_equal(video_sync, video_async)
    assert metadata_sync == metadata_async

    disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
    with pytest.raises(ValueError):
        _, _ = connector.fetch_video(disallowed_url)

    with pytest.raises(ValueError):
        _, _ = await connector.fetch_video_async(disallowed_url)