Unverified Commit a5b3118f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add consistency tests for prototype container transforms (#6525)

* add consistency tests for prototype container transforms

* fix RandomApply
parent 54dd0a59
......@@ -464,38 +464,18 @@ def test_automatic_coverage_deterministic():
)
@pytest.mark.parametrize(
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"),
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS),
)
def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, make_images_kwargs, supports_pil):
args, kwargs = args_kwargs
try:
legacy = legacy_transform_cls(*args, **kwargs)
except Exception as exc:
raise pytest.UsageError(
f"Initializing the legacy transform failed with the error above. "
f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
) from exc
def check_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
if images is None:
images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
try:
prototype = prototype_transform_cls(*args, **kwargs)
except Exception as exc:
raise AssertionError(
"Initializing the prototype transform failed with the error above. "
"This means there is a consistency bug in the constructor."
) from exc
for image in images:
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
for image in make_images(**make_images_kwargs):
image_tensor = torch.Tensor(image)
image_pil = to_image_pil(image) if image.ndim == 3 and supports_pil else None
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
try:
torch.manual_seed(0)
output_legacy_tensor = legacy(image_tensor)
output_legacy_tensor = legacy_transform(image_tensor)
except Exception as exc:
raise pytest.UsageError(
f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
......@@ -505,7 +485,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
try:
torch.manual_seed(0)
output_prototype_tensor = prototype(image_tensor)
output_prototype_tensor = prototype_transform(image_tensor)
except Exception as exc:
raise AssertionError(
f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
......@@ -521,7 +501,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
try:
torch.manual_seed(0)
output_prototype_image = prototype(image)
output_prototype_image = prototype_transform(image)
except Exception as exc:
raise AssertionError(
f"Transforming a feature image with shape {image_repr} failed in the prototype transform with "
......@@ -535,10 +515,12 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
)
if image_pil is not None:
if image.ndim == 3 and supports_pil:
image_pil = to_image_pil(image)
try:
torch.manual_seed(0)
output_legacy_pil = legacy(image_pil)
output_legacy_pil = legacy_transform(image_pil)
except Exception as exc:
raise pytest.UsageError(
f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
......@@ -548,7 +530,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
try:
torch.manual_seed(0)
output_prototype_pil = prototype(image_pil)
output_prototype_pil = prototype_transform(image_pil)
except Exception as exc:
raise AssertionError(
f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
......@@ -563,23 +545,116 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
)
@pytest.mark.parametrize(
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"),
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS),
)
def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, make_images_kwargs, supports_pil):
args, kwargs = args_kwargs
try:
legacy_transform = legacy_transform_cls(*args, **kwargs)
except Exception as exc:
raise pytest.UsageError(
f"Initializing the legacy transform failed with the error above. "
f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
) from exc
try:
prototype_transform = prototype_transform_cls(*args, **kwargs)
except Exception as exc:
raise AssertionError(
"Initializing the prototype transform failed with the error above. "
"This means there is a consistency bug in the constructor."
) from exc
check_consistency(
prototype_transform, legacy_transform, images=make_images(**make_images_kwargs), supports_pil=supports_pil
)
class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
consistency automatically tests the wrapped transforms consistency.
Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
that were already tested for consistency above.
"""
def test_compose(self):
prototype_transform = prototype_transforms.Compose(
[
prototype_transforms.Resize(256),
prototype_transforms.CenterCrop(224),
]
)
legacy_transform = legacy_transforms.Compose(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
)
check_consistency(prototype_transform, legacy_transform)
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
def test_random_apply(self, p):
prototype_transform = prototype_transforms.RandomApply(
[
prototype_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)
legacy_transform = legacy_transforms.RandomApply(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)
check_consistency(prototype_transform, legacy_transform)
# We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("p", [(0, 1), (1, 0)])
def test_random_choice(self, p):
prototype_transform = prototype_transforms.RandomChoice(
[
prototype_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)
legacy_transform = legacy_transforms.RandomChoice(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)
check_consistency(prototype_transform, legacy_transform)
class TestToTensorTransforms:
def test_pil_to_tensor(self):
prototype_transform = prototype_transforms.PILToTensor()
legacy_transform = legacy_transforms.PILToTensor()
for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image)
prototype_transform = prototype_transforms.PILToTensor()
legacy_transform = legacy_transforms.PILToTensor()
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
def test_to_tensor(self):
prototype_transform = prototype_transforms.ToTensor()
legacy_transform = legacy_transforms.ToTensor()
for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image)
image_numpy = np.array(image_pil)
prototype_transform = prototype_transforms.ToTensor()
legacy_transform = legacy_transforms.ToTensor()
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence
import torch
from torchvision.prototype.transforms import Transform
from ._transform import _RandomApplyTransform
class Compose(Transform):
def __init__(self, transforms: Sequence[Callable]) -> None:
......@@ -21,16 +19,21 @@ class Compose(Transform):
return sample
class RandomApply(_RandomApplyTransform):
def __init__(self, transform: Transform, p: float = 0.5) -> None:
super().__init__(p=p)
self.transform = transform
class RandomApply(Compose):
def __init__(self, transforms: Sequence[Callable], p: float = 0.5) -> None:
super().__init__(transforms)
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self.transform(inpt)
if torch.rand(1) >= self.p:
return sample
def extra_repr(self) -> str:
return f"p={self.p}"
return super().forward(sample)
class RandomChoice(Transform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment