Unverified Commit 8154c92a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add test for signature consistency of prototype and legacy transforms (#6526)

* add test for signature consistency

* fix CenterCrop and Lambda

* add removed params

* cleanup
parent a5b3118f
import enum
import functools
import inspect
import itertools
import numpy as np
import PIL.Image
import pytest
......@@ -58,26 +58,20 @@ class ArgsKwargs:
class ConsistencyConfig:
def __init__(
self, prototype_cls, legacy_cls, transform_args_kwargs=None, make_images_kwargs=None, supports_pil=True
self,
prototype_cls,
legacy_cls,
args_kwargs,
make_images_kwargs=None,
supports_pil=True,
removed_params=(),
):
self.prototype_cls = prototype_cls
self.legacy_cls = legacy_cls
self.transform_args_kwargs = transform_args_kwargs or [((), dict())]
self.args_kwargs = args_kwargs
self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
self.supports_pil = supports_pil
def parametrization(self):
return [
pytest.param(
self.prototype_cls,
self.legacy_cls,
args_kwargs,
self.make_images_kwargs,
self.supports_pil,
id=f"{self.legacy_cls.__name__}({args_kwargs})",
)
for args_kwargs in self.transform_args_kwargs
]
self.removed_params = removed_params
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
......@@ -93,6 +87,7 @@ CONSISTENCY_CONFIGS = [
],
supports_pil=False,
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
removed_params=["inplace"],
),
ConsistencyConfig(
prototype_transforms.Resize,
......@@ -319,6 +314,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1, value="random"),
],
supports_pil=False,
removed_params=["inplace"],
),
ConsistencyConfig(
prototype_transforms.ColorJitter,
......@@ -379,6 +375,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
ArgsKwargs(degrees=30.0, center=(0, 0)),
],
removed_params=["fillcolor", "resample"],
),
ConsistencyConfig(
prototype_transforms.RandomCrop,
......@@ -423,12 +420,13 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(1, 2, 3)),
],
removed_params=["resample"],
),
]
def test_automatic_coverage_deterministic():
legacy = {
def test_automatic_coverage():
available = {
name
for name, obj in legacy_transforms.__dict__.items()
if not name.startswith("_")
......@@ -454,9 +452,9 @@ def test_automatic_coverage_deterministic():
}
}
prototype = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}
checked = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}
missing = legacy - prototype
missing = available - checked
if missing:
raise AssertionError(
f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} "
......@@ -464,7 +462,37 @@ def test_automatic_coverage_deterministic():
)
def check_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
prototype_params = dict(inspect.signature(config.prototype_cls).parameters)
for param in config.removed_params:
legacy_params.pop(param, None)
missing = legacy_params.keys() - prototype_params.keys()
if missing:
raise AssertionError(
f"The prototype transform does not support the parameters "
f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
f"the `ConsistencyConfig`."
)
extra = prototype_params.keys() - legacy_params.keys()
extra_without_default = {param for param in extra if prototype_params[param].default is not inspect.Parameter.empty}
if extra_without_default:
raise AssertionError(
f"The prototype transform requires the parameters {sequence_to_str(sorted(missing), separate_last='and ')}, "
f"but the legacy transform does not. Please add a default value."
)
for name, legacy_param in legacy_params.items():
prototype_param = prototype_params[name]
assert prototype_param.kind is legacy_param.kind
def check_call_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
if images is None:
images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
......@@ -546,14 +574,18 @@ def check_consistency(prototype_transform, legacy_transform, images=None, suppor
@pytest.mark.parametrize(
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"),
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS),
("config", "args_kwargs"),
[
pytest.param(config, args_kwargs, id=f"{config.legacy_cls.__name__}({args_kwargs})")
for config in CONSISTENCY_CONFIGS
for args_kwargs in config.args_kwargs
],
)
def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, make_images_kwargs, supports_pil):
def test_call_consistency(config, args_kwargs):
args, kwargs = args_kwargs
try:
legacy_transform = legacy_transform_cls(*args, **kwargs)
legacy_transform = config.legacy_cls(*args, **kwargs)
except Exception as exc:
raise pytest.UsageError(
f"Initializing the legacy transform failed with the error above. "
......@@ -561,15 +593,18 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
) from exc
try:
prototype_transform = prototype_transform_cls(*args, **kwargs)
prototype_transform = config.prototype_cls(*args, **kwargs)
except Exception as exc:
raise AssertionError(
"Initializing the prototype transform failed with the error above. "
"This means there is a consistency bug in the constructor."
) from exc
check_consistency(
prototype_transform, legacy_transform, images=make_images(**make_images_kwargs), supports_pil=supports_pil
check_call_consistency(
prototype_transform,
legacy_transform,
images=make_images(**config.make_images_kwargs),
supports_pil=config.supports_pil,
)
......@@ -596,7 +631,7 @@ class TestContainerTransforms:
]
)
check_consistency(prototype_transform, legacy_transform)
check_call_consistency(prototype_transform, legacy_transform)
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
def test_random_apply(self, p):
......@@ -615,7 +650,7 @@ class TestContainerTransforms:
p=p,
)
check_consistency(prototype_transform, legacy_transform)
check_call_consistency(prototype_transform, legacy_transform)
# We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("p", [(0, 1), (1, 0)])
......@@ -635,7 +670,7 @@ class TestContainerTransforms:
p=p,
)
check_consistency(prototype_transform, legacy_transform)
check_call_consistency(prototype_transform, legacy_transform)
class TestToTensorTransforms:
......
......@@ -58,12 +58,12 @@ class Resize(Transform):
class CenterCrop(Transform):
def __init__(self, output_size: List[int]):
def __init__(self, size: List[int]):
super().__init__()
self.output_size = output_size
self.size = size
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.center_crop(inpt, output_size=self.output_size)
return F.center_crop(inpt, output_size=self.size)
class RandomResizedCrop(Transform):
......
......@@ -17,20 +17,20 @@ class Identity(Transform):
class Lambda(Transform):
def __init__(self, fn: Callable[[Any], Any], *types: Type):
def __init__(self, lambd: Callable[[Any], Any], *types: Type):
super().__init__()
self.fn = fn
self.lambd = lambd
self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, self.types):
return self.fn(inpt)
return self.lambd(inpt)
else:
return inpt
def extra_repr(self) -> str:
extras = []
name = getattr(self.fn, "__name__", None)
name = getattr(self.lambd, "__name__", None)
if name:
extras.append(name)
extras.append(f"types={[type.__name__ for type in self.types]}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment