Unverified Commit 025b71d8 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixes F.affine and F.rotate to support rectangular tensor images (#2553)

* Added code for F_t.rotate with test
- updated F.affine tests

* Rotate test tolerance to 2%

* Fixes failing test

* Optimized _expanded_affine_grid with a single matmul op

* Recoded _compute_output_size

* [WIP] recoded F_t.rotate internal methods

* [WIP] Fixed F.affine to support rectangular images

* Recoded _gen_affine_grid to optimized version ~ affine_grid
- Fixes flake8

* [WIP] Use _gen_affine_grid for affine and rotate

* Fixed tests on square / rectangular images for affine and rotate ops

* Removed redefinition of F.rotate
- due to bad merge
parent 76662528
......@@ -385,134 +385,165 @@ class Tester(unittest.TestCase):
)
def test_affine(self):
# Tests on square image
tensor, pil_img = self._create_data(26, 26)
# Tests on square and rectangular images
scripted_affine = torch.jit.script(F.affine)
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
# 2) Test rotation
test_configs = [
(90, torch.rot90(tensor, k=1, dims=(-1, -2))),
(45, None),
(30, None),
(-30, None),
(-45, None),
(-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
]
for a, true_tensor in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
if true_tensor is not None:
self.assertTrue(
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
)
else:
true_tensor = out_tensor
out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
# Tolerance : less than 6% of different pixels
self.assertLess(
ratio_diff_pixels,
0.06,
msg="{}\n{} vs \n{}".format(
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
# 3) Test translation
test_configs = [
[10, 12], (-12, -13)
]
for t in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
self.compareTensorToPIL(out_tensor, out_pil_img)
# 3) Test rotation + translation + scale + share
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]),
(33, (5, -4), 1.0, [0.0, 0.0]),
(45, [-5, 4], 1.2, [0.0, 0.0]),
(33, (-4, -8), 2.0, [0.0, 0.0]),
(85, (10, -10), 0.7, [0.0, 0.0]),
(0, [0, 0], 1.0, [35.0, ]),
(25, [0, 0], 1.2, [0.0, 15.0]),
(45, [-10, 0], 0.7, [2.0, 5.0]),
(45, [-10, -10], 1.2, [4.0, 5.0]),
]
for r in [0, ]:
for a, t, s, sh in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]:
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
if pil_img.size[0] == pil_img.size[1]:
# 2) Test rotation
test_configs = [
(90, torch.rot90(tensor, k=1, dims=(-1, -2))),
(45, None),
(30, None),
(-30, None),
(-45, None),
(-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
]
for a, true_tensor in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
if true_tensor is not None:
self.assertTrue(
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
)
else:
true_tensor = out_tensor
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
# Tolerance : less than 6% of different pixels
self.assertLess(
ratio_diff_pixels,
0.06,
msg="{}\n{} vs \n{}".format(
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
else:
test_configs = [
90, 45, 15, -30, -60, -120
]
for a in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
msg="{}: {}\n{} vs \n{}".format(
a, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
# 3) Test translation
test_configs = [
[10, 12], (-12, -13)
]
for t in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
0.05,
msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
self.compareTensorToPIL(out_tensor, out_pil_img)
# 3) Test rotation + translation + scale + share
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]),
(33, (5, -4), 1.0, [0.0, 0.0]),
(45, [-5, 4], 1.2, [0.0, 0.0]),
(33, (-4, -8), 2.0, [0.0, 0.0]),
(85, (10, -10), 0.7, [0.0, 0.0]),
(0, [0, 0], 1.0, [35.0, ]),
(-25, [0, 0], 1.2, [0.0, 15.0]),
(-45, [-10, 0], 0.7, [2.0, 5.0]),
(-45, [-10, -10], 1.2, [4.0, 5.0]),
(-90, [0, 0], 1.0, [0.0, 0.0]),
]
for r in [0, ]:
for a, t, s, sh in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
0.05,
msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
)
def test_rotate(self):
# Tests on square image
tensor, pil_img = self._create_data(26, 26)
scripted_rotate = torch.jit.script(F.rotate)
img_size = pil_img.size
centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]
for r in [0, ]:
for a in range(-120, 120, 23):
for e in [True, False]:
for c in centers:
out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c)
self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(r, a, e, c), out_tensor.shape, out_pil_tensor.shape
for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]:
img_size = pil_img.size
centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]
for r in [0, ]:
for a in range(-180, 180, 17):
for e in [True, False]:
for c in centers:
out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c)
self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, a, e, c), out_tensor.shape, out_pil_tensor.shape
)
)
)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 2% of different pixels
self.assertLess(
ratio_diff_pixels,
0.02,
msg="{}: {}\n{} vs \n{}".format(
(r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 2% of different pixels
self.assertLess(
ratio_diff_pixels,
0.02,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, a, e, c),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
)
)
if __name__ == '__main__':
......
......@@ -848,8 +848,9 @@ def rotate(
center_f = [0.0, 0.0]
if center is not None:
img_size = _get_image_size(img)
# Center is normalized to [-1, +1]
center_f = [2.0 * t / s - 1.0 for s, t in zip(img_size, center)]
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
......@@ -926,10 +927,8 @@ def affine(
return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
# we need to rescale translate by image size / 2 as its values can be between -1 and 1
translate = [2.0 * t / s for s, t in zip(img_size, translate)]
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear)
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
......
......@@ -663,6 +663,25 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
return img
def _gen_affine_grid(
theta: Tensor, w: int, h: int, ow: int, oh: int,
) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
d = 0.5
base_grid = torch.empty(1, oh, ow, 3)
base_grid[..., 0].copy_(torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow))
base_grid[..., 1].copy_(torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh).unsqueeze_(-1))
base_grid[..., 2].fill_(1)
output_grid = base_grid.view(1, oh * ow, 3).bmm(theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h]))
return output_grid.view(1, oh, ow, 2)
def affine(
img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None
) -> Tensor:
......@@ -688,44 +707,33 @@ def affine(
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
shape = img.shape
grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False)
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
mode = _interpolation_modes[resample]
return _apply_grid_transform(img, grid, mode)
def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# we need to normalize coordinates according to
# [0, s] is mapped [-1, +1] as theta translation parameters are normalized like that
pts = torch.tensor([
[-1.0, -1.0, 1.0],
[-1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, -1.0, 1.0],
[-0.5 * w, -0.5 * h, 1.0],
[-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
])
# denormalize back to w, h:
new_pts = (torch.matmul(pts, theta.t()) + 1.0) * torch.tensor([w, h]) / 2.0
new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2)
min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0)
size = torch.ceil(max_vals) - torch.floor(min_vals)
return int(size[0]), int(size[1])
def _expanded_affine_grid(theta: Tensor, w: int, h: int, expand: bool = False) -> Tensor:
if expand:
ow, oh = _compute_output_size(theta, w, h)
else:
ow, oh = w, h
d = 0.5 # if not align_corners
x = (torch.arange(ow) + d - ow * 0.5) / (0.5 * w)
y = (torch.arange(oh) + d - oh * 0.5) / (0.5 * h)
y, x = torch.meshgrid(y, x)
pts = torch.stack([x, y, torch.ones_like(x)], dim=-1)
output_grid = torch.matmul(pts, theta.t())
return output_grid.unsqueeze(dim=0)
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
cmax = torch.ceil((max_vals / tol).trunc_() * tol)
cmin = torch.floor((min_vals / tol).trunc_() * tol)
size = cmax - cmin
return int(size[0]), int(size[1])
def rotate(
......@@ -736,6 +744,7 @@ def rotate(
Args:
img (Tensor): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation.
Translation part (``matrix[2]`` and ``matrix[5]``) should be in pixel coordinates.
resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values:
bilinear(=2).
expand (bool, optional): Optional expansion flag.
......@@ -757,10 +766,10 @@ def rotate(
}
_assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes)
theta = torch.tensor(matrix).reshape(2, 3)
shape = img.shape
grid = _expanded_affine_grid(theta, shape[-1], shape[-2], expand=expand)
theta = torch.tensor(matrix).reshape(1, 2, 3)
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_output_size(theta, w, h) if expand else (w, h)
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
mode = _interpolation_modes[resample]
return _apply_grid_transform(img, grid, mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment