Unverified Commit f168b857 authored by Siqi Yan's avatar Siqi Yan Committed by GitHub
Browse files

Unit Test for run_dp_sharded_vision_model (#19103)


Signed-off-by: default avatarSiqi Yan <siqi@meta.com>
Co-authored-by: default avatarSiqi Yan <siqi@meta.com>
parent da511d54
...@@ -9,12 +9,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional ...@@ -9,12 +9,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional
import numpy as np import numpy as np
import pytest import pytest
import torch
import torch.multiprocessing as mp
from PIL import Image, ImageChops from PIL import Image, ImageChops
from tests.utils import multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector, from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata) merge_and_sort_multimodal_metadata,
run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict from vllm.multimodal.hasher import MultiModalHashDict
...@@ -413,3 +422,90 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving(): ...@@ -413,3 +422,90 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
assert modalities == expected_modalities assert modalities == expected_modalities
assert ranges == expected_ranges assert ranges == expected_ranges
assert hashes == expected_hashes assert hashes == expected_hashes
class SimpleLinearModel(torch.nn.Module):
"""A simple linear vision model for testing."""
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
super().__init__()
self.flatten = torch.nn.Flatten()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor):
# Flatten the input and apply linear transformation
x = self.flatten(x)
return self.linear(x)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
4, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
batch_size: int, master_port: int):
"""
Test that run_dp_sharded_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create a test input tensor
image_input = torch.randn(batch_size, 3, 224, 224)
# Create a simple linear model
vision_model = SimpleLinearModel()
# Run the model directly on the full input
with torch.inference_mode():
direct_output = vision_model(image_input)
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
# Check that the world size is setup correctly
assert get_tensor_model_parallel_world_size() == world_size
# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape
# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment