Unverified Commit 544a4070 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added consistency tests for detection transforms (#6566)



* [proto] Added consistency tests for detection transforms

* Updated tests according to the review

* More updates
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent e0e95380
import enum import enum
import inspect import inspect
from importlib.machinery import SourceFileLoader
from pathlib import Path
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch import torch
from prototype_common_utils import ArgsKwargs, assert_equal, make_images from prototype_common_utils import (
ArgsKwargs,
assert_equal,
make_bounding_box,
make_detection_mask,
make_image,
make_images,
make_label,
)
from torchvision import transforms as legacy_transforms from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms from torchvision.prototype import features, transforms as prototype_transforms
...@@ -840,3 +850,80 @@ class TestAATransforms: ...@@ -840,3 +850,80 @@ class TestAATransforms:
output = t(inpt) output = t(inpt)
assert_equal(expected_output, output) assert_equal(expected_output, output)
# Import reference detection transforms here for consistency checks
# torchvision/references/detection/transforms.py
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py"
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
class TestRefDetTransforms:
def make_datapoints(self, with_mask=True):
size = (600, 800)
num_objects = 22
pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB))
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
yield (pil_image, target)
tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
yield (tensor_image, target)
feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8))
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
yield (feature_image, target)
@pytest.mark.parametrize(
"t_ref, t, data_kwargs",
[
(det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}),
(det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}),
(det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}),
(det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}),
(
det_transforms.FixedSizeCrop((1024, 1024), fill=0),
prototype_transforms.FixedSizeCrop((1024, 1024), fill=0),
{},
),
(
det_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
prototype_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
{},
),
],
)
def test_transform(self, t_ref, t, data_kwargs):
for dp in self.make_datapoints(**data_kwargs):
# We should use prototype transform first as reference transform performs inplace target update
torch.manual_seed(12)
output = t(dp)
torch.manual_seed(12)
expected_output = t_ref(*dp)
assert_equal(expected_output, output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment