"docs/vscode:/vscode.git/clone" did not exist on "c2d4a3b5c7bb6a8367c00f7c797bf87f4b2fcef9"
Unverified Commit a5b3118f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add consistency tests for prototype container transforms (#6525)

* add consistency tests for prototype container transforms

* fix RandomApply
parent 54dd0a59
...@@ -464,38 +464,18 @@ def test_automatic_coverage_deterministic(): ...@@ -464,38 +464,18 @@ def test_automatic_coverage_deterministic():
) )
@pytest.mark.parametrize( def check_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"), if images is None:
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS), images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
)
def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, make_images_kwargs, supports_pil):
args, kwargs = args_kwargs
try:
legacy = legacy_transform_cls(*args, **kwargs)
except Exception as exc:
raise pytest.UsageError(
f"Initializing the legacy transform failed with the error above. "
f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
) from exc
try: for image in images:
prototype = prototype_transform_cls(*args, **kwargs) image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
except Exception as exc:
raise AssertionError(
"Initializing the prototype transform failed with the error above. "
"This means there is a consistency bug in the constructor."
) from exc
for image in make_images(**make_images_kwargs):
image_tensor = torch.Tensor(image) image_tensor = torch.Tensor(image)
image_pil = to_image_pil(image) if image.ndim == 3 and supports_pil else None
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
try: try:
torch.manual_seed(0) torch.manual_seed(0)
output_legacy_tensor = legacy(image_tensor) output_legacy_tensor = legacy_transform(image_tensor)
except Exception as exc: except Exception as exc:
raise pytest.UsageError( raise pytest.UsageError(
f"Transforming a tensor image {image_repr} failed in the legacy transform with the " f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
...@@ -505,7 +485,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -505,7 +485,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
try: try:
torch.manual_seed(0) torch.manual_seed(0)
output_prototype_tensor = prototype(image_tensor) output_prototype_tensor = prototype_transform(image_tensor)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with " f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
...@@ -521,7 +501,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -521,7 +501,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
try: try:
torch.manual_seed(0) torch.manual_seed(0)
output_prototype_image = prototype(image) output_prototype_image = prototype_transform(image)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
f"Transforming a feature image with shape {image_repr} failed in the prototype transform with " f"Transforming a feature image with shape {image_repr} failed in the prototype transform with "
...@@ -535,10 +515,12 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -535,10 +515,12 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}", msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
) )
if image_pil is not None: if image.ndim == 3 and supports_pil:
image_pil = to_image_pil(image)
try: try:
torch.manual_seed(0) torch.manual_seed(0)
output_legacy_pil = legacy(image_pil) output_legacy_pil = legacy_transform(image_pil)
except Exception as exc: except Exception as exc:
raise pytest.UsageError( raise pytest.UsageError(
f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the " f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
...@@ -548,7 +530,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -548,7 +530,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
try: try:
torch.manual_seed(0) torch.manual_seed(0)
output_prototype_pil = prototype(image_pil) output_prototype_pil = prototype_transform(image_pil)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with " f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
...@@ -563,23 +545,116 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -563,23 +545,116 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
) )
@pytest.mark.parametrize(
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"),
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS),
)
def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, make_images_kwargs, supports_pil):
args, kwargs = args_kwargs
try:
legacy_transform = legacy_transform_cls(*args, **kwargs)
except Exception as exc:
raise pytest.UsageError(
f"Initializing the legacy transform failed with the error above. "
f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
) from exc
try:
prototype_transform = prototype_transform_cls(*args, **kwargs)
except Exception as exc:
raise AssertionError(
"Initializing the prototype transform failed with the error above. "
"This means there is a consistency bug in the constructor."
) from exc
check_consistency(
prototype_transform, legacy_transform, images=make_images(**make_images_kwargs), supports_pil=supports_pil
)
class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
consistency automatically tests the wrapped transforms consistency.
Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
that were already tested for consistency above.
"""
def test_compose(self):
prototype_transform = prototype_transforms.Compose(
[
prototype_transforms.Resize(256),
prototype_transforms.CenterCrop(224),
]
)
legacy_transform = legacy_transforms.Compose(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
)
check_consistency(prototype_transform, legacy_transform)
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
def test_random_apply(self, p):
prototype_transform = prototype_transforms.RandomApply(
[
prototype_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)
legacy_transform = legacy_transforms.RandomApply(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)
check_consistency(prototype_transform, legacy_transform)
# We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("p", [(0, 1), (1, 0)])
def test_random_choice(self, p):
prototype_transform = prototype_transforms.RandomChoice(
[
prototype_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)
legacy_transform = legacy_transforms.RandomChoice(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)
check_consistency(prototype_transform, legacy_transform)
class TestToTensorTransforms: class TestToTensorTransforms:
def test_pil_to_tensor(self): def test_pil_to_tensor(self):
for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image)
prototype_transform = prototype_transforms.PILToTensor() prototype_transform = prototype_transforms.PILToTensor()
legacy_transform = legacy_transforms.PILToTensor() legacy_transform = legacy_transforms.PILToTensor()
for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image)
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil)) assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
def test_to_tensor(self): def test_to_tensor(self):
prototype_transform = prototype_transforms.ToTensor()
legacy_transform = legacy_transforms.ToTensor()
for image in make_images(extra_dims=[()]): for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image) image_pil = to_image_pil(image)
image_numpy = np.array(image_pil) image_numpy = np.array(image_pil)
prototype_transform = prototype_transforms.ToTensor()
legacy_transform = legacy_transforms.ToTensor()
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil)) assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy)) assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence from typing import Any, Callable, List, Optional, Sequence
import torch import torch
from torchvision.prototype.transforms import Transform from torchvision.prototype.transforms import Transform
from ._transform import _RandomApplyTransform
class Compose(Transform): class Compose(Transform):
def __init__(self, transforms: Sequence[Callable]) -> None: def __init__(self, transforms: Sequence[Callable]) -> None:
...@@ -21,16 +19,21 @@ class Compose(Transform): ...@@ -21,16 +19,21 @@ class Compose(Transform):
return sample return sample
class RandomApply(_RandomApplyTransform): class RandomApply(Compose):
def __init__(self, transform: Transform, p: float = 0.5) -> None: def __init__(self, transforms: Sequence[Callable], p: float = 0.5) -> None:
super().__init__(p=p) super().__init__(transforms)
self.transform = transform
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if torch.rand(1) >= self.p:
return self.transform(inpt) return sample
def extra_repr(self) -> str: return super().forward(sample)
return f"p={self.p}"
class RandomChoice(Transform): class RandomChoice(Transform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment