Unverified Commit bd70a780 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

test all transform signatures for consistency (#6527)

* check signatures of all available transformations

* fix signatures of RandAugment and TrivialAugmentWide

* move AA consistency tests to correct module
parent 8154c92a
......@@ -1665,205 +1665,3 @@ class TestLabelToOneHot:
assert isinstance(ohe_labels, features.OneHotLabel)
assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories
class TestAPIConsistency:
@pytest.mark.parametrize("antialias", [True, False])
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
def test_random_resized_crop(self, antialias, inpt):
from torchvision.transforms import transforms as ref_transforms
size = 224
t_ref = ref_transforms.RandomResizedCrop(size, antialias=antialias)
t = transforms.RandomResizedCrop(size, antialias=antialias)
torch.manual_seed(12)
expected_output = t_ref(inpt)
torch.manual_seed(12)
output = t(inpt)
if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)
torch.testing.assert_close(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_randaug(self, inpt, interpolation, mocker):
from torchvision.transforms import autoaugment as ref_transforms
t_ref = ref_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = transforms.RandAugment(interpolation=interpolation, num_ops=1)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
# Stable API, if signed there is another random call
if t._AUGMENTATION_SPACE[keys[i]][1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
for i in range(le):
expected_output = t_ref(inpt)
output = t(inpt)
if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)
torch.testing.assert_close(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_trivial_aug(self, inpt, interpolation, mocker):
from torchvision.transforms import autoaugment as ref_transforms
t_ref = ref_transforms.TrivialAugmentWide(interpolation=interpolation)
t = transforms.TrivialAugmentWide(interpolation=interpolation)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
for _ in range(le):
expected_output = t_ref(inpt)
output = t(inpt)
if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)
torch.testing.assert_close(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_augmix(self, inpt, interpolation, mocker):
from torchvision.transforms import autoaugment as ref_transforms
t_ref = ref_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
t = transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t._sample_dirichlet = lambda t: t.softmax(dim=-1)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
expected_output = t_ref(inpt)
output = t(inpt)
if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)
torch.testing.assert_close(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_aa(self, inpt, interpolation):
from torchvision.transforms import autoaugment as ref_transforms
aa_policy = ref_transforms.AutoAugmentPolicy("imagenet")
t_ref = ref_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = transforms.AutoAugment(aa_policy, interpolation=interpolation)
torch.manual_seed(12)
expected_output = t_ref(inpt)
torch.manual_seed(12)
output = t(inpt)
if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)
torch.testing.assert_close(expected_output, output)
......@@ -61,7 +61,8 @@ class ConsistencyConfig:
self,
prototype_cls,
legacy_cls,
args_kwargs,
# If no args_kwargs is passed, only the signature will be checked
args_kwargs=(),
make_images_kwargs=None,
supports_pil=True,
removed_params=(),
......@@ -422,6 +423,46 @@ CONSISTENCY_CONFIGS = [
],
removed_params=["resample"],
),
ConsistencyConfig(
prototype_transforms.PILToTensor,
legacy_transforms.PILToTensor,
),
ConsistencyConfig(
prototype_transforms.ToTensor,
legacy_transforms.ToTensor,
),
ConsistencyConfig(
prototype_transforms.Compose,
legacy_transforms.Compose,
),
ConsistencyConfig(
prototype_transforms.RandomApply,
legacy_transforms.RandomApply,
),
ConsistencyConfig(
prototype_transforms.RandomChoice,
legacy_transforms.RandomChoice,
),
ConsistencyConfig(
prototype_transforms.RandomOrder,
legacy_transforms.RandomOrder,
),
ConsistencyConfig(
prototype_transforms.AugMix,
legacy_transforms.AugMix,
),
ConsistencyConfig(
prototype_transforms.AutoAugment,
legacy_transforms.AutoAugment,
),
ConsistencyConfig(
prototype_transforms.RandAugment,
legacy_transforms.RandAugment,
),
ConsistencyConfig(
prototype_transforms.TrivialAugmentWide,
legacy_transforms.TrivialAugmentWide,
),
]
......@@ -429,27 +470,7 @@ def test_automatic_coverage():
available = {
name
for name, obj in legacy_transforms.__dict__.items()
if not name.startswith("_")
and isinstance(obj, type)
and not issubclass(obj, enum.Enum)
and name
not in {
# This framework is based on the assumption that the input image can always be a tensor and optionally a
# PIL image, but the transforms below require a non-tensor input.
"PILToTensor",
"ToTensor",
# Transform containers cannot be tested without other tranforms
"Compose",
"RandomApply",
"RandomChoice",
"RandomOrder",
# If the random parameter generation in the legacy and prototype transform is the same, setting the seed
# should be sufficient. In that case, the transforms below should be tested automatically.
"AugMix",
"AutoAugment",
"RandAugment",
"TrivialAugmentWide",
}
if not name.startswith("_") and isinstance(obj, type) and not issubclass(obj, enum.Enum)
}
checked = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}
......@@ -480,16 +501,22 @@ def test_signature_consistency(config):
)
extra = prototype_params.keys() - legacy_params.keys()
extra_without_default = {param for param in extra if prototype_params[param].default is not inspect.Parameter.empty}
extra_without_default = {
param
for param in extra
if prototype_params[param].default is inspect.Parameter.empty
and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
}
if extra_without_default:
raise AssertionError(
f"The prototype transform requires the parameters {sequence_to_str(sorted(missing), separate_last='and ')}, "
f"but the legacy transform does not. Please add a default value."
f"The prototype transform requires the parameters "
f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
f"not. Please add a default value."
)
for name, legacy_param in legacy_params.items():
prototype_param = prototype_params[name]
assert prototype_param.kind is legacy_param.kind
legacy_kinds = {name: param.kind for name, param in legacy_params.items()}
prototype_kinds = {name: prototype_params[name].kind for name in legacy_kinds.keys()}
assert prototype_kinds == legacy_kinds
def check_call_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
......@@ -693,3 +720,165 @@ class TestToTensorTransforms:
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
class TestAATransforms:
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
)
def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = prototype_transforms.RandAugment(interpolation=interpolation, num_ops=1)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
# Stable API, if signed there is another random call
if t._AUGMENTATION_SPACE[keys[i]][1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
for i in range(le):
expected_output = t_ref(inpt)
output = t(inpt)
assert_equal(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
)
def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
t = prototype_transforms.TrivialAugmentWide(interpolation=interpolation)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
for _ in range(le):
expected_output = t_ref(inpt)
output = t(inpt)
assert_equal(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
)
def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
t = prototype_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t._sample_dirichlet = lambda t: t.softmax(dim=-1)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
expected_output = t_ref(inpt)
output = t(inpt)
assert_equal(expected_output, output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
)
def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = prototype_transforms.AutoAugment(aa_policy, interpolation=interpolation)
torch.manual_seed(12)
expected_output = t_ref(inpt)
torch.manual_seed(12)
output = t(inpt)
assert_equal(expected_output, output)
......@@ -343,7 +343,6 @@ class RandAugment(_AutoAugmentBase):
def __init__(
self,
*,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
......@@ -402,7 +401,6 @@ class TrivialAugmentWide(_AutoAugmentBase):
def __init__(
self,
*,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment