test_utils.py 28.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import base64
5
import math
6
import mimetypes
7
8
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
9
from typing import TYPE_CHECKING, NamedTuple
10
11
12

import numpy as np
import pytest
13
14
import torch
import torch.multiprocessing as mp
15
from PIL import Image, ImageChops
16

17
18
19
20
from tests.utils import multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (init_distributed_environment,
                                             initialize_model_parallel)
21
from vllm.multimodal.image import convert_image_mode
22
from vllm.multimodal.inputs import PlaceholderRange
23
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
24
25
                                   get_load_balance_assignment,
                                   run_dp_sharded_mrope_vision_model,
26
27
28
                                   run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables
29

30
31
32
if TYPE_CHECKING:
    from vllm.multimodal.inputs import MultiModalPlaceholderDict

33
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
34
35
36
37
38
TEST_IMAGE_ASSETS = [
    "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",  # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
    "Grayscale_8bits_palette_sample_image.png",  # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
    "1280px-Venn_diagram_rgb.svg.png",  # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
    "RGBA_comp.png",  # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
39
40
]

41
42
TEST_VIDEO_URLS = [
    "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4",
43
    "https://github.com/opencv/opencv/raw/refs/tags/4.12.0/samples/data/vtest.avi",
44
45
]

46

47
@pytest.fixture(scope="module")
48
def url_images(local_asset_server) -> dict[str, Image.Image]:
49
50

    return {
51
52
        image_url: local_asset_server.get_image_asset(image_url)
        for image_url in TEST_IMAGE_ASSETS
53
    }
54
55


56
def get_supported_suffixes() -> tuple[str, ...]:
57
58
59
60
61
62
63
64
65
66
    # We should at least test the file types mentioned in GPT-4 with Vision
    OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif')

    # Additional file types that are supported by us
    EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff')

    return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES


def _image_equals(a: Image.Image, b: Image.Image) -> bool:
67
    return (np.asarray(a) == np.asarray(convert_image_mode(b, a.mode))).all()
68
69


70
@pytest.mark.asyncio
71
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
72
async def test_fetch_image_http(image_url: str):
73
74
75
76
    connector = MediaConnector()

    image_sync = connector.fetch_image(image_url)
    image_async = await connector.fetch_image_async(image_url)
77
78
79
    assert _image_equals(image_sync, image_async)


80
@pytest.mark.asyncio
81
@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS)
82
@pytest.mark.parametrize("suffix", get_supported_suffixes())
83
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
84
                                  raw_image_url: str, suffix: str):
85
    connector = MediaConnector()
86
    url_image = url_images[raw_image_url]
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    try:
        mime_type = Image.MIME[Image.registered_extensions()[suffix]]
    except KeyError:
        try:
            mime_type = mimetypes.types_map[suffix]
        except KeyError:
            pytest.skip('No MIME type')

    with NamedTemporaryFile(suffix=suffix) as f:
        try:
            url_image.save(f.name)
        except Exception as e:
            if e.args[0] == 'cannot write mode RGBA as JPEG':
                pytest.skip('Conversion not supported')

            raise

        base64_image = base64.b64encode(f.read()).decode("utf-8")
        data_url = f"data:{mime_type};base64,{base64_image}"

108
        data_image_sync = connector.fetch_image(data_url)
109
        if _image_equals(url_image, Image.open(f)):
110
            assert _image_equals(url_image, data_image_sync)
111
112
        else:
            pass  # Lossy format; only check that image can be opened
113

114
        data_image_async = await connector.fetch_image_async(data_url)
115
        assert _image_equals(data_image_sync, data_image_async)
116
117


118
@pytest.mark.asyncio
119
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
120
async def test_fetch_image_local_files(image_url: str):
121
122
    connector = MediaConnector()

123
    with TemporaryDirectory() as temp_dir:
124
125
126
        local_connector = MediaConnector(allowed_local_media_path=temp_dir)

        origin_image = connector.fetch_image(image_url)
127
128
129
130
        origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
                          quality=100,
                          icc_profile=origin_image.info.get('icc_profile'))

131
132
133
134
        image_async = await local_connector.fetch_image_async(
            f"file://{temp_dir}/{os.path.basename(image_url)}")
        image_sync = local_connector.fetch_image(
            f"file://{temp_dir}/{os.path.basename(image_url)}")
135
136
137
        # Check that the images are equal
        assert not ImageChops.difference(image_sync, image_async).getbbox()

138
139
140
141
142
        with pytest.raises(ValueError, match="must be a subpath"):
            await local_connector.fetch_image_async(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
        with pytest.raises(RuntimeError, match="Cannot load local files"):
            await connector.fetch_image_async(
143
144
                f"file://{temp_dir}/../{os.path.basename(image_url)}")

145
146
147
148
149
150
        with pytest.raises(ValueError, match="must be a subpath"):
            local_connector.fetch_image(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
        with pytest.raises(RuntimeError, match="Cannot load local files"):
            connector.fetch_image(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
151
152


153
@pytest.mark.asyncio
154
155
@pytest.mark.parametrize("image_url", [TEST_IMAGE_ASSETS[0]], indirect=True)
async def test_fetch_image_local_files_with_space_in_name(image_url: str):
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    connector = MediaConnector()

    with TemporaryDirectory() as temp_dir:
        local_connector = MediaConnector(allowed_local_media_path=temp_dir)

        origin_image = connector.fetch_image(image_url)
        filename = "file name with space.jpg"
        origin_image.save(os.path.join(temp_dir, filename),
                          quality=100,
                          icc_profile=origin_image.info.get('icc_profile'))

        try:
            image_async = await local_connector.fetch_image_async(
                f"file://{temp_dir}/{filename}")
            image_sync = local_connector.fetch_image(
                f"file://{temp_dir}/{filename}")
        except FileNotFoundError as e:
            pytest.fail(
                "Failed to fetch image with space in name: {}".format(e))
        # Check that the images are equal
        assert not ImageChops.difference(image_sync, image_async).getbbox()


179
180
181
182
183
184
185
186
187
188
189
190
191
@pytest.mark.asyncio
async def test_fetch_image_error_conversion():
    connector = MediaConnector()
    broken_img = "data:image/png;base64,aGVsbG9fdmxsbV9jb21tdW5pdHkK"

    # PIL.UnidentifiedImageError should be converted to ValueError
    with pytest.raises(ValueError):
        await connector.fetch_image_async(broken_img)

    with pytest.raises(ValueError):
        connector.fetch_image(broken_img)


192
193
194
195
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_fetch_video_http(video_url: str, num_frames: int):
196
197
198
199
    connector = MediaConnector(
        media_io_kwargs={"video": {
            "num_frames": num_frames,
        }})
200

201
202
203
204
    video_sync, metadata_sync = connector.fetch_video(video_url)
    video_async, metadata_async = await connector.fetch_video_async(video_url)
    assert np.array_equal(video_sync, video_async)
    assert metadata_sync == metadata_async
205
206


207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("max_duration", [1, 60, 1800])
@pytest.mark.parametrize("requested_fps", [2, 24])
async def test_fetch_video_http_with_dynamic_loader(
        video_url: str, max_duration: int, requested_fps: int,
        monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m:
        m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv_dynamic")
        connector = MediaConnector(
            media_io_kwargs={
                "video": {
                    "max_duration": max_duration,
                    "requested_fps": requested_fps,
                }
            })

        video_sync, metadata_sync = connector.fetch_video(video_url)
        video_async, metadata_async = await connector.fetch_video_async(
            video_url)

        assert np.array_equal(video_sync, video_async)
        assert metadata_sync == metadata_async
        assert metadata_sync["video_backend"] == "opencv_dynamic"


233
# Used for `test_argsort_mm_positions`.
234
235
class TestCase(NamedTuple):
    mm_positions: "MultiModalPlaceholderDict"
236
    expected_modality_idxs: list[tuple[str, int]]
237
238


239
def test_argsort_mm_positions():
240
241

    test_cases = [
242
243
        # Single modality
        ## Internally sorted
244
245
246
247
248
249
250
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=3, length=2),
                ]
            },
251
252
253
            expected_modality_idxs=[
                ("image", 0),
                ("image", 1),
254
255
            ],
        ),
256
        ## Internally unsorted
257
258
259
        TestCase(
            mm_positions={
                "image": [
260
                    PlaceholderRange(offset=3, length=2),
261
262
263
                    PlaceholderRange(offset=0, length=2),
                ]
            },
264
265
266
            expected_modality_idxs=[
                ("image", 1),
                ("image", 0),
267
268
269
            ],
        ),

270
271
        # Two modalities
        ## Internally sorted
272
273
274
275
276
277
278
279
280
281
282
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=7, length=4),
                    PlaceholderRange(offset=11, length=5),
                ],
                "audio": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=2, length=3),
                ]
            },
283
284
285
286
287
            expected_modality_idxs=[
                ("audio", 0),
                ("audio", 1),
                ("image", 0),
                ("image", 1),
288
            ],
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        ),
        ## Interleaved, internally sorted
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=4),
                    PlaceholderRange(offset=8, length=2),
                ],
                "audio": [
                    PlaceholderRange(offset=5, length=2),
                    PlaceholderRange(offset=11, length=4),
                ]
            },
            expected_modality_idxs=[
                ("image", 0),
                ("audio", 0),
                ("image", 1),
                ("audio", 1),
307
308
            ],
        ),
309
        ## Interleaved, internally unsorted
310
311
312
        TestCase(
            mm_positions={
                "image": [
313
314
                    PlaceholderRange(offset=8, length=2),
                    PlaceholderRange(offset=0, length=4),
315
316
                ],
                "audio": [
317
318
                    PlaceholderRange(offset=11, length=4),
                    PlaceholderRange(offset=5, length=2),
319
320
                ]
            },
321
322
323
324
325
            expected_modality_idxs=[
                ("image", 1),
                ("audio", 1),
                ("image", 0),
                ("audio", 0),
326
327
328
329
            ],
        ),

        # Three modalities
330
        ## Internally sorted
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=15, length=7),
                    PlaceholderRange(offset=22, length=8),
                ],
                "audio": [
                    PlaceholderRange(offset=0, length=2),
                ],
                "video": [
                    PlaceholderRange(offset=3, length=4),
                    PlaceholderRange(offset=7, length=5),
                    PlaceholderRange(offset=12, length=6),
                ]
            },
346
347
348
349
350
351
352
            expected_modality_idxs=[
                ("audio", 0),
                ("video", 0),
                ("video", 1),
                ("video", 2),
                ("image", 0),
                ("image", 1),
353
            ],
354
        ),
355
        ## Interleaved, internally sorted
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=2, length=3),
                    PlaceholderRange(offset=20, length=4),
                ],
                "audio": [
                    PlaceholderRange(offset=5, length=2),
                ],
                "video": [
                    PlaceholderRange(offset=8, length=5),
                ]
            },
370
371
372
373
374
375
            expected_modality_idxs=[
                ("image", 0),
                ("image", 1),
                ("audio", 0),
                ("video", 0),
                ("image", 2),
376
            ],
377
        ),
378
        ## Interleaved, internally sunorted
379
380
381
382
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
383
384
                    PlaceholderRange(offset=20, length=4),
                    PlaceholderRange(offset=2, length=3),
385
386
                ],
                "audio": [
387
                    PlaceholderRange(offset=5, length=2),
388
389
                ],
                "video": [
390
                    PlaceholderRange(offset=8, length=5),
391
392
                ]
            },
393
394
395
396
397
398
            expected_modality_idxs=[
                ("image", 0),
                ("image", 2),
                ("audio", 0),
                ("video", 0),
                ("image", 1),
399
400
            ],
        ),
401
402
    ]

403
404
    for mm_positions, expected_modality_idxs in test_cases:
        modality_idxs = argsort_mm_positions(mm_positions)
405

406
        assert modality_idxs == expected_modality_idxs
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455


class SimpleLinearModel(torch.nn.Module):
    """A simple linear vision model for testing."""

    def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
        super().__init__()
        self.flatten = torch.nn.Flatten()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x: torch.Tensor):
        # Flatten the input and apply linear transformation
        x = self.flatten(x)
        return self.linear(x)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
    "batch_size",
    [
        1,  # Single image
        4,  # Small batch
        5,  # Odd batch size (for testing padding)
    ],
)
def test_run_dp_sharded_vision_model(batch_size: int):
    world_size = 2
    # Launch processes
    mp.spawn(
        run_dp_sharded_vision_model_vs_direct,
        args=(
            world_size,
            batch_size,
            get_open_port(),
        ),
        nprocs=world_size,
    )


def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
                                          batch_size: int, master_port: int):
    """
    Test that run_dp_sharded_vision_model produces the same results as 
    calling the model directly.
    """

    # Set random seed for reproducibility
    current_platform.seed_everything(0)

456
457
    device = f"{current_platform.device_name}:{local_rank}"
    current_platform.set_device(device)
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    torch.set_default_device(device)

    update_environment_variables({
        'RANK': str(local_rank),
        'LOCAL_RANK': str(local_rank),
        'WORLD_SIZE': str(world_size),
        'MASTER_ADDR': 'localhost',
        'MASTER_PORT': str(master_port),
    })

    # initialize distributed
    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # Create a test input tensor
    image_input = torch.randn(batch_size, 3, 224, 224)

    # Create a simple linear model
    vision_model = SimpleLinearModel()

    # Run the model directly on the full input
    with torch.inference_mode():
        direct_output = vision_model(image_input)

    # Run the model through the sharded function
    with torch.inference_mode():
        sharded_output = run_dp_sharded_vision_model(image_input, vision_model)

486
    # Check that the world size is set up correctly
487
488
489
490
491
492
493
    assert get_tensor_model_parallel_world_size() == world_size

    # Check that the outputs have the same shape
    assert direct_output.shape == sharded_output.shape

    # Check that the outputs are close (they should be identical)
    assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663


@pytest.mark.parametrize(
    "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
    "expected_grouped_sizes_per_gpu,test_description",
    [
        # Empty input
        ([], 2, [], [0, 0], [0, 0], "empty input"),

        # Fewer samples than GPUs
        ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
                                               ], "fewer samples than GPUs"),

        # Single GPU
        ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),

        # Balanced assignment
        ([100, 100, 100, 100
          ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),

        # Unbalanced sizes - this one is trickier since the algorithm is greedy
        ([1000, 100, 200, 50], 2, [0, 2, 1, 3
                                   ], [1, 3], [1000, 350], "unbalanced sizes"),
    ],
)
def test_get_load_balance_assignment_cases(sizes, num_gpus,
                                           expected_shuffle_indices,
                                           expected_gpu_sample_counts,
                                           expected_grouped_sizes_per_gpu,
                                           test_description):
    """Test get_load_balance_assignment with various input cases."""
    result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
    (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result

    # Common assertions for all cases
    assert len(shuffle_indices) == len(sizes)
    assert len(gpu_sample_counts) == num_gpus
    assert len(grouped_sizes_per_gpu) == num_gpus
    assert sum(gpu_sample_counts) == len(sizes)

    assert shuffle_indices == expected_shuffle_indices

    assert gpu_sample_counts == expected_gpu_sample_counts
    assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu


class SimpleMRopeVisionModel(torch.nn.Module):
    """A simple vision model for testing mrope functionality."""

    def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
        super().__init__()
        self.spatial_merge_size = spatial_merge_size
        self.out_hidden_size = out_hidden_size
        self.linear = torch.nn.Linear(768, out_hidden_size)

    def forward(self, pixel_values: torch.Tensor,
                grid_thw_list: list[list[int]]):
        """Simple forward pass that simulates spatial merging."""
        # Apply linear transformation
        embeddings = self.linear(pixel_values)

        # Simulate spatial merging by reducing the number of patches
        merge_factor = self.spatial_merge_size * self.spatial_merge_size

        # Group patches and merge spatially
        merged_embeddings = []
        start_idx = 0

        for grid_thw in grid_thw_list:
            num_patches = math.prod(grid_thw)
            end_idx = start_idx + num_patches

            # Get patches for this image
            image_patches = embeddings[start_idx:end_idx]

            # Simulate spatial merging by averaging groups of patches
            merged_patches = num_patches // merge_factor
            if merged_patches > 0:
                # Reshape and average to simulate merging
                reshaped = image_patches[:merged_patches * merge_factor].view(
                    merged_patches, merge_factor, -1)
                merged = reshaped.mean(dim=1)
                merged_embeddings.append(merged)

            start_idx = end_idx

        if merged_embeddings:
            return torch.cat(merged_embeddings, dim=0)
        else:
            return torch.empty((0, self.out_hidden_size),
                               device=pixel_values.device,
                               dtype=pixel_values.dtype)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
    "batch_size",
    [
        1,  # Single image
        3,  # Small batch
        5,  # Odd batch size (for testing padding)
    ],
)
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
    world_size = 2
    # Launch processes
    mp.spawn(
        run_dp_sharded_mrope_vision_model_vs_direct,
        args=(
            world_size,
            batch_size,
            get_open_port(),
        ),
        nprocs=world_size,
    )


def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
                                                world_size: int,
                                                batch_size: int,
                                                master_port: int):
    """
    Test that run_dp_sharded_mrope_vision_model produces the same results as 
    calling the model directly.
    """
    # Set random seed for reproducibility
    current_platform.seed_everything(0)
    device = f"{current_platform.device_name}:{local_rank}"
    current_platform.set_device(device)
    torch.set_default_device(device)

    update_environment_variables({
        'RANK': str(local_rank),
        'LOCAL_RANK': str(local_rank),
        'WORLD_SIZE': str(world_size),
        'MASTER_ADDR': 'localhost',
        'MASTER_PORT': str(master_port),
    })

    # initialize distributed
    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # Create test data
    grid_thw_list = []
    pixel_values_list = []

    for i in range(batch_size):
        # Varying image sizes for better testing
        t, h, w = 1, 4 + i, 4 + i
        grid_thw_list.append([t, h, w])

        num_patches = t * h * w
        # Create random pixel values for this image
        image_pixels = torch.randn(num_patches, 768)
        pixel_values_list.append(image_pixels)

    # Concatenate all pixel values
    pixel_values = torch.cat(pixel_values_list, dim=0)

    # Create a simple mrope vision model
    vision_model = SimpleMRopeVisionModel()

    # Run the model directly on the full input (only on rank 0)
    if local_rank == 0:
        with torch.inference_mode():
            direct_output = vision_model(pixel_values, grid_thw_list)

    # Run the model through the sharded function
    with torch.inference_mode():
664
665
666
667
        sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
                                                           pixel_values,
                                                           grid_thw_list,
                                                           rope_type="rope_3d")
668
669
        sharded_output = torch.cat(sharded_output, dim=0)

670
    # Check that the world size is set up correctly
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    assert get_tensor_model_parallel_world_size() == world_size

    # Compare outputs (only on rank 0)
    if local_rank == 0:
        # Check that the outputs have the same shape
        assert direct_output.shape == sharded_output.shape
        # Check that the outputs are close (they should be identical)
        assert torch.allclose(direct_output,
                              sharded_output,
                              rtol=1e-5,
                              atol=1e-5)


@multi_gpu_test(num_gpus=2)
def test_run_dp_sharded_mrope_vision_model_empty_input():
    world_size = 2
    mp.spawn(
        run_dp_sharded_mrope_vision_model_empty_input_worker,
        args=(world_size, get_open_port()),
        nprocs=world_size,
    )


def run_dp_sharded_mrope_vision_model_empty_input_worker(
        local_rank: int, world_size: int, master_port: int):
    """Test run_dp_sharded_mrope_vision_model with empty input."""
    # Set up distributed environment
    device = f"{current_platform.device_name}:{local_rank}"
    current_platform.set_device(device)
    torch.set_default_device(device)

    update_environment_variables({
        'RANK': str(local_rank),
        'LOCAL_RANK': str(local_rank),
        'WORLD_SIZE': str(world_size),
        'MASTER_ADDR': 'localhost',
        'MASTER_PORT': str(master_port),
    })

    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # Create empty inputs
    pixel_values = torch.empty((0, 768))
    grid_thw_list: list[list[int]] = []

    vision_model = SimpleMRopeVisionModel()

    # Should handle empty input gracefully
    with torch.inference_mode():
721
722
723
724
        output = run_dp_sharded_mrope_vision_model(vision_model,
                                                   pixel_values,
                                                   grid_thw_list,
                                                   rope_type="rope_3d")
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776

    assert len(output) == 0


@multi_gpu_test(num_gpus=4)
def test_run_dp_sharded_mrope_vision_model_uneven_load():
    world_size = 4
    mp.spawn(
        run_dp_sharded_mrope_vision_model_uneven_load_worker,
        args=(world_size, get_open_port()),
        nprocs=world_size,
    )


def run_dp_sharded_mrope_vision_model_uneven_load_worker(
        local_rank: int, world_size: int, master_port: int):
    """Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
    # Set up distributed environment
    current_platform.seed_everything(123)
    device = f"{current_platform.device_name}:{local_rank}"
    current_platform.set_device(device)
    torch.set_default_device(device)

    update_environment_variables({
        'RANK': str(local_rank),
        'LOCAL_RANK': str(local_rank),
        'WORLD_SIZE': str(world_size),
        'MASTER_ADDR': 'localhost',
        'MASTER_PORT': str(master_port),
    })

    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # Create images with very different sizes
    grid_thw_list = [
        [1, 2, 2],  # Small: 4 patches
        [1, 8, 8],  # Large: 64 patches  
        [1, 3, 3],  # Medium: 9 patches
    ]

    pixel_values_list = []
    for grid_thw in grid_thw_list:
        num_patches = math.prod(grid_thw)
        image_pixels = torch.randn(num_patches, 768)
        pixel_values_list.append(image_pixels)

    pixel_values = torch.cat(pixel_values_list, dim=0)
    vision_model = SimpleMRopeVisionModel()

    # Should handle uneven distribution without errors
    with torch.inference_mode():
777
778
779
780
        output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
                                                         pixel_values,
                                                         grid_thw_list,
                                                         rope_type="rope_3d")
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818

    # Verify output shape is reasonable
    merge_factor = vision_model.spatial_merge_size**2
    expected_output_patches = list(
        math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)

    for i, output in enumerate(output_tuple):
        assert output.shape[0] == expected_output_patches[i]
        assert output.shape[1] == vision_model.out_hidden_size


@pytest.mark.parametrize("spatial_merge_size", [2, 4])
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
    """Test SimpleMRopeVisionModel with different spatial merge sizes."""
    device = current_platform.device_type

    grid_thw_list = [[1, 4, 4], [1, 6, 6]]  # Two images
    pixel_values_list = []

    for grid_thw in grid_thw_list:
        num_patches = math.prod(grid_thw)
        image_pixels = torch.randn(num_patches, 768, device=device)
        pixel_values_list.append(image_pixels)

    pixel_values = torch.cat(pixel_values_list, dim=0)
    vision_model = SimpleMRopeVisionModel(
        spatial_merge_size=spatial_merge_size).to(device)

    with torch.inference_mode():
        output = vision_model(pixel_values, grid_thw_list)

    # Verify output dimensions based on spatial merging
    total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
    merge_factor = spatial_merge_size**2
    expected_output_patches = total_patches // merge_factor

    assert output.shape[0] == expected_output_patches
    assert output.shape[1] == vision_model.out_hidden_size