Unverified Commit 0563b18c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Expand prototype transforms consistency tests to all deterministic transformations (#6518)

* add test for coverage

* add tests for remaining deterministic transforms

* fix Lambda
parent 97e0ea9c
import enum
import functools import functools
import itertools import itertools
import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
...@@ -8,6 +11,7 @@ import torch ...@@ -8,6 +11,7 @@ import torch
from test_prototype_transforms_functional import make_images from test_prototype_transforms_functional import make_images
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
from torchvision import transforms as legacy_transforms from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor
...@@ -70,7 +74,7 @@ class ConsistencyConfig: ...@@ -70,7 +74,7 @@ class ConsistencyConfig:
args_kwargs, args_kwargs,
self.make_images_kwargs, self.make_images_kwargs,
self.supports_pil, self.supports_pil,
id=f"{self.prototype_cls.__name__}({args_kwargs})", id=f"{self.legacy_cls.__name__}({args_kwargs})",
) )
for args_kwargs in self.transform_args_kwargs for args_kwargs in self.transform_args_kwargs
] ]
...@@ -174,9 +178,83 @@ CONSISTENCY_CONFIGS = [ ...@@ -174,9 +178,83 @@ CONSISTENCY_CONFIGS = [
DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[features.ColorSpace.RGB, features.ColorSpace.GRAY] DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[features.ColorSpace.RGB, features.ColorSpace.GRAY]
), ),
), ),
ConsistencyConfig(
prototype_transforms.ConvertImageDtype,
legacy_transforms.ConvertImageDtype,
[
ArgsKwargs(torch.float16),
ArgsKwargs(torch.bfloat16),
ArgsKwargs(torch.float32),
ArgsKwargs(torch.float64),
ArgsKwargs(torch.uint8),
],
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.ToPILImage,
legacy_transforms.ToPILImage,
[ArgsKwargs()],
make_images_kwargs=dict(
color_spaces=[
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
],
extra_dims=[()],
),
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.Lambda,
legacy_transforms.Lambda,
[
ArgsKwargs(lambda image: image / 2),
],
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
] ]
def test_automatic_coverage_deterministic():
legacy = {
name
for name, obj in legacy_transforms.__dict__.items()
if not name.startswith("_")
and isinstance(obj, type)
and not issubclass(obj, enum.Enum)
and name
not in {
"Compose",
# This framework is based on the assumption that the input image can always be a tensor and optionally a
# PIL image. The transforms below require a non-tensor input and thus have to be tested manually.
"PILToTensor",
"ToTensor",
}
}
# filter out random transformations
legacy = {name for name in legacy if "Random" not in name} - {
"AugMix",
"TrivialAugmentWide",
"GaussianBlur",
"RandAugment",
"AutoAugment",
"ColorJitter",
"ElasticTransform",
}
prototype = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}
missing = legacy - prototype
if missing:
raise AssertionError(
f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} "
f"are not checked for consistency although a legacy counterpart exists."
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"), ("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"),
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS), itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS),
...@@ -204,11 +282,13 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -204,11 +282,13 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
image_tensor = torch.Tensor(image) image_tensor = torch.Tensor(image)
image_pil = to_image_pil(image) if image.ndim == 3 and supports_pil else None image_pil = to_image_pil(image) if image.ndim == 3 and supports_pil else None
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
try: try:
output_legacy_tensor = legacy(image_tensor) output_legacy_tensor = legacy(image_tensor)
except Exception as exc: except Exception as exc:
raise pytest.UsageError( raise pytest.UsageError(
f"Transforming a tensor image with shape {tuple(image.shape)} failed in the legacy transform with the " f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
f"error above. This means that you need to specify the parameters passed to `make_images` through the " f"error above. This means that you need to specify the parameters passed to `make_images` through the "
"`make_images_kwargs` of the `ConsistencyConfig`." "`make_images_kwargs` of the `ConsistencyConfig`."
) from exc ) from exc
...@@ -217,7 +297,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -217,7 +297,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_prototype_tensor = prototype(image_tensor) output_prototype_tensor = prototype(image_tensor)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
f"Transforming a tensor image with shape {tuple(image.shape)} failed in the prototype transform with " f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the " f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`is_simple_tensor` path in `_transform`." f"`is_simple_tensor` path in `_transform`."
) from exc ) from exc
...@@ -232,7 +312,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -232,7 +312,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_prototype_image = prototype(image) output_prototype_image = prototype(image)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
f"Transforming a feature image with shape {tuple(image.shape)} failed in the prototype transform with " f"Transforming a feature image with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the " f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`features.Image` path in `_transform`." f"`features.Image` path in `_transform`."
) from exc ) from exc
...@@ -248,7 +328,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -248,7 +328,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_legacy_pil = legacy(image_pil) output_legacy_pil = legacy(image_pil)
except Exception as exc: except Exception as exc:
raise pytest.UsageError( raise pytest.UsageError(
f"Transforming a PIL image with shape {tuple(image.shape)} failed in the legacy transform with the " f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
f"error above. If this transform does not support PIL images, set `supports_pil=False` on the " f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
"`ConsistencyConfig`. " "`ConsistencyConfig`. "
) from exc ) from exc
...@@ -257,7 +337,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -257,7 +337,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_prototype_pil = prototype(image_pil) output_prototype_pil = prototype(image_pil)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
f"Transforming a PIL image with shape {tuple(image.shape)} failed in the prototype transform with " f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the " f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`PIL.Image.Image` path in `_transform`." f"`PIL.Image.Image` path in `_transform`."
) from exc ) from exc
...@@ -267,3 +347,25 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -267,3 +347,25 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_legacy_pil, output_legacy_pil,
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}", msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
) )
class TestToTensorTransforms:
def test_pil_to_tensor(self):
for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image)
prototype_transform = prototype_transforms.PILToTensor()
legacy_transform = legacy_transforms.PILToTensor()
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
def test_to_tensor(self):
for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image)
image_numpy = np.array(image_pil)
prototype_transform = prototype_transforms.ToTensor()
legacy_transform = legacy_transforms.ToTensor()
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
...@@ -23,7 +23,7 @@ class Lambda(Transform): ...@@ -23,7 +23,7 @@ class Lambda(Transform):
self.types = types or (object,) self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if type(inpt) in self.types: if isinstance(inpt, self.types):
return self.fn(inpt) return self.fn(inpt)
else: else:
return inpt return inpt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment