Unverified Commit 97e0ea9c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Expand prototype transforms consistency tests (#6516)

* improve test framework and add more consistency configs

* improve error messages

* increase coverage for Resize

* improve comparison
parent 0d69e35c
import functools
import itertools
import PIL.Image
import pytest
import torch.testing
import torch
from test_prototype_transforms_functional import make_images
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
from torchvision import transforms as legacy_transforms
from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor
class ImagePair(TensorLikePair):
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
return super()._process_inputs(
*[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]],
id=id,
allow_subclasses=allow_subclasses,
)
assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
......@@ -20,10 +37,17 @@ class ArgsKwargs:
yield self.kwargs
def __str__(self):
def short_repr(obj, max=20):
repr_ = repr(obj)
if len(repr_) <= max:
return repr_
return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}"
return ", ".join(
itertools.chain(
[repr(arg) for arg in self.args],
[f"{param}={repr(kwarg)}" for param, kwarg in self.kwargs.items()],
[short_repr(arg) for arg in self.args],
[f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()],
)
)
......@@ -52,6 +76,10 @@ class ConsistencyConfig:
]
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
CONSISTENCY_CONFIGS = [
ConsistencyConfig(
prototype_transforms.Normalize,
......@@ -68,7 +96,18 @@ CONSISTENCY_CONFIGS = [
[
ArgsKwargs(32),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
# FIXME: these are currently failing, since the new transform only supports the enum. The int input is
# already deprecated and scheduled to be removed in 0.15. Should we support ints on the prototype
# transform? I guess it depends if we roll out before 0.15 or not.
# ArgsKwargs((30, 27), interpolation=0),
# ArgsKwargs((35, 29), interpolation=2),
# ArgsKwargs((34, 25), interpolation=3),
ArgsKwargs(31, max_size=32),
ArgsKwargs(30, max_size=100),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
),
ConsistencyConfig(
......@@ -79,6 +118,62 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs((18, 13)),
],
),
ConsistencyConfig(
prototype_transforms.FiveCrop,
legacy_transforms.FiveCrop,
[
ArgsKwargs(18),
ArgsKwargs((18, 13)),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
ConsistencyConfig(
prototype_transforms.TenCrop,
legacy_transforms.TenCrop,
[
ArgsKwargs(18),
ArgsKwargs((18, 13)),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
ConsistencyConfig(
prototype_transforms.Pad,
legacy_transforms.Pad,
[
ArgsKwargs(3),
ArgsKwargs([3]),
ArgsKwargs([2, 3]),
ArgsKwargs([3, 2, 1, 4]),
ArgsKwargs(5, fill=1, padding_mode="constant"),
ArgsKwargs(5, padding_mode="edge"),
ArgsKwargs(5, padding_mode="reflect"),
ArgsKwargs(5, padding_mode="symmetric"),
],
),
ConsistencyConfig(
prototype_transforms.LinearTransformation,
legacy_transforms.LinearTransformation,
[
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX, LINEAR_TRANSFORMATION_MEAN),
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[features.ColorSpace.RGB]
),
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.Grayscale,
legacy_transforms.Grayscale,
[
ArgsKwargs(num_output_channels=1),
ArgsKwargs(num_output_channels=3),
],
make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[features.ColorSpace.RGB, features.ColorSpace.GRAY]
),
),
]
......@@ -113,8 +208,8 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_legacy_tensor = legacy(image_tensor)
except Exception as exc:
raise pytest.UsageError(
f"Transforming a tensor image with shape {tuple(image.shape)} failed with the error above. "
"This means that you need to specify the parameters passed to `make_images` through the "
f"Transforming a tensor image with shape {tuple(image.shape)} failed in the legacy transform with the "
f"error above. This means that you need to specify the parameters passed to `make_images` through the "
"`make_images_kwargs` of the `ConsistencyConfig`."
) from exc
......@@ -122,16 +217,14 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_prototype_tensor = prototype(image_tensor)
except Exception as exc:
raise AssertionError(
f"Transforming a tensor image with shape {tuple(image.shape)} failed with the error above. "
f"This means there is a consistency bug either in `_get_params` "
f"or in the `is_simple_tensor` path in `_transform`."
f"Transforming a tensor image with shape {tuple(image.shape)} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`is_simple_tensor` path in `_transform`."
) from exc
torch.testing.assert_close(
assert_equal(
output_prototype_tensor,
output_legacy_tensor,
atol=0,
rtol=0,
msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
)
......@@ -139,24 +232,38 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_prototype_image = prototype(image)
except Exception as exc:
raise AssertionError(
f"Transforming a feature image with shape {tuple(image.shape)} failed with the error above. "
f"This means there is a consistency bug either in `_get_params` "
f"or in the `features.Image` path in `_transform`."
f"Transforming a feature image with shape {tuple(image.shape)} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`features.Image` path in `_transform`."
) from exc
torch.testing.assert_close(
torch.Tensor(output_prototype_image),
assert_equal(
output_prototype_image,
output_prototype_tensor,
atol=0,
rtol=0,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
)
if image_pil is not None:
torch.testing.assert_close(
to_image_tensor(prototype(image_pil)),
to_image_tensor(legacy(image_pil)),
atol=0,
rtol=0,
try:
output_legacy_pil = legacy(image_pil)
except Exception as exc:
raise pytest.UsageError(
f"Transforming a PIL image with shape {tuple(image.shape)} failed in the legacy transform with the "
f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
"`ConsistencyConfig`. "
) from exc
try:
output_prototype_pil = prototype(image_pil)
except Exception as exc:
raise AssertionError(
f"Transforming a PIL image with shape {tuple(image.shape)} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`PIL.Image.Image` path in `_transform`."
) from exc
assert_equal(
output_prototype_pil,
output_legacy_pil,
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment