Unverified Commit 544a4070 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added consistency tests for detection transforms (#6566)



* [proto] Added consistency tests for detection transforms

* Updated tests according to the review

* More updates
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent e0e95380
import enum
import inspect
from importlib.machinery import SourceFileLoader
from pathlib import Path
import numpy as np
import PIL.Image
import pytest
import torch
from prototype_common_utils import ArgsKwargs, assert_equal, make_images
from prototype_common_utils import (
ArgsKwargs,
assert_equal,
make_bounding_box,
make_detection_mask,
make_image,
make_images,
make_label,
)
from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms
......@@ -840,3 +850,80 @@ class TestAATransforms:
output = t(inpt)
assert_equal(expected_output, output)
# Import reference detection transforms here for consistency checks
# torchvision/references/detection/transforms.py
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py"
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
class TestRefDetTransforms:
def make_datapoints(self, with_mask=True):
size = (600, 800)
num_objects = 22
pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB))
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
yield (pil_image, target)
tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
yield (tensor_image, target)
feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8))
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
yield (feature_image, target)
@pytest.mark.parametrize(
"t_ref, t, data_kwargs",
[
(det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}),
(det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}),
(det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}),
(det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}),
(
det_transforms.FixedSizeCrop((1024, 1024), fill=0),
prototype_transforms.FixedSizeCrop((1024, 1024), fill=0),
{},
),
(
det_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
prototype_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
{},
),
],
)
def test_transform(self, t_ref, t, data_kwargs):
for dp in self.make_datapoints(**data_kwargs):
# We should use prototype transform first as reference transform performs inplace target update
torch.manual_seed(12)
output = t(dp)
torch.manual_seed(12)
expected_output = t_ref(*dp)
assert_equal(expected_output, output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment