Unverified Commit 39702993 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Adapted functional tensor tests on CPU/CUDA (#2569)



* Adapted almost all functional tensor tests on CPU/CUDA
- fixed bug with transforms using generated grid
- remains *_crop, blocked by #2568
- TODO: test_adjustments

* Apply suggestions from code review
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>

* Fixed issues according to review

* Split tests into two: cpu and cuda

* Updated test_adjustments to run on CPU and CUDA
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent b16914be
......@@ -17,16 +17,16 @@ import torchvision.transforms.functional as F
class Tester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def compareTensorToPIL(self, tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.equal(pil_tensor), msg)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
......@@ -36,9 +36,9 @@ class Tester(unittest.TestCase):
msg="{}: mae={}, tol={}: \n{}\nvs\n{}".format(msg, mae, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
def test_vflip(self):
def _test_vflip(self, device):
script_vflip = torch.jit.script(F_t.vflip)
img_tensor = torch.randn(3, 16, 16)
img_tensor = torch.randn(3, 16, 16, device=device)
img_tensor_clone = img_tensor.clone()
vflipped_img = F_t.vflip(img_tensor)
vflipped_img_again = F_t.vflip(vflipped_img)
......@@ -49,9 +49,16 @@ class Tester(unittest.TestCase):
vflipped_img_script = script_vflip(img_tensor)
self.assertTrue(torch.equal(vflipped_img, vflipped_img_script))
def test_hflip(self):
def test_vflip_cpu(self):
self._test_vflip("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_vflip_cuda(self):
self._test_vflip("cuda")
def _test_hflip(self, device):
script_hflip = torch.jit.script(F_t.hflip)
img_tensor = torch.randn(3, 16, 16)
img_tensor = torch.randn(3, 16, 16, device=device)
img_tensor_clone = img_tensor.clone()
hflipped_img = F_t.hflip(img_tensor)
hflipped_img_again = F_t.hflip(hflipped_img)
......@@ -62,10 +69,17 @@ class Tester(unittest.TestCase):
hflipped_img_script = script_hflip(img_tensor)
self.assertTrue(torch.equal(hflipped_img, hflipped_img_script))
def test_crop(self):
script_crop = torch.jit.script(F_t.crop)
def test_hflip_cpu(self):
self._test_hflip("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_hflip_cuda(self):
self._test_hflip("cuda")
def _test_crop(self, device):
script_crop = torch.jit.script(F.crop)
img_tensor, pil_img = self._create_data(16, 18)
img_tensor, pil_img = self._create_data(16, 18, device=device)
test_configs = [
(1, 2, 4, 5), # crop inside top-left corner
......@@ -83,6 +97,13 @@ class Tester(unittest.TestCase):
img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
def test_crop_cpu(self):
self._test_crop("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_crop_cuda(self):
self._test_crop("cuda")
def test_hsv2rgb(self):
shape = (3, 100, 150)
for _ in range(20):
......@@ -128,7 +149,7 @@ class Tester(unittest.TestCase):
self.assertLess(max_diff, 1e-5)
def test_adjustments(self):
def _test_adjustments(self, device):
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)
......@@ -143,16 +164,16 @@ class Tester(unittest.TestCase):
shape = (channels, dims[0], dims[1])
if torch.randint(0, 2, (1,)) == 0:
img = torch.rand(*shape, dtype=torch.float)
img = torch.rand(*shape, dtype=torch.float, device=device)
else:
img = torch.randint(0, 256, shape, dtype=torch.uint8)
img = torch.randint(0, 256, shape, dtype=torch.uint8, device=device)
factor = 3 * torch.rand(1)
factor = 3 * torch.rand(1).item()
img_clone = img.clone()
for f, ft, sft in fns:
ft_img = ft(img, factor)
sft_img = sft(img, factor)
ft_img = ft(img, factor).cpu()
sft_img = sft(img, factor).cpu()
if not img.dtype.is_floating_point:
ft_img = ft_img.to(torch.float) / 255
sft_img = sft_img.to(torch.float) / 255
......@@ -170,15 +191,15 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(img, img_clone))
# test for class interface
f = transforms.ColorJitter(brightness=factor.item())
f = transforms.ColorJitter(brightness=factor)
scripted_fn = torch.jit.script(f)
scripted_fn(img)
f = transforms.ColorJitter(contrast=factor.item())
f = transforms.ColorJitter(contrast=factor)
scripted_fn = torch.jit.script(f)
scripted_fn(img)
f = transforms.ColorJitter(saturation=factor.item())
f = transforms.ColorJitter(saturation=factor)
scripted_fn = torch.jit.script(f)
scripted_fn(img)
......@@ -186,6 +207,13 @@ class Tester(unittest.TestCase):
scripted_fn = torch.jit.script(f)
scripted_fn(img)
def test_adjustments(self):
self._test_adjustments("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_adjustments_cuda(self):
self._test_adjustments("cuda")
def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
......@@ -199,10 +227,10 @@ class Tester(unittest.TestCase):
grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
def test_center_crop(self):
def _test_center_crop(self, device):
script_center_crop = torch.jit.script(F.center_crop)
img_tensor, pil_img = self._create_data(32, 34)
img_tensor, pil_img = self._create_data(32, 34, device=device)
cropped_pil_image = F.center_crop(pil_img, [10, 11])
......@@ -212,10 +240,17 @@ class Tester(unittest.TestCase):
cropped_tensor = script_center_crop(img_tensor, [10, 11])
self.compareTensorToPIL(cropped_tensor, cropped_pil_image)
def test_five_crop(self):
def test_center_crop(self):
self._test_center_crop("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_center_crop_cuda(self):
self._test_center_crop("cuda")
def _test_five_crop(self, device):
script_five_crop = torch.jit.script(F.five_crop)
img_tensor, pil_img = self._create_data(32, 34)
img_tensor, pil_img = self._create_data(32, 34, device=device)
cropped_pil_images = F.five_crop(pil_img, [10, 11])
......@@ -227,10 +262,17 @@ class Tester(unittest.TestCase):
for i in range(5):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
def test_ten_crop(self):
def test_five_crop(self):
self._test_five_crop("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_five_crop_cuda(self):
self._test_five_crop("cuda")
def _test_ten_crop(self, device):
script_ten_crop = torch.jit.script(F.ten_crop)
img_tensor, pil_img = self._create_data(32, 34)
img_tensor, pil_img = self._create_data(32, 34, device=device)
cropped_pil_images = F.ten_crop(pil_img, [10, 11])
......@@ -242,9 +284,16 @@ class Tester(unittest.TestCase):
for i in range(10):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
def test_pad(self):
def test_ten_crop(self):
self._test_ten_crop("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_ten_crop_cuda(self):
self._test_ten_crop("cuda")
def _test_pad(self, device):
script_fn = torch.jit.script(F_t.pad)
tensor, pil_img = self._create_data(7, 8)
tensor, pil_img = self._create_data(7, 8, device=device)
for dt in [None, torch.float32, torch.float64]:
if dt is not None:
......@@ -280,9 +329,16 @@ class Tester(unittest.TestCase):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
def test_adjust_gamma(self):
script_fn = torch.jit.script(F_t.adjust_gamma)
tensor, pil_img = self._create_data(26, 36)
def test_pad_cpu(self):
self._test_pad("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_pad_cuda(self):
self._test_pad("cuda")
def _test_adjust_gamma(self, device):
script_fn = torch.jit.script(F.adjust_gamma)
tensor, pil_img = self._create_data(26, 36, device=device)
for dt in [torch.float64, torch.float32, None]:
......@@ -293,8 +349,8 @@ class Tester(unittest.TestCase):
gains = [0.7, 1.0, 1.3]
for gamma, gain in zip(gammas, gains):
adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain)
adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain)
adjusted_tensor = F.adjust_gamma(tensor, gamma, gain)
adjusted_pil = F.adjust_gamma(pil_img, gamma, gain)
scripted_result = script_fn(tensor, gamma, gain)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])
......@@ -305,11 +361,18 @@ class Tester(unittest.TestCase):
self.compareTensorToPIL(rbg_tensor, adjusted_pil)
self.assertTrue(adjusted_tensor.equal(scripted_result))
self.assertTrue(adjusted_tensor.allclose(scripted_result))
def test_adjust_gamma_cpu(self):
self._test_adjust_gamma("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_adjust_gamma_cuda(self):
self._test_adjust_gamma("cuda")
def test_resize(self):
def _test_resize(self, device):
script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36)
tensor, pil_img = self._create_data(26, 36, device=device)
for dt in [None, torch.float32, torch.float64]:
if dt is not None:
......@@ -345,16 +408,23 @@ class Tester(unittest.TestCase):
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
def test_resized_crop(self):
def test_resize_cpu(self):
self._test_resize("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_resize_cuda(self):
self._test_resize("cuda")
def _test_resized_crop(self, device):
# test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity
tensor, _ = self._create_data(26, 36)
tensor, _ = self._create_data(26, 36, device=device)
for i in [0, 2, 3]:
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i)
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
# 2) resize by half and crop a TL corner
tensor, _ = self._create_data(26, 36)
tensor, _ = self._create_data(26, 36, device=device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0)
expected_out_tensor = tensor[:, :20:2, :30:2]
self.assertTrue(
......@@ -362,11 +432,18 @@ class Tester(unittest.TestCase):
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
)
def test_affine(self):
def test_resized_crop_cpu(self):
self._test_resized_crop("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_resized_crop_cuda(self):
self._test_resized_crop("cuda")
def _test_affine(self, device):
# Tests on square and rectangular images
scripted_affine = torch.jit.script(F.affine)
for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]:
for tensor, pil_img in [self._create_data(26, 26, device=device), self._create_data(32, 26, device=device)]:
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
......@@ -390,8 +467,16 @@ class Tester(unittest.TestCase):
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
]
for a, true_tensor in test_configs:
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(device)
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
out_tensor = fn(
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
if true_tensor is not None:
self.assertTrue(
true_tensor.equal(out_tensor),
......@@ -400,11 +485,6 @@ class Tester(unittest.TestCase):
else:
true_tensor = out_tensor
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
# Tolerance : less than 6% of different pixels
......@@ -420,12 +500,16 @@ class Tester(unittest.TestCase):
90, 45, 15, -30, -60, -120
]
for a in test_configs:
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
out_tensor = fn(
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
).cpu()
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
......@@ -443,9 +527,12 @@ class Tester(unittest.TestCase):
[10, 12], (-12, -13)
]
for t in test_configs:
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
self.compareTensorToPIL(out_tensor, out_pil_img)
# 3) Test rotation + translation + scale + share
......@@ -467,23 +554,31 @@ class Tester(unittest.TestCase):
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu()
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
# Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
tol = 0.06 if device == "cuda" else 0.05
self.assertLess(
ratio_diff_pixels,
0.05,
tol,
msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
def test_rotate(self):
def test_affine_cpu(self):
self._test_affine("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_affine_cuda(self):
self._test_affine("cuda")
def _test_rotate(self, device):
# Tests on square image
scripted_rotate = torch.jit.script(F.rotate)
for tensor, pil_img in [self._create_data(26, 26), self._create_data(32, 26)]:
for tensor, pil_img in [self._create_data(26, 26, device=device), self._create_data(32, 26, device=device)]:
img_size = pil_img.size
centers = [
......@@ -500,7 +595,7 @@ class Tester(unittest.TestCase):
out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c)
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu()
self.assertEqual(
out_tensor.shape,
......@@ -523,11 +618,18 @@ class Tester(unittest.TestCase):
)
)
def test_perspective(self):
def test_rotate_cpu(self):
self._test_rotate("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_rotate_cuda(self):
self._test_rotate("cuda")
def _test_perspective(self, device):
from torchvision.transforms import RandomPerspective
for tensor, pil_img in [self._create_data(26, 34), self._create_data(26, 26)]:
for tensor, pil_img in [self._create_data(26, 34, device=device), self._create_data(26, 26, device=device)]:
scripted_tranform = torch.jit.script(F.perspective)
......@@ -547,7 +649,7 @@ class Tester(unittest.TestCase):
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.perspective, scripted_tranform]:
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r)
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
......@@ -563,6 +665,13 @@ class Tester(unittest.TestCase):
)
)
def test_perspective_cpu(self):
self._test_perspective("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_perspective_cuda(self):
self._test_perspective("cuda")
if __name__ == '__main__':
unittest.main()
......@@ -223,10 +223,10 @@ def to_pil_image(pic, mode=None):
pic = np.expand_dims(pic, 2)
npimg = pic
if isinstance(pic, torch.FloatTensor) and mode != 'F':
pic = pic.mul(255).byte()
if isinstance(pic, torch.Tensor):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
if pic.is_floating_point() and mode != 'F':
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray):
raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
......
......@@ -3,7 +3,7 @@ from typing import Optional, Dict, Tuple
import torch
from torch import Tensor
from torch.nn.functional import affine_grid, grid_sample
from torch.nn.functional import grid_sample
from torch.jit.annotations import List, BroadcastingList2
......@@ -714,12 +714,13 @@ def _gen_affine_grid(
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
d = 0.5
base_grid = torch.empty(1, oh, ow, 3)
base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
base_grid[..., 0].copy_(torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow))
base_grid[..., 1].copy_(torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh).unsqueeze_(-1))
base_grid[..., 2].fill_(1)
output_grid = base_grid.view(1, oh * ow, 3).bmm(theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h]))
rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
......@@ -746,14 +747,15 @@ def affine(
_assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes)
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
mode = _interpolation_modes[resample]
return _apply_grid_transform(img, grid, mode)
def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]:
def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
......@@ -765,6 +767,7 @@ def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]:
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
])
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2)
min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0)
......@@ -807,16 +810,17 @@ def rotate(
}
_assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes)
theta = torch.tensor(matrix).reshape(1, 2, 3)
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_output_size(theta, w, h) if expand else (w, h)
ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
mode = _interpolation_modes[resample]
return _apply_grid_transform(img, grid, mode)
def _perspective_grid(coeffs: List[float], ow: int, oh: int):
def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.device):
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
......@@ -828,19 +832,20 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int):
theta1 = torch.tensor([[
[coeffs[0], coeffs[1], coeffs[2]],
[coeffs[3], coeffs[4], coeffs[5]]
]])
]], dtype=torch.float, device=device)
theta2 = torch.tensor([[
[coeffs[6], coeffs[7], 1.0],
[coeffs[6], coeffs[7], 1.0]
]])
]], dtype=torch.float, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3)
base_grid = torch.empty(1, oh, ow, 3, dtype=torch.float, device=device)
base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow))
base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1))
base_grid[..., 2].fill_(1)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh]))
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=torch.float, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
output_grid = output_grid1 / output_grid2 - 1.0
......@@ -880,7 +885,7 @@ def perspective(
)
ow, oh = img.shape[-1], img.shape[-2]
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh)
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, device=img.device)
mode = _interpolation_modes[interpolation]
return _apply_grid_transform(img, grid, mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment