Unverified Commit 8154c92a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add test for signature consistency of prototype and legacy transforms (#6526)

* add test for signature consistency

* fix CenterCrop and Lambda

* add removed params

* cleanup
parent a5b3118f
import enum import enum
import functools import functools
import inspect
import itertools import itertools
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
...@@ -58,26 +58,20 @@ class ArgsKwargs: ...@@ -58,26 +58,20 @@ class ArgsKwargs:
class ConsistencyConfig: class ConsistencyConfig:
def __init__( def __init__(
self, prototype_cls, legacy_cls, transform_args_kwargs=None, make_images_kwargs=None, supports_pil=True self,
prototype_cls,
legacy_cls,
args_kwargs,
make_images_kwargs=None,
supports_pil=True,
removed_params=(),
): ):
self.prototype_cls = prototype_cls self.prototype_cls = prototype_cls
self.legacy_cls = legacy_cls self.legacy_cls = legacy_cls
self.transform_args_kwargs = transform_args_kwargs or [((), dict())] self.args_kwargs = args_kwargs
self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
self.supports_pil = supports_pil self.supports_pil = supports_pil
self.removed_params = removed_params
def parametrization(self):
return [
pytest.param(
self.prototype_cls,
self.legacy_cls,
args_kwargs,
self.make_images_kwargs,
self.supports_pil,
id=f"{self.legacy_cls.__name__}({args_kwargs})",
)
for args_kwargs in self.transform_args_kwargs
]
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters # These are here since both the prototype and legacy transform need to be constructed with the same random parameters
...@@ -93,6 +87,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -93,6 +87,7 @@ CONSISTENCY_CONFIGS = [
], ],
supports_pil=False, supports_pil=False,
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
removed_params=["inplace"],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.Resize, prototype_transforms.Resize,
...@@ -319,6 +314,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -319,6 +314,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1, value="random"), ArgsKwargs(p=1, value="random"),
], ],
supports_pil=False, supports_pil=False,
removed_params=["inplace"],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.ColorJitter, prototype_transforms.ColorJitter,
...@@ -379,6 +375,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -379,6 +375,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)), ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
ArgsKwargs(degrees=30.0, center=(0, 0)), ArgsKwargs(degrees=30.0, center=(0, 0)),
], ],
removed_params=["fillcolor", "resample"],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomCrop, prototype_transforms.RandomCrop,
...@@ -423,12 +420,13 @@ CONSISTENCY_CONFIGS = [ ...@@ -423,12 +420,13 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(degrees=30.0, fill=1), ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(1, 2, 3)), ArgsKwargs(degrees=30.0, fill=(1, 2, 3)),
], ],
removed_params=["resample"],
), ),
] ]
def test_automatic_coverage_deterministic(): def test_automatic_coverage():
legacy = { available = {
name name
for name, obj in legacy_transforms.__dict__.items() for name, obj in legacy_transforms.__dict__.items()
if not name.startswith("_") if not name.startswith("_")
...@@ -454,9 +452,9 @@ def test_automatic_coverage_deterministic(): ...@@ -454,9 +452,9 @@ def test_automatic_coverage_deterministic():
} }
} }
prototype = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS} checked = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}
missing = legacy - prototype missing = available - checked
if missing: if missing:
raise AssertionError( raise AssertionError(
f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} " f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} "
...@@ -464,7 +462,37 @@ def test_automatic_coverage_deterministic(): ...@@ -464,7 +462,37 @@ def test_automatic_coverage_deterministic():
) )
def check_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True): @pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
prototype_params = dict(inspect.signature(config.prototype_cls).parameters)
for param in config.removed_params:
legacy_params.pop(param, None)
missing = legacy_params.keys() - prototype_params.keys()
if missing:
raise AssertionError(
f"The prototype transform does not support the parameters "
f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
f"the `ConsistencyConfig`."
)
extra = prototype_params.keys() - legacy_params.keys()
extra_without_default = {param for param in extra if prototype_params[param].default is not inspect.Parameter.empty}
if extra_without_default:
raise AssertionError(
f"The prototype transform requires the parameters {sequence_to_str(sorted(missing), separate_last='and ')}, "
f"but the legacy transform does not. Please add a default value."
)
for name, legacy_param in legacy_params.items():
prototype_param = prototype_params[name]
assert prototype_param.kind is legacy_param.kind
def check_call_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
if images is None: if images is None:
images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS) images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
...@@ -546,14 +574,18 @@ def check_consistency(prototype_transform, legacy_transform, images=None, suppor ...@@ -546,14 +574,18 @@ def check_consistency(prototype_transform, legacy_transform, images=None, suppor
@pytest.mark.parametrize( @pytest.mark.parametrize(
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"), ("config", "args_kwargs"),
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS), [
pytest.param(config, args_kwargs, id=f"{config.legacy_cls.__name__}({args_kwargs})")
for config in CONSISTENCY_CONFIGS
for args_kwargs in config.args_kwargs
],
) )
def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, make_images_kwargs, supports_pil): def test_call_consistency(config, args_kwargs):
args, kwargs = args_kwargs args, kwargs = args_kwargs
try: try:
legacy_transform = legacy_transform_cls(*args, **kwargs) legacy_transform = config.legacy_cls(*args, **kwargs)
except Exception as exc: except Exception as exc:
raise pytest.UsageError( raise pytest.UsageError(
f"Initializing the legacy transform failed with the error above. " f"Initializing the legacy transform failed with the error above. "
...@@ -561,15 +593,18 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ...@@ -561,15 +593,18 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
) from exc ) from exc
try: try:
prototype_transform = prototype_transform_cls(*args, **kwargs) prototype_transform = config.prototype_cls(*args, **kwargs)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
"Initializing the prototype transform failed with the error above. " "Initializing the prototype transform failed with the error above. "
"This means there is a consistency bug in the constructor." "This means there is a consistency bug in the constructor."
) from exc ) from exc
check_consistency( check_call_consistency(
prototype_transform, legacy_transform, images=make_images(**make_images_kwargs), supports_pil=supports_pil prototype_transform,
legacy_transform,
images=make_images(**config.make_images_kwargs),
supports_pil=config.supports_pil,
) )
...@@ -596,7 +631,7 @@ class TestContainerTransforms: ...@@ -596,7 +631,7 @@ class TestContainerTransforms:
] ]
) )
check_consistency(prototype_transform, legacy_transform) check_call_consistency(prototype_transform, legacy_transform)
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1]) @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
def test_random_apply(self, p): def test_random_apply(self, p):
...@@ -615,7 +650,7 @@ class TestContainerTransforms: ...@@ -615,7 +650,7 @@ class TestContainerTransforms:
p=p, p=p,
) )
check_consistency(prototype_transform, legacy_transform) check_call_consistency(prototype_transform, legacy_transform)
# We can't test other values for `p` since the random parameter generation is different # We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("p", [(0, 1), (1, 0)]) @pytest.mark.parametrize("p", [(0, 1), (1, 0)])
...@@ -635,7 +670,7 @@ class TestContainerTransforms: ...@@ -635,7 +670,7 @@ class TestContainerTransforms:
p=p, p=p,
) )
check_consistency(prototype_transform, legacy_transform) check_call_consistency(prototype_transform, legacy_transform)
class TestToTensorTransforms: class TestToTensorTransforms:
......
...@@ -58,12 +58,12 @@ class Resize(Transform): ...@@ -58,12 +58,12 @@ class Resize(Transform):
class CenterCrop(Transform): class CenterCrop(Transform):
def __init__(self, output_size: List[int]): def __init__(self, size: List[int]):
super().__init__() super().__init__()
self.output_size = output_size self.size = size
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.center_crop(inpt, output_size=self.output_size) return F.center_crop(inpt, output_size=self.size)
class RandomResizedCrop(Transform): class RandomResizedCrop(Transform):
......
...@@ -17,20 +17,20 @@ class Identity(Transform): ...@@ -17,20 +17,20 @@ class Identity(Transform):
class Lambda(Transform): class Lambda(Transform):
def __init__(self, fn: Callable[[Any], Any], *types: Type): def __init__(self, lambd: Callable[[Any], Any], *types: Type):
super().__init__() super().__init__()
self.fn = fn self.lambd = lambd
self.types = types or (object,) self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, self.types): if isinstance(inpt, self.types):
return self.fn(inpt) return self.lambd(inpt)
else: else:
return inpt return inpt
def extra_repr(self) -> str: def extra_repr(self) -> str:
extras = [] extras = []
name = getattr(self.fn, "__name__", None) name = getattr(self.lambd, "__name__", None)
if name: if name:
extras.append(name) extras.append(name)
extras.append(f"types={[type.__name__ for type in self.types]}") extras.append(f"types={[type.__name__ for type in self.types]}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment