Unverified Commit fae53217 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Enable v1 vs. v2 consistency in refactored tests (#7882)

parent 47cd5ea8
import enum
import importlib.machinery
import importlib.util
import inspect
......@@ -83,35 +82,6 @@ CONSISTENCY_CONFIGS = [
supports_pil=False,
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
),
ConsistencyConfig(
v2_transforms.Resize,
legacy_transforms.Resize,
[
NotScriptableArgsKwargs(32),
ArgsKwargs([32]),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
ArgsKwargs([31], max_size=32),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.Resize,
legacy_transforms.Resize,
[
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC, antialias=True),
],
closeness_kwargs=dict(rtol=0, atol=21),
),
ConsistencyConfig(
v2_transforms.CenterCrop,
legacy_transforms.CenterCrop,
......@@ -187,20 +157,6 @@ CONSISTENCY_CONFIGS = [
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
v2_transforms.ConvertImageDtype,
legacy_transforms.ConvertImageDtype,
[
ArgsKwargs(torch.float16),
ArgsKwargs(torch.bfloat16),
ArgsKwargs(torch.float32),
ArgsKwargs(torch.float64),
ArgsKwargs(torch.uint8),
],
supports_pil=False,
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
v2_transforms.ToPILImage,
legacy_transforms.ToPILImage,
......@@ -226,22 +182,6 @@ CONSISTENCY_CONFIGS = [
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.RandomHorizontalFlip,
legacy_transforms.RandomHorizontalFlip,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
v2_transforms.RandomVerticalFlip,
legacy_transforms.RandomVerticalFlip,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
v2_transforms.RandomEqualize,
legacy_transforms.RandomEqualize,
......@@ -367,30 +307,6 @@ CONSISTENCY_CONFIGS = [
],
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
),
*[
ConsistencyConfig(
v2_transforms.ElasticTransform,
legacy_transforms.ElasticTransform,
[
ArgsKwargs(),
ArgsKwargs(alpha=20.0),
ArgsKwargs(alpha=(15.3, 27.2)),
ArgsKwargs(sigma=3.0),
ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(interpolation=PIL.Image.NEAREST),
ArgsKwargs(interpolation=PIL.Image.BICUBIC),
ArgsKwargs(fill=1),
],
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)], dtypes=[dt]),
# We updated gaussian blur kernel generation with a faster and numerically more stable version
# This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance
closeness_kwargs=ckw,
)
for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})]
],
ConsistencyConfig(
v2_transforms.GaussianBlur,
legacy_transforms.GaussianBlur,
......@@ -402,26 +318,6 @@ CONSISTENCY_CONFIGS = [
],
closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
),
ConsistencyConfig(
v2_transforms.RandomAffine,
legacy_transforms.RandomAffine,
[
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=0.0, translate=(0.4, 0.6)),
ArgsKwargs(degrees=0.0, scale=(0.3, 0.8)),
ArgsKwargs(degrees=0.0, shear=13),
ArgsKwargs(degrees=0.0, shear=(8, 17)),
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
ArgsKwargs(degrees=30.0, center=(0, 0)),
],
removed_params=["fillcolor", "resample"],
),
ConsistencyConfig(
v2_transforms.RandomCrop,
legacy_transforms.RandomCrop,
......@@ -456,21 +352,6 @@ CONSISTENCY_CONFIGS = [
],
closeness_kwargs={"atol": None, "rtol": None},
),
ConsistencyConfig(
v2_transforms.RandomRotation,
legacy_transforms.RandomRotation,
[
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
ArgsKwargs(degrees=30.0, expand=True),
ArgsKwargs(degrees=30.0, center=(0, 0)),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(1, 2, 3)),
],
removed_params=["resample"],
),
ConsistencyConfig(
v2_transforms.PILToTensor,
legacy_transforms.PILToTensor,
......@@ -514,23 +395,6 @@ CONSISTENCY_CONFIGS = [
]
def test_automatic_coverage():
available = {
name
for name, obj in legacy_transforms.__dict__.items()
if not name.startswith("_") and isinstance(obj, type) and not issubclass(obj, enum.Enum)
}
checked = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}
missing = available - checked
if missing:
raise AssertionError(
f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} "
f"are not checked for consistency although a legacy counterpart exists."
)
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
......@@ -708,15 +572,9 @@ get_params_parametrization = pytest.mark.parametrize(
(v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
(v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
(v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(v2_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
(v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(
v2_transforms.RandomAffine,
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
),
(v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
(v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(v2_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
(v2_transforms.AutoAugment, ArgsKwargs(5)),
]
],
......
......@@ -228,26 +228,37 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type):
assert functional_param == kernel_param
def _check_transform_v1_compatibility(transform, input):
def _check_transform_v1_compatibility(transform, input, rtol, atol):
"""If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
``get_params`` method, is scriptable, and the scripted version can be called without error."""
if transform._v1_transform_cls is None:
``get_params`` method that is the v1 equivalent, the output is close to v1, is scriptable, and the scripted version
can be called without error."""
if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image):
return
if type(input) is not torch.Tensor:
v1_transform_cls = transform._v1_transform_cls
if v1_transform_cls is None:
return
if hasattr(transform._v1_transform_cls, "get_params"):
assert type(transform).get_params is transform._v1_transform_cls.get_params
if hasattr(v1_transform_cls, "get_params"):
assert type(transform).get_params is v1_transform_cls.get_params
scripted_transform = _script(transform)
with ignore_jit_no_profile_information_warning():
scripted_transform(input)
v1_transform = v1_transform_cls(**transform._extract_params_for_v1_transform())
with freeze_rng_state():
output_v2 = transform(input)
with freeze_rng_state():
output_v1 = v1_transform(input)
assert_close(output_v2, output_v1, rtol=rtol, atol=atol)
if isinstance(input, PIL.Image.Image):
return
_script(v1_transform)(input)
def check_transform(transform_cls, input, *args, **kwargs):
transform = transform_cls(*args, **kwargs)
def check_transform(transform, input, check_v1_compatibility=True):
pickle.loads(pickle.dumps(transform))
output = transform(input)
......@@ -256,7 +267,8 @@ def check_transform(transform_cls, input, *args, **kwargs):
if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format
_check_transform_v1_compatibility(transform, input)
if check_v1_compatibility:
_check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))
def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
......@@ -524,7 +536,12 @@ class TestResize:
],
)
def test_transform(self, size, device, make_input):
check_transform(transforms.Resize, make_input(self.INPUT_SIZE, device=device), size=size, antialias=True)
check_transform(
transforms.Resize(size=size, antialias=True),
make_input(self.INPUT_SIZE, device=device),
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_v1_compatibility=dict(rtol=0, atol=1),
)
def _check_output_size(self, input, output, *, size, max_size):
assert tuple(F.get_size(output)) == self._compute_output_size(
......@@ -848,7 +865,7 @@ class TestHorizontalFlip:
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
check_transform(transforms.RandomHorizontalFlip, make_input(device=device), p=1)
check_transform(transforms.RandomHorizontalFlip(p=1), make_input(device=device))
@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
......@@ -1026,7 +1043,7 @@ class TestAffine:
def test_transform(self, make_input, device):
input = make_input(device=device)
check_transform(transforms.RandomAffine, input, **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES)
check_transform(transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES), input)
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
......@@ -1313,7 +1330,7 @@ class TestVerticalFlip:
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
check_transform(transforms.RandomVerticalFlip, make_input(device=device), p=1)
check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device))
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_image_correctness(self, fn):
......@@ -1464,7 +1481,7 @@ class TestRotate:
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
check_transform(
transforms.RandomRotation, make_input(device=device), **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES
transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES), make_input(device=device)
)
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
......@@ -1726,7 +1743,7 @@ class TestToDtype:
input = make_input(dtype=input_dtype, device=device)
if as_dict:
output_dtype = {type(input): output_dtype}
check_transform(transforms.ToDtype, input, dtype=output_dtype, scale=scale)
check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input)
def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False):
input_dtype = image.dtype
......@@ -2415,7 +2432,12 @@ class TestElastic:
@pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, size, device):
check_transform(transforms.ElasticTransform, make_input(size, device=device))
check_transform(
transforms.ElasticTransform(),
make_input(size, device=device),
# We updated gaussian blur kernel generation with a faster and numerically more stable version
check_v1_compatibility=dict(rtol=0, atol=1),
)
class TestToPureTensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment