Unverified Commit 97e0ea9c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Expand prototype transforms consistency tests (#6516)

* improve test framework and add more consistency configs

* improve error messages

* increase coverage for Resize

* improve comparison
parent 0d69e35c
import functools
import itertools import itertools
import PIL.Image
import pytest import pytest
import torch.testing
import torch
from test_prototype_transforms_functional import make_images from test_prototype_transforms_functional import make_images
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
from torchvision import transforms as legacy_transforms from torchvision import transforms as legacy_transforms
from torchvision.prototype import features, transforms as prototype_transforms from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor
class ImagePair(TensorLikePair):
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
return super()._process_inputs(
*[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]],
id=id,
allow_subclasses=allow_subclasses,
)
assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
...@@ -20,10 +37,17 @@ class ArgsKwargs: ...@@ -20,10 +37,17 @@ class ArgsKwargs:
yield self.kwargs yield self.kwargs
def __str__(self): def __str__(self):
def short_repr(obj, max=20):
repr_ = repr(obj)
if len(repr_) <= max:
return repr_
return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}"
return ", ".join( return ", ".join(
itertools.chain( itertools.chain(
[repr(arg) for arg in self.args], [short_repr(arg) for arg in self.args],
[f"{param}={repr(kwarg)}" for param, kwarg in self.kwargs.items()], [f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()],
) )
) )
...@@ -52,6 +76,10 @@ class ConsistencyConfig: ...@@ -52,6 +76,10 @@ class ConsistencyConfig:
] ]
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)
CONSISTENCY_CONFIGS = [ CONSISTENCY_CONFIGS = [
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.Normalize, prototype_transforms.Normalize,
...@@ -68,7 +96,18 @@ CONSISTENCY_CONFIGS = [ ...@@ -68,7 +96,18 @@ CONSISTENCY_CONFIGS = [
[ [
ArgsKwargs(32), ArgsKwargs(32),
ArgsKwargs((32, 29)), ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.BICUBIC), ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
# FIXME: these are currently failing, since the new transform only supports the enum. The int input is
# already deprecated and scheduled to be removed in 0.15. Should we support ints on the prototype
# transform? I guess it depends if we roll out before 0.15 or not.
# ArgsKwargs((30, 27), interpolation=0),
# ArgsKwargs((35, 29), interpolation=2),
# ArgsKwargs((34, 25), interpolation=3),
ArgsKwargs(31, max_size=32),
ArgsKwargs(30, max_size=100),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
], ],
), ),
ConsistencyConfig( ConsistencyConfig(
...@@ -79,6 +118,62 @@ CONSISTENCY_CONFIGS = [ ...@@ -79,6 +118,62 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs((18, 13)), ArgsKwargs((18, 13)),
], ],
), ),
ConsistencyConfig(
prototype_transforms.FiveCrop,
legacy_transforms.FiveCrop,
[
ArgsKwargs(18),
ArgsKwargs((18, 13)),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
ConsistencyConfig(
prototype_transforms.TenCrop,
legacy_transforms.TenCrop,
[
ArgsKwargs(18),
ArgsKwargs((18, 13)),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
ConsistencyConfig(
prototype_transforms.Pad,
legacy_transforms.Pad,
[
ArgsKwargs(3),
ArgsKwargs([3]),
ArgsKwargs([2, 3]),
ArgsKwargs([3, 2, 1, 4]),
ArgsKwargs(5, fill=1, padding_mode="constant"),
ArgsKwargs(5, padding_mode="edge"),
ArgsKwargs(5, padding_mode="reflect"),
ArgsKwargs(5, padding_mode="symmetric"),
],
),
ConsistencyConfig(
prototype_transforms.LinearTransformation,
legacy_transforms.LinearTransformation,
[
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX, LINEAR_TRANSFORMATION_MEAN),
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[features.ColorSpace.RGB]
),
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.Grayscale,
legacy_transforms.Grayscale,
[
ArgsKwargs(num_output_channels=1),
ArgsKwargs(num_output_channels=3),
],
make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[features.ColorSpace.RGB, features.ColorSpace.GRAY]
),
),
] ]
...@@ -113,8 +208,8 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -113,8 +208,8 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_legacy_tensor = legacy(image_tensor) output_legacy_tensor = legacy(image_tensor)
except Exception as exc: except Exception as exc:
raise pytest.UsageError( raise pytest.UsageError(
f"Transforming a tensor image with shape {tuple(image.shape)} failed with the error above. " f"Transforming a tensor image with shape {tuple(image.shape)} failed in the legacy transform with the "
"This means that you need to specify the parameters passed to `make_images` through the " f"error above. This means that you need to specify the parameters passed to `make_images` through the "
"`make_images_kwargs` of the `ConsistencyConfig`." "`make_images_kwargs` of the `ConsistencyConfig`."
) from exc ) from exc
...@@ -122,16 +217,14 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -122,16 +217,14 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_prototype_tensor = prototype(image_tensor) output_prototype_tensor = prototype(image_tensor)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
f"Transforming a tensor image with shape {tuple(image.shape)} failed with the error above. " f"Transforming a tensor image with shape {tuple(image.shape)} failed in the prototype transform with "
f"This means there is a consistency bug either in `_get_params` " f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"or in the `is_simple_tensor` path in `_transform`." f"`is_simple_tensor` path in `_transform`."
) from exc ) from exc
torch.testing.assert_close( assert_equal(
output_prototype_tensor, output_prototype_tensor,
output_legacy_tensor, output_legacy_tensor,
atol=0,
rtol=0,
msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}", msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
) )
...@@ -139,24 +232,38 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -139,24 +232,38 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_prototype_image = prototype(image) output_prototype_image = prototype(image)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
f"Transforming a feature image with shape {tuple(image.shape)} failed with the error above. " f"Transforming a feature image with shape {tuple(image.shape)} failed in the prototype transform with "
f"This means there is a consistency bug either in `_get_params` " f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"or in the `features.Image` path in `_transform`." f"`features.Image` path in `_transform`."
) from exc ) from exc
torch.testing.assert_close( assert_equal(
torch.Tensor(output_prototype_image), output_prototype_image,
output_prototype_tensor, output_prototype_tensor,
atol=0,
rtol=0,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}", msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
) )
if image_pil is not None: if image_pil is not None:
torch.testing.assert_close( try:
to_image_tensor(prototype(image_pil)), output_legacy_pil = legacy(image_pil)
to_image_tensor(legacy(image_pil)), except Exception as exc:
atol=0, raise pytest.UsageError(
rtol=0, f"Transforming a PIL image with shape {tuple(image.shape)} failed in the legacy transform with the "
f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
"`ConsistencyConfig`. "
) from exc
try:
output_prototype_pil = prototype(image_pil)
except Exception as exc:
raise AssertionError(
f"Transforming a PIL image with shape {tuple(image.shape)} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`PIL.Image.Image` path in `_transform`."
) from exc
assert_equal(
output_prototype_pil,
output_legacy_pil,
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}", msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment