Unverified Commit 7046e56f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add segmentation reference consistency tests (#6591)



* add segmentation reference consistency tests

* fall back to smoke tests for resize

* add test for RandomCrop
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 0a946d5b
...@@ -69,7 +69,8 @@ class PILImagePair(TensorLikePair): ...@@ -69,7 +69,8 @@ class PILImagePair(TensorLikePair):
def _process_inputs(self, actual, expected, *, id, allow_subclasses): def _process_inputs(self, actual, expected, *, id, allow_subclasses):
actual, expected = [ actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected] to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
for input in [actual, expected]
] ]
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL # This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
# image to a tensor adds a singleton leading dimension. # image to a tensor adds a singleton leading dimension.
......
import enum import enum
import inspect import inspect
import random
from collections import defaultdict
from importlib.machinery import SourceFileLoader from importlib.machinery import SourceFileLoader
from pathlib import Path from pathlib import Path
...@@ -16,13 +18,15 @@ from prototype_common_utils import ( ...@@ -16,13 +18,15 @@ from prototype_common_utils import (
make_image, make_image,
make_images, make_images,
make_label, make_label,
make_segmentation_mask,
) )
from torchvision import transforms as legacy_transforms from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms._utils import query_chw
from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.functional import to_image_pil
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
...@@ -852,10 +856,12 @@ class TestAATransforms: ...@@ -852,10 +856,12 @@ class TestAATransforms:
assert_equal(expected_output, output) assert_equal(expected_output, output)
# Import reference detection transforms here for consistency checks def import_transforms_from_references(reference):
# torchvision/references/detection/transforms.py ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py"
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py" return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
det_transforms = import_transforms_from_references("detection")
class TestRefDetTransforms: class TestRefDetTransforms:
...@@ -873,7 +879,7 @@ class TestRefDetTransforms: ...@@ -873,7 +879,7 @@ class TestRefDetTransforms:
yield (pil_image, target) yield (pil_image, target)
tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8) tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB))
target = { target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -883,7 +889,7 @@ class TestRefDetTransforms: ...@@ -883,7 +889,7 @@ class TestRefDetTransforms:
yield (tensor_image, target) yield (tensor_image, target)
feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)) feature_image = make_image(size=size, color_space=features.ColorSpace.RGB)
target = { target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -927,3 +933,165 @@ class TestRefDetTransforms: ...@@ -927,3 +933,165 @@ class TestRefDetTransforms:
expected_output = t_ref(*dp) expected_output = t_ref(*dp)
assert_equal(expected_output, output) assert_equal(expected_output, output)
seg_transforms = import_transforms_from_references("segmentation")
# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
class PadIfSmaller(prototype_transforms.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = prototype_transforms._geometry._setup_fill_arg(fill)
def _get_params(self, sample):
_, height, width = query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)
def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, padding=params["padding"], fill=fill)
class TestRefSegTransforms:
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
size = (256, 640)
num_categories = 21
conv_fns = []
if supports_pil:
conv_fns.append(to_image_pil)
conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns:
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype)
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
dp = (conv_fn(feature_image), feature_mask)
dp_ref = (
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image),
to_image_pil(feature_mask),
)
yield dp, dp_ref
def set_seed(self, seed=12):
torch.manual_seed(seed)
random.seed(seed)
def check(self, t, t_ref, data_kwargs=None):
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):
self.set_seed()
output = t(dp)
self.set_seed()
expected_output = t_ref(*dp_ref)
assert_equal(output, expected_output)
@pytest.mark.parametrize(
("t_ref", "t", "data_kwargs"),
[
(
seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
prototype_transforms.RandomHorizontalFlip(p=1.0),
dict(),
),
(
seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
prototype_transforms.RandomHorizontalFlip(p=0.0),
dict(),
),
(
seg_transforms.RandomCrop(size=480),
prototype_transforms.Compose(
[
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})),
prototype_transforms.RandomCrop(size=480),
]
),
dict(),
),
(
seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
dict(supports_pil=False, image_dtype=torch.float),
),
],
)
def test_common(self, t_ref, t, data_kwargs):
self.check(t, t_ref, data_kwargs)
def check_resize(self, mocker, t_ref, t):
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
mock_ref = mocker.patch("torchvision.transforms.functional.resize")
for dp, dp_ref in self.make_datapoints():
mock.reset_mock()
mock_ref.reset_mock()
self.set_seed()
t(dp)
assert mock.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp)
)
self.set_seed()
t_ref(*dp_ref)
assert mock_ref.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref)
)
for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list):
assert args_kwargs[0][1] == [args_kwargs_ref[0][1]]
def test_random_resize_train(self, mocker):
base_size = 520
min_size = base_size // 2
max_size = base_size * 2
randint = torch.randint
def patched_randint(a, b, *other_args, **kwargs):
if kwargs or len(other_args) > 1 or other_args[0] != ():
return randint(a, b, *other_args, **kwargs)
return random.randint(a, b)
# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
# normally
t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
mocker.patch(
"torchvision.prototype.transforms._geometry.torch.randint",
new=patched_randint,
)
t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size)
self.check_resize(mocker, t_ref, t)
def test_random_resize_eval(self, mocker):
torch.manual_seed(0)
base_size = 520
t = prototype_transforms.Resize(size=base_size, antialias=True)
t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)
self.check_resize(mocker, t_ref, t)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment