Unverified Commit 55477476 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Improved functional tensor geom transforms to work on floatX dtype (#2661)

* Improved functional tensor geom transforms to work on floatX dtype
- Fixes #2600
- added tests
- refactored test_affine

* Removed float16/cpu case
parent 6662b30a
......@@ -183,7 +183,12 @@ class Tester(TransformsTester):
script_fn = torch.jit.script(F_t.pad)
tensor, pil_img = self._create_data(7, 8, device=self.device)
for dt in [None, torch.float32, torch.float64]:
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
......@@ -295,7 +300,12 @@ class Tester(TransformsTester):
script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36, device=self.device)
for dt in [None, torch.float32, torch.float64]:
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
......@@ -346,15 +356,10 @@ class Tester(TransformsTester):
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
)
def test_affine(self):
# Tests on square and rectangular images
scripted_affine = torch.jit.script(F.affine)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
def _test_affine_identity_map(self, tensor, scripted_affine):
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
......@@ -363,7 +368,7 @@ class Tester(TransformsTester):
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
if pil_img.size[0] == pil_img.size[1]:
def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
# 2) Test rotation
test_configs = [
(90, torch.rot90(tensor, k=1, dims=(-1, -2))),
......@@ -375,7 +380,6 @@ class Tester(TransformsTester):
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
]
for a, true_tensor in test_configs:
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
......@@ -390,20 +394,22 @@ class Tester(TransformsTester):
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
)
else:
true_tensor = out_tensor
num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 6% of different pixels
self.assertLess(
ratio_diff_pixels,
0.06,
msg="{}\n{} vs \n{}".format(
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
else:
def _test_affine_rect_rotations(self, tensor, pil_img, scripted_affine):
test_configs = [
90, 45, 15, -30, -60, -120
]
......@@ -419,6 +425,9 @@ class Tester(TransformsTester):
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
......@@ -430,6 +439,7 @@ class Tester(TransformsTester):
)
)
def _test_affine_translations(self, tensor, pil_img, scripted_affine):
# 3) Test translation
test_configs = [
[10, 12], (-12, -13)
......@@ -441,9 +451,13 @@ class Tester(TransformsTester):
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
self.compareTensorToPIL(out_tensor, out_pil_img)
# 3) Test rotation + translation + scale + share
def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
# 4) Test rotation + translation + scale + share
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]),
(33, (5, -4), 1.0, [0.0, 0.0]),
......@@ -463,6 +477,10 @@ class Tester(TransformsTester):
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
......@@ -475,6 +493,30 @@ class Tester(TransformsTester):
)
)
def test_affine(self):
# Tests on square and rectangular images
scripted_affine = torch.jit.script(F.affine)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
self._test_affine_identity_map(tensor, scripted_affine)
if pil_img.size[0] == pil_img.size[1]:
self._test_affine_square_rotations(tensor, pil_img, scripted_affine)
else:
self._test_affine_rect_rotations(tensor, pil_img, scripted_affine)
self._test_affine_translations(tensor, pil_img, scripted_affine)
# self._test_affine_all_ops(tensor, pil_img, scripted_affine)
def test_rotate(self):
# Tests on square image
scripted_rotate = torch.jit.script(F.rotate)
......@@ -489,6 +531,15 @@ class Tester(TransformsTester):
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
for r in [0, ]:
for a in range(-180, 180, 17):
for e in [True, False]:
......@@ -499,21 +550,24 @@ class Tester(TransformsTester):
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, a, e, c), out_tensor.shape, out_pil_tensor.shape
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
)
)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 2% of different pixels
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.02,
0.03,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, a, e, c),
(img_size, r, dt, a, e, c),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
......@@ -525,10 +579,10 @@ class Tester(TransformsTester):
from torchvision.transforms import RandomPerspective
data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
for tensor, pil_img in data:
scripted_tranform = torch.jit.script(F.perspective)
for tensor, pil_img in data:
test_configs = [
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
......@@ -539,6 +593,15 @@ class Tester(TransformsTester):
RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n)
]
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
for r in [0, ]:
for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
......@@ -547,6 +610,9 @@ class Tester(TransformsTester):
for fn in [F.perspective, scripted_tranform]:
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
......@@ -554,7 +620,7 @@ class Tester(TransformsTester):
ratio_diff_pixels,
0.05,
msg="{}: {}\n{} vs \n{}".format(
(r, spoints, epoints),
(r, dt, spoints, epoints),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
......
......@@ -718,9 +718,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
out_dtype = img.dtype
need_cast = False
if img.dtype not in (torch.float32, torch.float64):
if out_dtype != grid.dtype:
need_cast = True
img = img.to(torch.float32)
img = img.to(grid)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
......@@ -777,7 +777,8 @@ def affine(
_assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes)
theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
......@@ -842,7 +843,8 @@ def rotate(
_assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes)
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
mode = _interpolation_modes[resample]
......@@ -850,7 +852,7 @@ def rotate(
return _apply_grid_transform(img, grid, mode)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.device):
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device):
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
......@@ -858,23 +860,22 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.devic
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor([[
[coeffs[0], coeffs[1], coeffs[2]],
[coeffs[3], coeffs[4], coeffs[5]]
]], dtype=torch.float, device=device)
]], dtype=dtype, device=device)
theta2 = torch.tensor([[
[coeffs[6], coeffs[7], 1.0],
[coeffs[6], coeffs[7], 1.0]
]], dtype=torch.float, device=device)
]], dtype=dtype, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=torch.float, device=device)
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow))
base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1))
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=torch.float, device=device)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
......@@ -915,7 +916,8 @@ def perspective(
)
ow, oh = img.shape[-1], img.shape[-2]
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, device=img.device)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
mode = _interpolation_modes[interpolation]
return _apply_grid_transform(img, grid, mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment