"vscode:/vscode.git/clone" did not exist on "dc5ede718f7876cb0aeba338ec5da2bd6db94ed2"
Unverified Commit 74a1efcd authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixed reparametrization for shear X/Y in autoaugment ops (#5384)

* Added ref tests for shear X/Y

* Added PIL tests and fixed tan(level) difference

* Updated tests

* Fixed reparam for shear X/Y in autoaugment

* Fixed arc_level -> level as atan is applied internally

* Fixed links
parent 60327755
......@@ -187,7 +187,7 @@ def _assert_approx_equal_tensor_to_pil(
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
assert err < tol
assert err < tol, f"{err} vs {tol}"
def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
......
......@@ -14,9 +14,11 @@ from common_utils import (
cpu_and_gpu,
assert_equal,
)
from PIL import Image
from torchvision import transforms as T
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
from torchvision.transforms.autoaugment import _apply_op
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
......@@ -725,6 +727,48 @@ def test_autoaugment_save(augmentation, tmpdir):
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
@pytest.mark.parametrize("interpolation", [F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR])
@pytest.mark.parametrize("mode", ["X", "Y"])
def test_autoaugment__op_apply_shear(interpolation, mode):
# We check that torchvision's implementation of shear is equivalent
# to official CIFAR10 autoaugment implementation:
# https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L290
image_size = 32
def shear(pil_img, level, mode, resample):
if mode == "X":
matrix = (1, level, 0, 0, 1, 0)
elif mode == "Y":
matrix = (1, 0, 0, level, 1, 0)
return pil_img.transform((image_size, image_size), Image.AFFINE, matrix, resample=resample)
t_img, pil_img = _create_data(image_size, image_size)
resample_pil = {
F.InterpolationMode.NEAREST: Image.NEAREST,
F.InterpolationMode.BILINEAR: Image.BILINEAR,
}[interpolation]
level = 0.3
expected_out = shear(pil_img, level, mode=mode, resample=resample_pil)
# Check pil output vs expected pil
out = _apply_op(pil_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0)
assert out == expected_out
if interpolation == F.InterpolationMode.BILINEAR:
# We skip bilinear mode for tensors as
# affine transformation results are not exactly the same
# between tensors and pil images
# MAE as around 1.40
# Max Abs error can be 163 or 170
return
# Check tensor output vs expected pil
out = _apply_op(t_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0)
_assert_approx_equal_tensor_to_pil(out, expected_out)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"config",
......
......@@ -14,23 +14,31 @@ def _apply_op(
img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
):
if op_name == "ShearX":
# magnitude should be arctan(magnitude)
# official autoaug: (1, level, 0, 0, 1, 0)
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
# compared to
# torchvision: (1, tan(level), 0, 0, 1, 0)
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
img = F.affine(
img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
shear=[math.degrees(math.atan(magnitude)), 0.0],
interpolation=interpolation,
fill=fill,
center=[0, 0],
)
elif op_name == "ShearY":
# magnitude should be arctan(magnitude)
# See above
img = F.affine(
img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
shear=[0.0, math.degrees(math.atan(magnitude))],
interpolation=interpolation,
fill=fill,
center=[0, 0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment