test_utils.py 14.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import base64
import mimetypes
6
7
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
8
from typing import TYPE_CHECKING, NamedTuple
9
10
11

import numpy as np
import pytest
12
13
import torch
import torch.multiprocessing as mp
14
from PIL import Image, ImageChops
15

16
17
18
19
from tests.utils import multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (init_distributed_environment,
                                             initialize_model_parallel)
20
from vllm.multimodal.image import convert_image_mode
21
from vllm.multimodal.inputs import PlaceholderRange
22
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
23
24
25
                                   run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables
26

27
28
29
if TYPE_CHECKING:
    from vllm.multimodal.inputs import MultiModalPlaceholderDict

30
31
32
33
34
35
36
37
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
    "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
    "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
    "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
    "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]

38
39
TEST_VIDEO_URLS = [
    "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4",
40
    "https://github.com/opencv/opencv/raw/refs/tags/4.12.0/samples/data/vtest.avi",
41
42
]

43

44
@pytest.fixture(scope="module")
45
def url_images() -> dict[str, Image.Image]:
46
47
48
49
50
51
    connector = MediaConnector()

    return {
        image_url: connector.fetch_image(image_url)
        for image_url in TEST_IMAGE_URLS
    }
52
53


54
def get_supported_suffixes() -> tuple[str, ...]:
55
56
57
58
59
60
61
62
63
64
    # We should at least test the file types mentioned in GPT-4 with Vision
    OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif')

    # Additional file types that are supported by us
    EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff')

    return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES


def _image_equals(a: Image.Image, b: Image.Image) -> bool:
65
    return (np.asarray(a) == np.asarray(convert_image_mode(b, a.mode))).all()
66
67


68
@pytest.mark.asyncio
69
70
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str):
71
72
73
74
    connector = MediaConnector()

    image_sync = connector.fetch_image(image_url)
    image_async = await connector.fetch_image_async(image_url)
75
76
77
    assert _image_equals(image_sync, image_async)


78
@pytest.mark.asyncio
79
80
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("suffix", get_supported_suffixes())
81
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
82
                                  image_url: str, suffix: str):
83
    connector = MediaConnector()
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    url_image = url_images[image_url]

    try:
        mime_type = Image.MIME[Image.registered_extensions()[suffix]]
    except KeyError:
        try:
            mime_type = mimetypes.types_map[suffix]
        except KeyError:
            pytest.skip('No MIME type')

    with NamedTemporaryFile(suffix=suffix) as f:
        try:
            url_image.save(f.name)
        except Exception as e:
            if e.args[0] == 'cannot write mode RGBA as JPEG':
                pytest.skip('Conversion not supported')

            raise

        base64_image = base64.b64encode(f.read()).decode("utf-8")
        data_url = f"data:{mime_type};base64,{base64_image}"

106
        data_image_sync = connector.fetch_image(data_url)
107
        if _image_equals(url_image, Image.open(f)):
108
            assert _image_equals(url_image, data_image_sync)
109
110
        else:
            pass  # Lossy format; only check that image can be opened
111

112
        data_image_async = await connector.fetch_image_async(data_url)
113
        assert _image_equals(data_image_sync, data_image_async)
114
115


116
117
118
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_local_files(image_url: str):
119
120
    connector = MediaConnector()

121
    with TemporaryDirectory() as temp_dir:
122
123
124
        local_connector = MediaConnector(allowed_local_media_path=temp_dir)

        origin_image = connector.fetch_image(image_url)
125
126
127
128
        origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
                          quality=100,
                          icc_profile=origin_image.info.get('icc_profile'))

129
130
131
132
        image_async = await local_connector.fetch_image_async(
            f"file://{temp_dir}/{os.path.basename(image_url)}")
        image_sync = local_connector.fetch_image(
            f"file://{temp_dir}/{os.path.basename(image_url)}")
133
134
135
        # Check that the images are equal
        assert not ImageChops.difference(image_sync, image_async).getbbox()

136
137
138
139
140
        with pytest.raises(ValueError, match="must be a subpath"):
            await local_connector.fetch_image_async(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
        with pytest.raises(RuntimeError, match="Cannot load local files"):
            await connector.fetch_image_async(
141
142
                f"file://{temp_dir}/../{os.path.basename(image_url)}")

143
144
145
146
147
148
        with pytest.raises(ValueError, match="must be a subpath"):
            local_connector.fetch_image(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
        with pytest.raises(RuntimeError, match="Cannot load local files"):
            connector.fetch_image(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
149
150


151
152
153
154
155
156
157
158
159
160
161
162
163
@pytest.mark.asyncio
async def test_fetch_image_error_conversion():
    connector = MediaConnector()
    broken_img = "data:image/png;base64,aGVsbG9fdmxsbV9jb21tdW5pdHkK"

    # PIL.UnidentifiedImageError should be converted to ValueError
    with pytest.raises(ValueError):
        await connector.fetch_image_async(broken_img)

    with pytest.raises(ValueError):
        connector.fetch_image(broken_img)


164
165
166
167
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_fetch_video_http(video_url: str, num_frames: int):
168
169
170
171
    connector = MediaConnector(
        media_io_kwargs={"video": {
            "num_frames": num_frames,
        }})
172

173
174
175
176
    video_sync, metadata_sync = connector.fetch_video(video_url)
    video_async, metadata_async = await connector.fetch_video_async(video_url)
    assert np.array_equal(video_sync, video_async)
    assert metadata_sync == metadata_async
177
178


179
# Used for `test_argsort_mm_positions`.
180
181
class TestCase(NamedTuple):
    mm_positions: "MultiModalPlaceholderDict"
182
    expected_modality_idxs: list[tuple[str, int]]
183
184


185
def test_argsort_mm_positions():
186
187

    test_cases = [
188
189
        # Single modality
        ## Internally sorted
190
191
192
193
194
195
196
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=3, length=2),
                ]
            },
197
198
199
            expected_modality_idxs=[
                ("image", 0),
                ("image", 1),
200
201
            ],
        ),
202
        ## Internally unsorted
203
204
205
        TestCase(
            mm_positions={
                "image": [
206
                    PlaceholderRange(offset=3, length=2),
207
208
209
                    PlaceholderRange(offset=0, length=2),
                ]
            },
210
211
212
            expected_modality_idxs=[
                ("image", 1),
                ("image", 0),
213
214
215
            ],
        ),

216
217
        # Two modalities
        ## Internally sorted
218
219
220
221
222
223
224
225
226
227
228
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=7, length=4),
                    PlaceholderRange(offset=11, length=5),
                ],
                "audio": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=2, length=3),
                ]
            },
229
230
231
232
233
            expected_modality_idxs=[
                ("audio", 0),
                ("audio", 1),
                ("image", 0),
                ("image", 1),
234
            ],
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        ),
        ## Interleaved, internally sorted
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=4),
                    PlaceholderRange(offset=8, length=2),
                ],
                "audio": [
                    PlaceholderRange(offset=5, length=2),
                    PlaceholderRange(offset=11, length=4),
                ]
            },
            expected_modality_idxs=[
                ("image", 0),
                ("audio", 0),
                ("image", 1),
                ("audio", 1),
253
254
            ],
        ),
255
        ## Interleaved, internally unsorted
256
257
258
        TestCase(
            mm_positions={
                "image": [
259
260
                    PlaceholderRange(offset=8, length=2),
                    PlaceholderRange(offset=0, length=4),
261
262
                ],
                "audio": [
263
264
                    PlaceholderRange(offset=11, length=4),
                    PlaceholderRange(offset=5, length=2),
265
266
                ]
            },
267
268
269
270
271
            expected_modality_idxs=[
                ("image", 1),
                ("audio", 1),
                ("image", 0),
                ("audio", 0),
272
273
274
275
            ],
        ),

        # Three modalities
276
        ## Internally sorted
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=15, length=7),
                    PlaceholderRange(offset=22, length=8),
                ],
                "audio": [
                    PlaceholderRange(offset=0, length=2),
                ],
                "video": [
                    PlaceholderRange(offset=3, length=4),
                    PlaceholderRange(offset=7, length=5),
                    PlaceholderRange(offset=12, length=6),
                ]
            },
292
293
294
295
296
297
298
            expected_modality_idxs=[
                ("audio", 0),
                ("video", 0),
                ("video", 1),
                ("video", 2),
                ("image", 0),
                ("image", 1),
299
            ],
300
        ),
301
        ## Interleaved, internally sorted
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=2, length=3),
                    PlaceholderRange(offset=20, length=4),
                ],
                "audio": [
                    PlaceholderRange(offset=5, length=2),
                ],
                "video": [
                    PlaceholderRange(offset=8, length=5),
                ]
            },
316
317
318
319
320
321
            expected_modality_idxs=[
                ("image", 0),
                ("image", 1),
                ("audio", 0),
                ("video", 0),
                ("image", 2),
322
            ],
323
        ),
324
        ## Interleaved, internally sunorted
325
326
327
328
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
329
330
                    PlaceholderRange(offset=20, length=4),
                    PlaceholderRange(offset=2, length=3),
331
332
                ],
                "audio": [
333
                    PlaceholderRange(offset=5, length=2),
334
335
                ],
                "video": [
336
                    PlaceholderRange(offset=8, length=5),
337
338
                ]
            },
339
340
341
342
343
344
            expected_modality_idxs=[
                ("image", 0),
                ("image", 2),
                ("audio", 0),
                ("video", 0),
                ("image", 1),
345
346
            ],
        ),
347
348
    ]

349
350
    for mm_positions, expected_modality_idxs in test_cases:
        modality_idxs = argsort_mm_positions(mm_positions)
351

352
        assert modality_idxs == expected_modality_idxs
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439


class SimpleLinearModel(torch.nn.Module):
    """A simple linear vision model for testing."""

    def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
        super().__init__()
        self.flatten = torch.nn.Flatten()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x: torch.Tensor):
        # Flatten the input and apply linear transformation
        x = self.flatten(x)
        return self.linear(x)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
    "batch_size",
    [
        1,  # Single image
        4,  # Small batch
        5,  # Odd batch size (for testing padding)
    ],
)
def test_run_dp_sharded_vision_model(batch_size: int):
    world_size = 2
    # Launch processes
    mp.spawn(
        run_dp_sharded_vision_model_vs_direct,
        args=(
            world_size,
            batch_size,
            get_open_port(),
        ),
        nprocs=world_size,
    )


def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
                                          batch_size: int, master_port: int):
    """
    Test that run_dp_sharded_vision_model produces the same results as 
    calling the model directly.
    """

    # Set random seed for reproducibility
    current_platform.seed_everything(0)

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    torch.set_default_device(device)

    update_environment_variables({
        'RANK': str(local_rank),
        'LOCAL_RANK': str(local_rank),
        'WORLD_SIZE': str(world_size),
        'MASTER_ADDR': 'localhost',
        'MASTER_PORT': str(master_port),
    })

    # initialize distributed
    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # Create a test input tensor
    image_input = torch.randn(batch_size, 3, 224, 224)

    # Create a simple linear model
    vision_model = SimpleLinearModel()

    # Run the model directly on the full input
    with torch.inference_mode():
        direct_output = vision_model(image_input)

    # Run the model through the sharded function
    with torch.inference_mode():
        sharded_output = run_dp_sharded_vision_model(image_input, vision_model)

    # Check that the world size is setup correctly
    assert get_tensor_model_parallel_world_size() == world_size

    # Check that the outputs have the same shape
    assert direct_output.shape == sharded_output.shape

    # Check that the outputs are close (they should be identical)
    assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)