Unverified Commit 55477476 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Improved functional tensor geom transforms to work on floatX dtype (#2661)

* Improved functional tensor geom transforms to work on floatX dtype
- Fixes #2600
- added tests
- refactored test_affine

* Removed float16/cpu case
parent 6662b30a
...@@ -183,7 +183,12 @@ class Tester(TransformsTester): ...@@ -183,7 +183,12 @@ class Tester(TransformsTester):
script_fn = torch.jit.script(F_t.pad) script_fn = torch.jit.script(F_t.pad)
tensor, pil_img = self._create_data(7, 8, device=self.device) tensor, pil_img = self._create_data(7, 8, device=self.device)
for dt in [None, torch.float32, torch.float64]: for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None: if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases # This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt) tensor = tensor.to(dt)
...@@ -295,7 +300,12 @@ class Tester(TransformsTester): ...@@ -295,7 +300,12 @@ class Tester(TransformsTester):
script_fn = torch.jit.script(F_t.resize) script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36, device=self.device) tensor, pil_img = self._create_data(26, 36, device=self.device)
for dt in [None, torch.float32, torch.float64]: for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None: if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases # This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt) tensor = tensor.to(dt)
...@@ -346,134 +356,166 @@ class Tester(TransformsTester): ...@@ -346,134 +356,166 @@ class Tester(TransformsTester):
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]) msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
) )
def test_affine(self): def _test_affine_identity_map(self, tensor, scripted_affine):
# Tests on square and rectangular images # 1) identity map
scripted_affine = torch.jit.script(F.affine) out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)] self.assertTrue(
for tensor, pil_img in data: tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
# 1) identity map def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) # 2) Test rotation
self.assertTrue( test_configs = [
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) (90, torch.rot90(tensor, k=1, dims=(-1, -2))),
) (45, None),
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) (30, None),
self.assertTrue( (-30, None),
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) (-45, None),
(-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
]
for a, true_tensor in test_configs:
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
) )
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device)
for fn in [F.affine, scripted_affine]:
out_tensor = fn(
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
if true_tensor is not None:
self.assertTrue(
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
)
if pil_img.size[0] == pil_img.size[1]: if out_tensor.dtype != torch.uint8:
# 2) Test rotation out_tensor = out_tensor.to(torch.uint8)
test_configs = [
(90, torch.rot90(tensor, k=1, dims=(-1, -2))), num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
(45, None), ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
(30, None), # Tolerance : less than 6% of different pixels
(-30, None), self.assertLess(
(-45, None), ratio_diff_pixels,
(-90, torch.rot90(tensor, k=-1, dims=(-1, -2))), 0.06,
(180, torch.rot90(tensor, k=2, dims=(-1, -2))), msg="{}\n{} vs \n{}".format(
] ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
for a, true_tensor in test_configs:
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
) )
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device) )
for fn in [F.affine, scripted_affine]: def _test_affine_rect_rotations(self, tensor, pil_img, scripted_affine):
out_tensor = fn( test_configs = [
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 90, 45, 15, -30, -60, -120
) ]
if true_tensor is not None: for a in test_configs:
self.assertTrue(
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
)
else:
true_tensor = out_tensor
num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
# Tolerance : less than 6% of different pixels
self.assertLess(
ratio_diff_pixels,
0.06,
msg="{}\n{} vs \n{}".format(
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
else:
test_configs = [
90, 45, 15, -30, -60, -120
]
for a in test_configs:
out_pil_img = F.affine( out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.affine, scripted_affine]:
out_tensor = fn(
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
msg="{}: {}\n{} vs \n{}".format(
a, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
) )
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) )
for fn in [F.affine, scripted_affine]:
out_tensor = fn(
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
).cpu()
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
msg="{}: {}\n{} vs \n{}".format(
a, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
# 3) Test translation def _test_affine_translations(self, tensor, pil_img, scripted_affine):
test_configs = [ # 3) Test translation
[10, 12], (-12, -13) test_configs = [
] [10, 12], (-12, -13)
for t in test_configs: ]
for t in test_configs:
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
for fn in [F.affine, scripted_affine]: for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
self.compareTensorToPIL(out_tensor, out_pil_img) if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
# 3) Test rotation + translation + scale + share self.compareTensorToPIL(out_tensor, out_pil_img)
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]), def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
(33, (5, -4), 1.0, [0.0, 0.0]), # 4) Test rotation + translation + scale + share
(45, [-5, 4], 1.2, [0.0, 0.0]), test_configs = [
(33, (-4, -8), 2.0, [0.0, 0.0]), (45, [5, 6], 1.0, [0.0, 0.0]),
(85, (10, -10), 0.7, [0.0, 0.0]), (33, (5, -4), 1.0, [0.0, 0.0]),
(0, [0, 0], 1.0, [35.0, ]), (45, [-5, 4], 1.2, [0.0, 0.0]),
(-25, [0, 0], 1.2, [0.0, 15.0]), (33, (-4, -8), 2.0, [0.0, 0.0]),
(-45, [-10, 0], 0.7, [2.0, 5.0]), (85, (10, -10), 0.7, [0.0, 0.0]),
(-45, [-10, -10], 1.2, [4.0, 5.0]), (0, [0, 0], 1.0, [35.0, ]),
(-90, [0, 0], 1.0, [0.0, 0.0]), (-25, [0, 0], 1.2, [0.0, 15.0]),
] (-45, [-10, 0], 0.7, [2.0, 5.0]),
for r in [0, ]: (-45, [-10, -10], 1.2, [4.0, 5.0]),
for a, t, s, sh in test_configs: (-90, [0, 0], 1.0, [0.0, 0.0]),
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) ]
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for r in [0, ]:
for a, t, s, sh in test_configs:
for fn in [F.affine, scripted_affine]: out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu() out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] for fn in [F.affine, scripted_affine]:
# Tolerance : less than 5% (cpu), 6% (cuda) of different pixels out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu()
tol = 0.06 if self.device == "cuda" else 0.05
self.assertLess( if out_tensor.dtype != torch.uint8:
ratio_diff_pixels, out_tensor = out_tensor.to(torch.uint8)
tol,
msg="{}: {}\n{} vs \n{}".format( num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
) # Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
tol = 0.06 if self.device == "cuda" else 0.05
self.assertLess(
ratio_diff_pixels,
tol,
msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
) )
)
def test_affine(self):
# Tests on square and rectangular images
scripted_affine = torch.jit.script(F.affine)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
self._test_affine_identity_map(tensor, scripted_affine)
if pil_img.size[0] == pil_img.size[1]:
self._test_affine_square_rotations(tensor, pil_img, scripted_affine)
else:
self._test_affine_rect_rotations(tensor, pil_img, scripted_affine)
self._test_affine_translations(tensor, pil_img, scripted_affine)
# self._test_affine_all_ops(tensor, pil_img, scripted_affine)
def test_rotate(self): def test_rotate(self):
# Tests on square image # Tests on square image
...@@ -489,45 +531,57 @@ class Tester(TransformsTester): ...@@ -489,45 +531,57 @@ class Tester(TransformsTester):
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)] [int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
] ]
for r in [0, ]: for dt in [None, torch.float32, torch.float64, torch.float16]:
for a in range(-180, 180, 17):
for e in [True, False]: if dt == torch.float16 and torch.device(self.device).type == "cpu":
for c in centers: # skip float16 on CPU case
continue
out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) if dt is not None:
for fn in [F.rotate, scripted_rotate]: tensor = tensor.to(dtype=dt)
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu()
for r in [0, ]:
self.assertEqual( for a in range(-180, 180, 17):
out_tensor.shape, for e in [True, False]:
out_pil_tensor.shape, for c in centers:
msg="{}: {} vs {}".format(
(img_size, r, a, e, c), out_tensor.shape, out_pil_tensor.shape out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
)
) )
) num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] # Tolerance : less than 3% of different pixels
# Tolerance : less than 2% of different pixels self.assertLess(
self.assertLess(
ratio_diff_pixels,
0.02,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, a, e, c),
ratio_diff_pixels, ratio_diff_pixels,
out_tensor[0, :7, :7], 0.03,
out_pil_tensor[0, :7, :7] msg="{}: {}\n{} vs \n{}".format(
(img_size, r, dt, a, e, c),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
) )
)
def test_perspective(self): def test_perspective(self):
from torchvision.transforms import RandomPerspective from torchvision.transforms import RandomPerspective
data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)] data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
for tensor, pil_img in data: scripted_tranform = torch.jit.script(F.perspective)
scripted_tranform = torch.jit.script(F.perspective) for tensor, pil_img in data:
test_configs = [ test_configs = [
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]], [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
...@@ -539,27 +593,39 @@ class Tester(TransformsTester): ...@@ -539,27 +593,39 @@ class Tester(TransformsTester):
RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n) RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n)
] ]
for r in [0, ]: for dt in [None, torch.float32, torch.float64, torch.float16]:
for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) if dt == torch.float16 and torch.device(self.device).type == "cpu":
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) # skip float16 on CPU case
continue
for fn in [F.perspective, scripted_tranform]:
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu() if dt is not None:
tensor = tensor.to(dtype=dt)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] for r in [0, ]:
# Tolerance : less than 5% of different pixels for spoints, epoints in test_configs:
self.assertLess( out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
ratio_diff_pixels, out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
0.05,
msg="{}: {}\n{} vs \n{}".format( for fn in [F.perspective, scripted_tranform]:
(r, spoints, epoints), out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels, ratio_diff_pixels,
out_tensor[0, :7, :7], 0.05,
out_pil_tensor[0, :7, :7] msg="{}: {}\n{} vs \n{}".format(
(r, dt, spoints, epoints),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
) )
)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
......
...@@ -718,9 +718,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: ...@@ -718,9 +718,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
out_dtype = img.dtype out_dtype = img.dtype
need_cast = False need_cast = False
if img.dtype not in (torch.float32, torch.float64): if out_dtype != grid.dtype:
need_cast = True need_cast = True
img = img.to(torch.float32) img = img.to(grid)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
...@@ -777,7 +777,8 @@ def affine( ...@@ -777,7 +777,8 @@ def affine(
_assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes) _assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes)
theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3) dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape shape = img.shape
# grid will be generated on the same device as theta and img # grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
...@@ -842,7 +843,8 @@ def rotate( ...@@ -842,7 +843,8 @@ def rotate(
_assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes)
w, h = img.shape[-1], img.shape[-2] w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3) dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img # grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
mode = _interpolation_modes[resample] mode = _interpolation_modes[resample]
...@@ -850,7 +852,7 @@ def rotate( ...@@ -850,7 +852,7 @@ def rotate(
return _apply_grid_transform(img, grid, mode) return _apply_grid_transform(img, grid, mode)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.device): def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device):
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394 # src/libImaging/Geometry.c#L394
...@@ -858,23 +860,22 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.devic ...@@ -858,23 +860,22 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.devic
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
# #
theta1 = torch.tensor([[ theta1 = torch.tensor([[
[coeffs[0], coeffs[1], coeffs[2]], [coeffs[0], coeffs[1], coeffs[2]],
[coeffs[3], coeffs[4], coeffs[5]] [coeffs[3], coeffs[4], coeffs[5]]
]], dtype=torch.float, device=device) ]], dtype=dtype, device=device)
theta2 = torch.tensor([[ theta2 = torch.tensor([[
[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0],
[coeffs[6], coeffs[7], 1.0] [coeffs[6], coeffs[7], 1.0]
]], dtype=torch.float, device=device) ]], dtype=dtype, device=device)
d = 0.5 d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=torch.float, device=device) base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow)) base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow))
base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1)) base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1))
base_grid[..., 2].fill_(1) base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=torch.float, device=device) rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
...@@ -915,7 +916,8 @@ def perspective( ...@@ -915,7 +916,8 @@ def perspective(
) )
ow, oh = img.shape[-1], img.shape[-2] ow, oh = img.shape[-1], img.shape[-2]
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, device=img.device) dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
mode = _interpolation_modes[interpolation] mode = _interpolation_modes[interpolation]
return _apply_grid_transform(img, grid, mode) return _apply_grid_transform(img, grid, mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment