"...backends/trtllm/performance_sweeps/benchmark_agg.slurm" did not exist on "a7badb855c8ebb3ef6c4295e6e5144034897fb7b"
test_utils.py 26.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import base64
5
import math
6
import mimetypes
7
8
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
9
from typing import TYPE_CHECKING, NamedTuple
10
11
12

import numpy as np
import pytest
zhuwenwen's avatar
zhuwenwen committed
13
import os
14
15
import torch
import torch.multiprocessing as mp
16
from PIL import Image, ImageChops
17

18
19
20
21
from tests.utils import multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (init_distributed_environment,
                                             initialize_model_parallel)
22
from vllm.multimodal.image import convert_image_mode
23
from vllm.multimodal.inputs import PlaceholderRange
24
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
25
26
                                   get_load_balance_assignment,
                                   run_dp_sharded_mrope_vision_model,
27
28
29
                                   run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables
zhuwenwen's avatar
zhuwenwen committed
30
from ..utils import models_path_prefix, urls_port
31

32
33
34
if TYPE_CHECKING:
    from vllm.multimodal.inputs import MultiModalPlaceholderDict

35
36
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
37
38
39
40
    f"http://localhost:{urls_port}/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
    f"http://localhost:{urls_port}/Grayscale_8bits_palette_sample_image.png",
    f"http://localhost:{urls_port}/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
    f"http://localhost:{urls_port}/RGBA_comp.png",
41
42
]

43
44
TEST_VIDEO_URLS = [
    "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4",
45
    "https://github.com/opencv/opencv/raw/refs/tags/4.12.0/samples/data/vtest.avi",
46
47
]

48

49
@pytest.fixture(scope="module")
50
def url_images() -> dict[str, Image.Image]:
51
52
53
54
55
56
    connector = MediaConnector()

    return {
        image_url: connector.fetch_image(image_url)
        for image_url in TEST_IMAGE_URLS
    }
57
58


59
def get_supported_suffixes() -> tuple[str, ...]:
60
61
62
63
64
65
66
67
68
69
    # We should at least test the file types mentioned in GPT-4 with Vision
    OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif')

    # Additional file types that are supported by us
    EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff')

    return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES


def _image_equals(a: Image.Image, b: Image.Image) -> bool:
70
    return (np.asarray(a) == np.asarray(convert_image_mode(b, a.mode))).all()
71
72


73
@pytest.mark.asyncio
74
75
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str):
76
77
78
79
    connector = MediaConnector()

    image_sync = connector.fetch_image(image_url)
    image_async = await connector.fetch_image_async(image_url)
80
81
82
    assert _image_equals(image_sync, image_async)


83
@pytest.mark.asyncio
84
85
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("suffix", get_supported_suffixes())
86
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
87
                                  image_url: str, suffix: str):
88
    connector = MediaConnector()
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    url_image = url_images[image_url]

    try:
        mime_type = Image.MIME[Image.registered_extensions()[suffix]]
    except KeyError:
        try:
            mime_type = mimetypes.types_map[suffix]
        except KeyError:
            pytest.skip('No MIME type')

    with NamedTemporaryFile(suffix=suffix) as f:
        try:
            url_image.save(f.name)
        except Exception as e:
            if e.args[0] == 'cannot write mode RGBA as JPEG':
                pytest.skip('Conversion not supported')

            raise

        base64_image = base64.b64encode(f.read()).decode("utf-8")
        data_url = f"data:{mime_type};base64,{base64_image}"

111
        data_image_sync = connector.fetch_image(data_url)
112
        if _image_equals(url_image, Image.open(f)):
113
            assert _image_equals(url_image, data_image_sync)
114
115
        else:
            pass  # Lossy format; only check that image can be opened
116

117
        data_image_async = await connector.fetch_image_async(data_url)
118
        assert _image_equals(data_image_sync, data_image_async)
119
120


121
122
123
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_local_files(image_url: str):
124
125
    connector = MediaConnector()

126
    with TemporaryDirectory() as temp_dir:
127
128
129
        local_connector = MediaConnector(allowed_local_media_path=temp_dir)

        origin_image = connector.fetch_image(image_url)
130
131
132
133
        origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
                          quality=100,
                          icc_profile=origin_image.info.get('icc_profile'))

134
135
136
137
        image_async = await local_connector.fetch_image_async(
            f"file://{temp_dir}/{os.path.basename(image_url)}")
        image_sync = local_connector.fetch_image(
            f"file://{temp_dir}/{os.path.basename(image_url)}")
138
139
140
        # Check that the images are equal
        assert not ImageChops.difference(image_sync, image_async).getbbox()

141
142
143
144
145
        with pytest.raises(ValueError, match="must be a subpath"):
            await local_connector.fetch_image_async(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
        with pytest.raises(RuntimeError, match="Cannot load local files"):
            await connector.fetch_image_async(
146
147
                f"file://{temp_dir}/../{os.path.basename(image_url)}")

148
149
150
151
152
153
        with pytest.raises(ValueError, match="must be a subpath"):
            local_connector.fetch_image(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
        with pytest.raises(RuntimeError, match="Cannot load local files"):
            connector.fetch_image(
                f"file://{temp_dir}/../{os.path.basename(image_url)}")
154
155


156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
@pytest.mark.asyncio
async def test_fetch_image_local_files_with_space_in_name():
    image_url = TEST_IMAGE_URLS[0]
    connector = MediaConnector()

    with TemporaryDirectory() as temp_dir:
        local_connector = MediaConnector(allowed_local_media_path=temp_dir)

        origin_image = connector.fetch_image(image_url)
        filename = "file name with space.jpg"
        origin_image.save(os.path.join(temp_dir, filename),
                          quality=100,
                          icc_profile=origin_image.info.get('icc_profile'))

        try:
            image_async = await local_connector.fetch_image_async(
                f"file://{temp_dir}/{filename}")
            image_sync = local_connector.fetch_image(
                f"file://{temp_dir}/{filename}")
        except FileNotFoundError as e:
            pytest.fail(
                "Failed to fetch image with space in name: {}".format(e))
        # Check that the images are equal
        assert not ImageChops.difference(image_sync, image_async).getbbox()


182
183
184
185
186
187
188
189
190
191
192
193
194
@pytest.mark.asyncio
async def test_fetch_image_error_conversion():
    connector = MediaConnector()
    broken_img = "data:image/png;base64,aGVsbG9fdmxsbV9jb21tdW5pdHkK"

    # PIL.UnidentifiedImageError should be converted to ValueError
    with pytest.raises(ValueError):
        await connector.fetch_image_async(broken_img)

    with pytest.raises(ValueError):
        connector.fetch_image(broken_img)


195
196
197
198
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_fetch_video_http(video_url: str, num_frames: int):
199
200
201
202
    connector = MediaConnector(
        media_io_kwargs={"video": {
            "num_frames": num_frames,
        }})
203

204
205
    video_sync, metadata_sync = connector.fetch_video(video_url)
    video_async, metadata_async = await connector.fetch_video_async(video_url)
206
    assert np.array_equal(video_sync, video_async)
207
    assert metadata_sync == metadata_async
208
209


210
# Used for `test_argsort_mm_positions`.
211
212
class TestCase(NamedTuple):
    mm_positions: "MultiModalPlaceholderDict"
213
    expected_modality_idxs: list[tuple[str, int]]
214
215


216
def test_argsort_mm_positions():
217
218

    test_cases = [
219
220
        # Single modality
        ## Internally sorted
221
222
223
224
225
226
227
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=3, length=2),
                ]
            },
228
229
230
            expected_modality_idxs=[
                ("image", 0),
                ("image", 1),
231
232
            ],
        ),
233
        ## Internally unsorted
234
235
236
        TestCase(
            mm_positions={
                "image": [
237
                    PlaceholderRange(offset=3, length=2),
238
239
240
                    PlaceholderRange(offset=0, length=2),
                ]
            },
241
242
243
            expected_modality_idxs=[
                ("image", 1),
                ("image", 0),
244
245
246
            ],
        ),

247
248
        # Two modalities
        ## Internally sorted
249
250
251
252
253
254
255
256
257
258
259
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=7, length=4),
                    PlaceholderRange(offset=11, length=5),
                ],
                "audio": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=2, length=3),
                ]
            },
260
261
262
263
264
            expected_modality_idxs=[
                ("audio", 0),
                ("audio", 1),
                ("image", 0),
                ("image", 1),
265
            ],
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        ),
        ## Interleaved, internally sorted
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=4),
                    PlaceholderRange(offset=8, length=2),
                ],
                "audio": [
                    PlaceholderRange(offset=5, length=2),
                    PlaceholderRange(offset=11, length=4),
                ]
            },
            expected_modality_idxs=[
                ("image", 0),
                ("audio", 0),
                ("image", 1),
                ("audio", 1),
284
285
            ],
        ),
286
        ## Interleaved, internally unsorted
287
288
289
        TestCase(
            mm_positions={
                "image": [
290
291
                    PlaceholderRange(offset=8, length=2),
                    PlaceholderRange(offset=0, length=4),
292
293
                ],
                "audio": [
294
295
                    PlaceholderRange(offset=11, length=4),
                    PlaceholderRange(offset=5, length=2),
296
297
                ]
            },
298
299
300
301
302
            expected_modality_idxs=[
                ("image", 1),
                ("audio", 1),
                ("image", 0),
                ("audio", 0),
303
304
305
306
            ],
        ),

        # Three modalities
307
        ## Internally sorted
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=15, length=7),
                    PlaceholderRange(offset=22, length=8),
                ],
                "audio": [
                    PlaceholderRange(offset=0, length=2),
                ],
                "video": [
                    PlaceholderRange(offset=3, length=4),
                    PlaceholderRange(offset=7, length=5),
                    PlaceholderRange(offset=12, length=6),
                ]
            },
323
324
325
326
327
328
329
            expected_modality_idxs=[
                ("audio", 0),
                ("video", 0),
                ("video", 1),
                ("video", 2),
                ("image", 0),
                ("image", 1),
330
            ],
331
        ),
332
        ## Interleaved, internally sorted
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=2, length=3),
                    PlaceholderRange(offset=20, length=4),
                ],
                "audio": [
                    PlaceholderRange(offset=5, length=2),
                ],
                "video": [
                    PlaceholderRange(offset=8, length=5),
                ]
            },
347
348
349
350
351
352
            expected_modality_idxs=[
                ("image", 0),
                ("image", 1),
                ("audio", 0),
                ("video", 0),
                ("image", 2),
353
            ],
354
        ),
355
        ## Interleaved, internally sunorted
356
357
358
359
        TestCase(
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
360
361
                    PlaceholderRange(offset=20, length=4),
                    PlaceholderRange(offset=2, length=3),
362
363
                ],
                "audio": [
364
                    PlaceholderRange(offset=5, length=2),
365
366
                ],
                "video": [
367
                    PlaceholderRange(offset=8, length=5),
368
369
                ]
            },
370
371
372
373
374
375
            expected_modality_idxs=[
                ("image", 0),
                ("image", 2),
                ("audio", 0),
                ("video", 0),
                ("image", 1),
376
377
            ],
        ),
378
379
    ]

380
381
    for mm_positions, expected_modality_idxs in test_cases:
        modality_idxs = argsort_mm_positions(mm_positions)
382

383
        assert modality_idxs == expected_modality_idxs
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432


class SimpleLinearModel(torch.nn.Module):
    """A simple linear vision model for testing."""

    def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
        super().__init__()
        self.flatten = torch.nn.Flatten()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x: torch.Tensor):
        # Flatten the input and apply linear transformation
        x = self.flatten(x)
        return self.linear(x)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
    "batch_size",
    [
        1,  # Single image
        4,  # Small batch
        5,  # Odd batch size (for testing padding)
    ],
)
def test_run_dp_sharded_vision_model(batch_size: int):
    world_size = 2
    # Launch processes
    mp.spawn(
        run_dp_sharded_vision_model_vs_direct,
        args=(
            world_size,
            batch_size,
            get_open_port(),
        ),
        nprocs=world_size,
    )


def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
                                          batch_size: int, master_port: int):
    """
    Test that run_dp_sharded_vision_model produces the same results as 
    calling the model directly.
    """

    # Set random seed for reproducibility
    current_platform.seed_everything(0)

433
434
    device = f"{current_platform.device_name}:{local_rank}"
    current_platform.set_device(device)
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    torch.set_default_device(device)

    update_environment_variables({
        'RANK': str(local_rank),
        'LOCAL_RANK': str(local_rank),
        'WORLD_SIZE': str(world_size),
        'MASTER_ADDR': 'localhost',
        'MASTER_PORT': str(master_port),
    })

    # initialize distributed
    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # Create a test input tensor
    image_input = torch.randn(batch_size, 3, 224, 224)

    # Create a simple linear model
    vision_model = SimpleLinearModel()

    # Run the model directly on the full input
    with torch.inference_mode():
        direct_output = vision_model(image_input)

    # Run the model through the sharded function
    with torch.inference_mode():
        sharded_output = run_dp_sharded_vision_model(image_input, vision_model)

    # Check that the world size is setup correctly
    assert get_tensor_model_parallel_world_size() == world_size

    # Check that the outputs have the same shape
    assert direct_output.shape == sharded_output.shape

    # Check that the outputs are close (they should be identical)
    assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789


@pytest.mark.parametrize(
    "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
    "expected_grouped_sizes_per_gpu,test_description",
    [
        # Empty input
        ([], 2, [], [0, 0], [0, 0], "empty input"),

        # Fewer samples than GPUs
        ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
                                               ], "fewer samples than GPUs"),

        # Single GPU
        ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),

        # Balanced assignment
        ([100, 100, 100, 100
          ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),

        # Unbalanced sizes - this one is trickier since the algorithm is greedy
        ([1000, 100, 200, 50], 2, [0, 2, 1, 3
                                   ], [1, 3], [1000, 350], "unbalanced sizes"),
    ],
)
def test_get_load_balance_assignment_cases(sizes, num_gpus,
                                           expected_shuffle_indices,
                                           expected_gpu_sample_counts,
                                           expected_grouped_sizes_per_gpu,
                                           test_description):
    """Test get_load_balance_assignment with various input cases."""
    result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
    (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result

    # Common assertions for all cases
    assert len(shuffle_indices) == len(sizes)
    assert len(gpu_sample_counts) == num_gpus
    assert len(grouped_sizes_per_gpu) == num_gpus
    assert sum(gpu_sample_counts) == len(sizes)

    assert shuffle_indices == expected_shuffle_indices

    assert gpu_sample_counts == expected_gpu_sample_counts
    assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu


class SimpleMRopeVisionModel(torch.nn.Module):
    """A simple vision model for testing mrope functionality."""

    def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
        super().__init__()
        self.spatial_merge_size = spatial_merge_size
        self.out_hidden_size = out_hidden_size
        self.linear = torch.nn.Linear(768, out_hidden_size)

    def forward(self, pixel_values: torch.Tensor,
                grid_thw_list: list[list[int]]):
        """Simple forward pass that simulates spatial merging."""
        # Apply linear transformation
        embeddings = self.linear(pixel_values)

        # Simulate spatial merging by reducing the number of patches
        merge_factor = self.spatial_merge_size * self.spatial_merge_size

        # Group patches and merge spatially
        merged_embeddings = []
        start_idx = 0

        for grid_thw in grid_thw_list:
            num_patches = math.prod(grid_thw)
            end_idx = start_idx + num_patches

            # Get patches for this image
            image_patches = embeddings[start_idx:end_idx]

            # Simulate spatial merging by averaging groups of patches
            merged_patches = num_patches // merge_factor
            if merged_patches > 0:
                # Reshape and average to simulate merging
                reshaped = image_patches[:merged_patches * merge_factor].view(
                    merged_patches, merge_factor, -1)
                merged = reshaped.mean(dim=1)
                merged_embeddings.append(merged)

            start_idx = end_idx

        if merged_embeddings:
            return torch.cat(merged_embeddings, dim=0)
        else:
            return torch.empty((0, self.out_hidden_size),
                               device=pixel_values.device,
                               dtype=pixel_values.dtype)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
    "batch_size",
    [
        1,  # Single image
        3,  # Small batch
        5,  # Odd batch size (for testing padding)
    ],
)
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
    world_size = 2
    # Launch processes
    mp.spawn(
        run_dp_sharded_mrope_vision_model_vs_direct,
        args=(
            world_size,
            batch_size,
            get_open_port(),
        ),
        nprocs=world_size,
    )


def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
                                                world_size: int,
                                                batch_size: int,
                                                master_port: int):
    """
    Test that run_dp_sharded_mrope_vision_model produces the same results as 
    calling the model directly.
    """
    # Set random seed for reproducibility
    current_platform.seed_everything(0)
    device = f"{current_platform.device_name}:{local_rank}"
    current_platform.set_device(device)
    torch.set_default_device(device)

    update_environment_variables({
        'RANK': str(local_rank),
        'LOCAL_RANK': str(local_rank),
        'WORLD_SIZE': str(world_size),
        'MASTER_ADDR': 'localhost',
        'MASTER_PORT': str(master_port),
    })

    # initialize distributed
    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # Create test data
    grid_thw_list = []
    pixel_values_list = []

    for i in range(batch_size):
        # Varying image sizes for better testing
        t, h, w = 1, 4 + i, 4 + i
        grid_thw_list.append([t, h, w])

        num_patches = t * h * w
        # Create random pixel values for this image
        image_pixels = torch.randn(num_patches, 768)
        pixel_values_list.append(image_pixels)

    # Concatenate all pixel values
    pixel_values = torch.cat(pixel_values_list, dim=0)

    # Create a simple mrope vision model
    vision_model = SimpleMRopeVisionModel()

    # Run the model directly on the full input (only on rank 0)
    if local_rank == 0:
        with torch.inference_mode():
            direct_output = vision_model(pixel_values, grid_thw_list)

    # Run the model through the sharded function
    with torch.inference_mode():
        sharded_output = run_dp_sharded_mrope_vision_model(
            vision_model, pixel_values, grid_thw_list)
        sharded_output = torch.cat(sharded_output, dim=0)

    # Check that the world size is setup correctly
    assert get_tensor_model_parallel_world_size() == world_size

    # Compare outputs (only on rank 0)
    if local_rank == 0:
        # Check that the outputs have the same shape
        assert direct_output.shape == sharded_output.shape
        # Check that the outputs are close (they should be identical)
        assert torch.allclose(direct_output,
                              sharded_output,
                              rtol=1e-5,
                              atol=1e-5)


@multi_gpu_test(num_gpus=2)
def test_run_dp_sharded_mrope_vision_model_empty_input():
    world_size = 2
    mp.spawn(
        run_dp_sharded_mrope_vision_model_empty_input_worker,
        args=(world_size, get_open_port()),
        nprocs=world_size,
    )


def run_dp_sharded_mrope_vision_model_empty_input_worker(
        local_rank: int, world_size: int, master_port: int):
    """Test run_dp_sharded_mrope_vision_model with empty input."""
    # Set up distributed environment
    device = f"{current_platform.device_name}:{local_rank}"
    current_platform.set_device(device)
    torch.set_default_device(device)

    update_environment_variables({
        'RANK': str(local_rank),
        'LOCAL_RANK': str(local_rank),
        'WORLD_SIZE': str(world_size),
        'MASTER_ADDR': 'localhost',
        'MASTER_PORT': str(master_port),
    })

    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # Create empty inputs
    pixel_values = torch.empty((0, 768))
    grid_thw_list: list[list[int]] = []

    vision_model = SimpleMRopeVisionModel()

    # Should handle empty input gracefully
    with torch.inference_mode():
        output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values,
                                                   grid_thw_list)

    assert len(output) == 0


@multi_gpu_test(num_gpus=4)
def test_run_dp_sharded_mrope_vision_model_uneven_load():
    world_size = 4
    mp.spawn(
        run_dp_sharded_mrope_vision_model_uneven_load_worker,
        args=(world_size, get_open_port()),
        nprocs=world_size,
    )


def run_dp_sharded_mrope_vision_model_uneven_load_worker(
        local_rank: int, world_size: int, master_port: int):
    """Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
    # Set up distributed environment
    current_platform.seed_everything(123)
    device = f"{current_platform.device_name}:{local_rank}"
    current_platform.set_device(device)
    torch.set_default_device(device)

    update_environment_variables({
        'RANK': str(local_rank),
        'LOCAL_RANK': str(local_rank),
        'WORLD_SIZE': str(world_size),
        'MASTER_ADDR': 'localhost',
        'MASTER_PORT': str(master_port),
    })

    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # Create images with very different sizes
    grid_thw_list = [
        [1, 2, 2],  # Small: 4 patches
        [1, 8, 8],  # Large: 64 patches  
        [1, 3, 3],  # Medium: 9 patches
    ]

    pixel_values_list = []
    for grid_thw in grid_thw_list:
        num_patches = math.prod(grid_thw)
        image_pixels = torch.randn(num_patches, 768)
        pixel_values_list.append(image_pixels)

    pixel_values = torch.cat(pixel_values_list, dim=0)
    vision_model = SimpleMRopeVisionModel()

    # Should handle uneven distribution without errors
    with torch.inference_mode():
        output_tuple = run_dp_sharded_mrope_vision_model(
            vision_model, pixel_values, grid_thw_list)

    # Verify output shape is reasonable
    merge_factor = vision_model.spatial_merge_size**2
    expected_output_patches = list(
        math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)

    for i, output in enumerate(output_tuple):
        assert output.shape[0] == expected_output_patches[i]
        assert output.shape[1] == vision_model.out_hidden_size


@pytest.mark.parametrize("spatial_merge_size", [2, 4])
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
    """Test SimpleMRopeVisionModel with different spatial merge sizes."""
    device = current_platform.device_type

    grid_thw_list = [[1, 4, 4], [1, 6, 6]]  # Two images
    pixel_values_list = []

    for grid_thw in grid_thw_list:
        num_patches = math.prod(grid_thw)
        image_pixels = torch.randn(num_patches, 768, device=device)
        pixel_values_list.append(image_pixels)

    pixel_values = torch.cat(pixel_values_list, dim=0)
    vision_model = SimpleMRopeVisionModel(
        spatial_merge_size=spatial_merge_size).to(device)

    with torch.inference_mode():
        output = vision_model(pixel_values, grid_thw_list)

    # Verify output dimensions based on spatial merging
    total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
    merge_factor = spatial_merge_size**2
    expected_output_patches = total_patches // merge_factor

    assert output.shape[0] == expected_output_patches
    assert output.shape[1] == vision_model.out_hidden_size