Unverified Commit 182f80df authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Finish porting test_functional_tensor.py to pytest (#3990)

parent a629a9b2
import itertools import itertools
import os import os
import unittest
import colorsys import colorsys
import math import math
...@@ -31,103 +30,8 @@ from typing import Dict, List, Sequence, Tuple ...@@ -31,103 +30,8 @@ from typing import Dict, List, Sequence, Tuple
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
class Tester(unittest.TestCase): @needs_cuda
def test_scale_channel():
def setUp(self):
self.device = "cpu"
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
img_size = pil_img.size
dt = tensor.dtype
for r in [NEAREST, ]:
for a in range(-180, 180, 17):
for e in [True, False]:
for c in centers:
for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
))
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, dt, a, e, c, f),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
)
def test_rotate(self):
# Tests on square image
scripted_rotate = torch.jit.script(F.rotate)
data = [_create_data(26, 26, device=self.device), _create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
img_size = pil_img.size
centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
center = (20, 22)
_test_fn_on_batch(
batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
)
tensor, pil_img = data[0]
# assert deprecation warning and non-BC
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.rotate(tensor, 45, resample=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
def setUp(self):
self.device = "cuda"
def test_scale_channel(self):
"""Make sure that _scale_channel gives the same results on CPU and GPU as """Make sure that _scale_channel gives the same results on CPU and GPU as
histc or bincount are used depending on the device. histc or bincount are used depending on the device.
""" """
...@@ -1271,4 +1175,4 @@ def test_ten_crop(device): ...@@ -1271,4 +1175,4 @@ def test_ten_crop(device):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment