Unverified Commit 1f94320d authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port AA tests (#7927)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent d0e16b76
......@@ -705,281 +705,6 @@ class TestToTensorTransforms:
assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
class TestAATransforms:
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
# Stable API, if signed there is another random call
if t._AUGMENTATION_SPACE[keys[i]][1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
for i in range(le):
expected_output = t_ref(inpt)
output = t(inpt)
assert_close(expected_output, output, atol=1, rtol=0.1)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_randaug_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
torch.manual_seed(12)
expected_output = tt_ref(inpt)
torch.manual_seed(12)
scripted_output = tt(inpt)
assert_equal(scripted_output, expected_output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
for _ in range(le):
expected_output = t_ref(inpt)
output = t(inpt)
assert_close(expected_output, output, atol=1, rtol=0.1)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_trivial_aug_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
torch.manual_seed(12)
expected_output = tt_ref(inpt)
torch.manual_seed(12)
scripted_output = tt(inpt)
assert_equal(scripted_output, expected_output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t._sample_dirichlet = lambda t: t.softmax(dim=-1)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)
randint_values = iter(randint_values)
mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)
expected_output = t_ref(inpt)
output = t(inpt)
assert_equal(expected_output, output)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_augmix_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
torch.manual_seed(12)
expected_output = tt_ref(inpt)
torch.manual_seed(12)
scripted_output = tt(inpt)
assert_equal(scripted_output, expected_output)
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
torch.manual_seed(12)
expected_output = t_ref(inpt)
torch.manual_seed(12)
output = t(inpt)
assert_equal(expected_output, output)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_aa_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
torch.manual_seed(12)
expected_output = tt_ref(inpt)
torch.manual_seed(12)
scripted_output = tt(inpt)
assert_equal(scripted_output, expected_output)
def import_transforms_from_references(reference):
HERE = Path(__file__).parent
PROJECT_ROOT = HERE.parent
......
......@@ -232,7 +232,7 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol):
"""If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
``get_params`` method that is the v1 equivalent, the output is close to v1, is scriptable, and the scripted version
can be called without error."""
if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image):
if not (type(input) is torch.Tensor or isinstance(input, PIL.Image.Image)):
return
v1_transform_cls = transform._v1_transform_cls
......@@ -250,7 +250,7 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol):
with freeze_rng_state():
output_v1 = v1_transform(input)
assert_close(output_v2, output_v1, rtol=rtol, atol=atol)
assert_close(F.to_image(output_v2), F.to_image(output_v1), rtol=rtol, atol=atol)
if isinstance(input, PIL.Image.Image):
return
......@@ -2772,7 +2772,10 @@ class TestErase:
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
check_transform(transforms.RandomErasing(p=1), make_input(device=device))
input = make_input(device=device)
check_transform(
transforms.RandomErasing(p=1), input, check_v1_compatibility=not isinstance(input, PIL.Image.Image)
)
def _reference_erase_image(self, image, *, i, j, h, w, v):
mask = torch.zeros_like(image, dtype=torch.bool)
......@@ -2898,3 +2901,111 @@ class TestGaussianBlur:
else:
assert sigma[0] <= params["sigma"][0] <= sigma[1]
assert sigma[0] <= params["sigma"][1] <= sigma[1]
class TestAutoAugmentTransforms:
# These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling.
# It's typically very hard to test the effect on some parameters without heavy mocking logic.
# This class adds correctness tests for the kernels that are specific to those transforms. The rest of kernels, e.g.
# rotate, are tested in their respective classes. The rest of the tests here are mostly smoke tests.
def _reference_shear_translate(self, image, *, transform_id, magnitude, interpolation, fill):
if isinstance(image, PIL.Image.Image):
input = image
else:
input = F.to_pil_image(image)
matrix = {
"ShearX": (1, magnitude, 0, 0, 1, 0),
"ShearY": (1, 0, 0, magnitude, 1, 0),
"TranslateX": (1, 0, -int(magnitude), 0, 1, 0),
"TranslateY": (1, 0, 0, 0, 1, -int(magnitude)),
}[transform_id]
output = input.transform(
input.size, PIL.Image.AFFINE, matrix, resample=pil_modes_mapping[interpolation], fill=fill
)
if isinstance(image, PIL.Image.Image):
return output
else:
return F.to_image(output)
@pytest.mark.parametrize("transform_id", ["ShearX", "ShearY", "TranslateX", "TranslateY"])
@pytest.mark.parametrize("magnitude", [0.3, -0.2, 0.0])
@pytest.mark.parametrize(
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
)
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
@pytest.mark.parametrize("input_type", ["Tensor", "PIL"])
def test_correctness_shear_translate(self, transform_id, magnitude, interpolation, fill, input_type):
# ShearX/Y and TranslateX/Y are the only ops that are native to the AA transforms. They are modeled after the
# reference implementation:
# https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L362
# All other ops are checked in their respective dedicated tests.
image = make_image(dtype=torch.uint8, device="cpu")
if input_type == "PIL":
image = F.to_pil_image(image)
if "Translate" in transform_id:
# For TranslateX/Y magnitude is a value in pixels
magnitude *= min(F.get_size(image))
actual = transforms.AutoAugment()._apply_image_or_video_transform(
image,
transform_id=transform_id,
magnitude=magnitude,
interpolation=interpolation,
fill={type(image): fill},
)
expected = self._reference_shear_translate(
image, transform_id=transform_id, magnitude=magnitude, interpolation=interpolation, fill=fill
)
if input_type == "PIL":
actual, expected = F.to_image(actual), F.to_image(expected)
if "Shear" in transform_id and input_type == "Tensor":
mae = (actual.float() - expected.float()).abs().mean()
assert mae < (12 if interpolation is transforms.InterpolationMode.NEAREST else 5)
else:
assert_close(actual, expected, rtol=0, atol=1)
@pytest.mark.parametrize(
"transform",
[transforms.AutoAugment(), transforms.RandAugment(), transforms.TrivialAugmentWide(), transforms.AugMix()],
)
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_smoke(self, transform, make_input, dtype, device):
if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"):
pytest.skip(
"PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
"will degenerate to that anyway."
)
input = make_input(dtype=dtype, device=device)
with freeze_rng_state():
# By default every test starts from the same random seed. This leads to minimal coverage of the sampling
# that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage,
# we build a reproducible random seed from the input type, dtype, and device.
torch.manual_seed(hash((make_input, dtype, device)))
# For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1
# and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks
# here and only check if we can script the v2 transform and subsequently call the result.
check_transform(transform, input, check_v1_compatibility=False)
if type(input) is torch.Tensor and dtype is torch.uint8:
_script(transform)(input)
def test_auto_augment_policy_error(self):
with pytest.raises(ValueError, match="provided policy"):
transforms.AutoAugment(policy=None)
@pytest.mark.parametrize("severity", [0, 11])
def test_aug_mix_severity_error(self, severity):
with pytest.raises(ValueError, match="severity must be between"):
transforms.AugMix(severity=severity)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment