Unverified Commit 0563b18c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Expand prototype transforms consistency tests to all deterministic transformations (#6518)

* add test for coverage

* add tests for remaining deterministic transforms

* fix Lambda
parent 97e0ea9c
import enum
import functools
import itertools
import numpy as np
import PIL.Image
import pytest
......@@ -8,6 +11,7 @@ import torch
from test_prototype_transforms_functional import make_images
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor
......@@ -70,7 +74,7 @@ class ConsistencyConfig:
args_kwargs,
self.make_images_kwargs,
self.supports_pil,
id=f"{self.prototype_cls.__name__}({args_kwargs})",
id=f"{self.legacy_cls.__name__}({args_kwargs})",
)
for args_kwargs in self.transform_args_kwargs
]
......@@ -174,9 +178,83 @@ CONSISTENCY_CONFIGS = [
DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[features.ColorSpace.RGB, features.ColorSpace.GRAY]
),
),
ConsistencyConfig(
prototype_transforms.ConvertImageDtype,
legacy_transforms.ConvertImageDtype,
[
ArgsKwargs(torch.float16),
ArgsKwargs(torch.bfloat16),
ArgsKwargs(torch.float32),
ArgsKwargs(torch.float64),
ArgsKwargs(torch.uint8),
],
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.ToPILImage,
legacy_transforms.ToPILImage,
[ArgsKwargs()],
make_images_kwargs=dict(
color_spaces=[
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
],
extra_dims=[()],
),
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.Lambda,
legacy_transforms.Lambda,
[
ArgsKwargs(lambda image: image / 2),
],
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
]
def test_automatic_coverage_deterministic():
legacy = {
name
for name, obj in legacy_transforms.__dict__.items()
if not name.startswith("_")
and isinstance(obj, type)
and not issubclass(obj, enum.Enum)
and name
not in {
"Compose",
# This framework is based on the assumption that the input image can always be a tensor and optionally a
# PIL image. The transforms below require a non-tensor input and thus have to be tested manually.
"PILToTensor",
"ToTensor",
}
}
# filter out random transformations
legacy = {name for name in legacy if "Random" not in name} - {
"AugMix",
"TrivialAugmentWide",
"GaussianBlur",
"RandAugment",
"AutoAugment",
"ColorJitter",
"ElasticTransform",
}
prototype = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}
missing = legacy - prototype
if missing:
raise AssertionError(
f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} "
f"are not checked for consistency although a legacy counterpart exists."
)
@pytest.mark.parametrize(
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"),
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS),
......@@ -204,11 +282,13 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
image_tensor = torch.Tensor(image)
image_pil = to_image_pil(image) if image.ndim == 3 and supports_pil else None
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
try:
output_legacy_tensor = legacy(image_tensor)
except Exception as exc:
raise pytest.UsageError(
f"Transforming a tensor image with shape {tuple(image.shape)} failed in the legacy transform with the "
f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
f"error above. This means that you need to specify the parameters passed to `make_images` through the "
"`make_images_kwargs` of the `ConsistencyConfig`."
) from exc
......@@ -217,7 +297,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_prototype_tensor = prototype(image_tensor)
except Exception as exc:
raise AssertionError(
f"Transforming a tensor image with shape {tuple(image.shape)} failed in the prototype transform with "
f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`is_simple_tensor` path in `_transform`."
) from exc
......@@ -232,7 +312,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_prototype_image = prototype(image)
except Exception as exc:
raise AssertionError(
f"Transforming a feature image with shape {tuple(image.shape)} failed in the prototype transform with "
f"Transforming a feature image with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`features.Image` path in `_transform`."
) from exc
......@@ -248,7 +328,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_legacy_pil = legacy(image_pil)
except Exception as exc:
raise pytest.UsageError(
f"Transforming a PIL image with shape {tuple(image.shape)} failed in the legacy transform with the "
f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
"`ConsistencyConfig`. "
) from exc
......@@ -257,7 +337,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_prototype_pil = prototype(image_pil)
except Exception as exc:
raise AssertionError(
f"Transforming a PIL image with shape {tuple(image.shape)} failed in the prototype transform with "
f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`PIL.Image.Image` path in `_transform`."
) from exc
......@@ -267,3 +347,25 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
output_legacy_pil,
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
)
class TestToTensorTransforms:
def test_pil_to_tensor(self):
for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image)
prototype_transform = prototype_transforms.PILToTensor()
legacy_transform = legacy_transforms.PILToTensor()
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
def test_to_tensor(self):
for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image)
image_numpy = np.array(image_pil)
prototype_transform = prototype_transforms.ToTensor()
legacy_transform = legacy_transforms.ToTensor()
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
......@@ -23,7 +23,7 @@ class Lambda(Transform):
self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if type(inpt) in self.types:
if isinstance(inpt, self.types):
return self.fn(inpt)
else:
return inpt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment