Unverified Commit 4c95bb6e authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Port test_rotate in test_functional_tensor.py to pytest (#3983)

parent 964ce1e9
......@@ -191,90 +191,6 @@ class Tester(unittest.TestCase):
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
assert_equal(transformed_batch, s_transformed_batch)
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
img_size = pil_img.size
dt = tensor.dtype
for r in [NEAREST, ]:
for a in range(-180, 180, 17):
for e in [True, False]:
for c in centers:
for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
))
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, dt, a, e, c, f),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
)
def test_rotate(self):
# Tests on square image
scripted_rotate = torch.jit.script(F.rotate)
data = [_create_data(26, 26, device=self.device), _create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
img_size = pil_img.size
centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
center = (20, 22)
_test_fn_on_batch(
batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
)
tensor, pil_img = data[0]
# assert deprecation warning and non-BC
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.rotate(tensor, 45, resample=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
......@@ -295,6 +211,88 @@ class CUDATester(Tester):
assert_equal(scaled_cpu, scaled_cuda.to('cpu'))
class TestRotate:
ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16]
scripted_rotate = torch.jit.script(F.rotate)
IMG_W = 26
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('height, width', [(26, IMG_W), (32, IMG_W)])
@pytest.mark.parametrize('center', [
None,
(int(IMG_W * 0.3), int(IMG_W * 0.4)),
[int(IMG_W * 0.5), int(IMG_W * 0.6)],
])
@pytest.mark.parametrize('dt', ALL_DTYPES)
@pytest.mark.parametrize('angle', range(-180, 180, 17))
@pytest.mark.parametrize('expand', [True, False])
@pytest.mark.parametrize('fill', [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )])
@pytest.mark.parametrize('fn', [F.rotate, scripted_rotate])
def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn):
tensor, pil_img = _create_data(height, width, device=device)
if dt == torch.float16 and torch.device(device).type == "cpu":
# skip float16 on CPU case
return
if dt is not None:
tensor = tensor.to(dtype=dt)
f_pil = int(fill[0]) if fill is not None and len(fill) == 1 else fill
out_pil_img = F.rotate(pil_img, angle=angle, interpolation=NEAREST, expand=expand, center=center, fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
out_tensor = fn(tensor, angle=angle, interpolation=NEAREST, expand=expand, center=center, fill=fill).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
assert out_tensor.shape == out_pil_tensor.shape, (
f"{(height, width, NEAREST, dt, angle, expand, center)}: "
f"{out_tensor.shape} vs {out_pil_tensor.shape}")
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
assert ratio_diff_pixels < 0.03, (
f"{(height, width, NEAREST, dt, angle, expand, center, fill)}: "
f"{ratio_diff_pixels}\n{out_tensor[0, :7, :7]} vs \n"
f"{out_pil_tensor[0, :7, :7]}")
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dt', ALL_DTYPES)
def test_rotate_batch(self, device, dt):
if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
center = (20, 22)
_test_fn_on_batch(
batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
)
def test_rotate_deprecation_resample(self):
tensor, _ = _create_data(26, 26)
# assert deprecation warning and non-BC
with pytest.warns(UserWarning, match=r"Argument resample is deprecated and will be removed"):
res1 = F.rotate(tensor, 45, resample=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
def test_rotate_interpolation_type(self):
tensor, _ = _create_data(26, 26)
# assert changed type warning
with pytest.warns(UserWarning, match=r"Argument interpolation should be of type InterpolationMode"):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
class TestAffine:
ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16]
......@@ -473,7 +471,7 @@ class TestAffine:
# Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
tol = 0.06 if device == "cuda" else 0.05
assert ratio_diff_pixels < tol, "{}: {}\n{} vs \n{}".format(
(i, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
(NEAREST, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
@pytest.mark.parametrize('device', cpu_and_gpu())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment