Unverified Commit 7046e56f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add segmentation reference consistency tests (#6591)



* add segmentation reference consistency tests

* fall back to smoke tests for resize

* add test for RandomCrop
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 0a946d5b
......@@ -69,7 +69,8 @@ class PILImagePair(TensorLikePair):
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected]
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
for input in [actual, expected]
]
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
# image to a tensor adds a singleton leading dimension.
......
import enum
import inspect
import random
from collections import defaultdict
from importlib.machinery import SourceFileLoader
from pathlib import Path
......@@ -16,13 +18,15 @@ from prototype_common_utils import (
make_image,
make_images,
make_label,
make_segmentation_mask,
)
from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms._utils import query_chw
from torchvision.prototype.transforms.functional import to_image_pil
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
......@@ -852,10 +856,12 @@ class TestAATransforms:
assert_equal(expected_output, output)
# Import reference detection transforms here for consistency checks
# torchvision/references/detection/transforms.py
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py"
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
def import_transforms_from_references(reference):
ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py"
return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
det_transforms = import_transforms_from_references("detection")
class TestRefDetTransforms:
......@@ -873,7 +879,7 @@ class TestRefDetTransforms:
yield (pil_image, target)
tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)
tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB))
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
......@@ -883,7 +889,7 @@ class TestRefDetTransforms:
yield (tensor_image, target)
feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8))
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB)
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
......@@ -927,3 +933,165 @@ class TestRefDetTransforms:
expected_output = t_ref(*dp)
assert_equal(expected_output, output)
seg_transforms = import_transforms_from_references("segmentation")
# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
class PadIfSmaller(prototype_transforms.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = prototype_transforms._geometry._setup_fill_arg(fill)
def _get_params(self, sample):
_, height, width = query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)
def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, padding=params["padding"], fill=fill)
class TestRefSegTransforms:
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
size = (256, 640)
num_categories = 21
conv_fns = []
if supports_pil:
conv_fns.append(to_image_pil)
conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns:
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype)
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
dp = (conv_fn(feature_image), feature_mask)
dp_ref = (
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image),
to_image_pil(feature_mask),
)
yield dp, dp_ref
def set_seed(self, seed=12):
torch.manual_seed(seed)
random.seed(seed)
def check(self, t, t_ref, data_kwargs=None):
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):
self.set_seed()
output = t(dp)
self.set_seed()
expected_output = t_ref(*dp_ref)
assert_equal(output, expected_output)
@pytest.mark.parametrize(
("t_ref", "t", "data_kwargs"),
[
(
seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
prototype_transforms.RandomHorizontalFlip(p=1.0),
dict(),
),
(
seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
prototype_transforms.RandomHorizontalFlip(p=0.0),
dict(),
),
(
seg_transforms.RandomCrop(size=480),
prototype_transforms.Compose(
[
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})),
prototype_transforms.RandomCrop(size=480),
]
),
dict(),
),
(
seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
dict(supports_pil=False, image_dtype=torch.float),
),
],
)
def test_common(self, t_ref, t, data_kwargs):
self.check(t, t_ref, data_kwargs)
def check_resize(self, mocker, t_ref, t):
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
mock_ref = mocker.patch("torchvision.transforms.functional.resize")
for dp, dp_ref in self.make_datapoints():
mock.reset_mock()
mock_ref.reset_mock()
self.set_seed()
t(dp)
assert mock.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp)
)
self.set_seed()
t_ref(*dp_ref)
assert mock_ref.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref)
)
for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list):
assert args_kwargs[0][1] == [args_kwargs_ref[0][1]]
def test_random_resize_train(self, mocker):
base_size = 520
min_size = base_size // 2
max_size = base_size * 2
randint = torch.randint
def patched_randint(a, b, *other_args, **kwargs):
if kwargs or len(other_args) > 1 or other_args[0] != ():
return randint(a, b, *other_args, **kwargs)
return random.randint(a, b)
# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
# normally
t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
mocker.patch(
"torchvision.prototype.transforms._geometry.torch.randint",
new=patched_randint,
)
t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size)
self.check_resize(mocker, t_ref, t)
def test_random_resize_eval(self, mocker):
torch.manual_seed(0)
base_size = 520
t = prototype_transforms.Resize(size=base_size, antialias=True)
t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)
self.check_resize(mocker, t_ref, t)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment