Unverified Commit a629a9b2 authored by Vivek Kumar's avatar Vivek Kumar Committed by GitHub
Browse files

Port some tests to pytest in test_functional_tensor.py (#3988)

parent 7fb4ef57
......@@ -36,160 +36,89 @@ class Tester(unittest.TestCase):
def setUp(self):
self.device = "cpu"
def test_hsv2rgb(self):
scripted_fn = torch.jit.script(F_t._hsv2rgb)
shape = (3, 100, 150)
for _ in range(10):
hsv_img = torch.rand(*shape, dtype=torch.float, device=self.device)
rgb_img = F_t._hsv2rgb(hsv_img)
ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1)
h, s, v, = hsv_img.unbind(0)
h = h.flatten().cpu().numpy()
s = s.flatten().cpu().numpy()
v = v.flatten().cpu().numpy()
rgb = []
for h1, s1, v1 in zip(h, s, v):
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device)
torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5)
s_rgb_img = scripted_fn(hsv_img)
torch.testing.assert_close(rgb_img, s_rgb_img)
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=self.device).float()
_test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
def test_rgb2hsv(self):
scripted_fn = torch.jit.script(F_t._rgb2hsv)
shape = (3, 150, 100)
for _ in range(10):
rgb_img = torch.rand(*shape, dtype=torch.float, device=self.device)
hsv_img = F_t._rgb2hsv(rgb_img)
ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
r, g, b, = rgb_img.unbind(dim=-3)
r = r.flatten().cpu().numpy()
g = g.flatten().cpu().numpy()
b = b.flatten().cpu().numpy()
hsv = []
for r1, g1, b1 in zip(r, g, b):
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))
colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=self.device)
ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1)
colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1)
max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max()
max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max()
max_diff = max(max_diff_h, max_diff_sv)
self.assertLess(max_diff, 1e-5)
s_hsv_img = scripted_fn(rgb_img)
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=self.device).float()
_test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
for num_output_channels in (3, 1):
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
_assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
assert_equal(s_gray_tensor, gray_tensor)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
def test_center_crop(self):
script_center_crop = torch.jit.script(F.center_crop)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_image = F.center_crop(pil_img, [10, 11])
cropped_tensor = F.center_crop(img_tensor, [10, 11])
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
cropped_tensor = script_center_crop(img_tensor, [10, 11])
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
def test_five_crop(self):
script_five_crop = torch.jit.script(F.five_crop)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_images = F.five_crop(pil_img, [10, 11])
cropped_tensors = F.five_crop(img_tensor, [10, 11])
for i in range(5):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_five_crop(img_tensor, [10, 11])
for i in range(5):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
assert_equal(true_transformed_img, transformed_img)
# scriptable function test
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
assert_equal(transformed_batch, s_transformed_batch)
def test_ten_crop(self):
script_ten_crop = torch.jit.script(F.ten_crop)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_images = F.ten_crop(pil_img, [10, 11])
cropped_tensors = F.ten_crop(img_tensor, [10, 11])
for i in range(10):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_ten_crop(img_tensor, [10, 11])
for i in range(10):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
assert_equal(true_transformed_img, transformed_img)
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
img_size = pil_img.size
dt = tensor.dtype
for r in [NEAREST, ]:
for a in range(-180, 180, 17):
for e in [True, False]:
for c in centers:
for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
))
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, dt, a, e, c, f),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
)
def test_rotate(self):
# Tests on square image
scripted_rotate = torch.jit.script(F.rotate)
data = [_create_data(26, 26, device=self.device), _create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
img_size = pil_img.size
centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
center = (20, 22)
_test_fn_on_batch(
batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
)
tensor, pil_img = data[0]
# assert deprecation warning and non-BC
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.rotate(tensor, 45, resample=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
# scriptable function test
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
assert_equal(transformed_batch, s_transformed_batch)
# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
......@@ -1174,5 +1103,172 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
)
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_hsv2rgb(device):
scripted_fn = torch.jit.script(F_t._hsv2rgb)
shape = (3, 100, 150)
for _ in range(10):
hsv_img = torch.rand(*shape, dtype=torch.float, device=device)
rgb_img = F_t._hsv2rgb(hsv_img)
ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1)
h, s, v, = hsv_img.unbind(0)
h = h.flatten().cpu().numpy()
s = s.flatten().cpu().numpy()
v = v.flatten().cpu().numpy()
rgb = []
for h1, s1, v1 in zip(h, s, v):
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=device)
torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5)
s_rgb_img = scripted_fn(hsv_img)
torch.testing.assert_close(rgb_img, s_rgb_img)
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=device).float()
_test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_rgb2hsv(device):
scripted_fn = torch.jit.script(F_t._rgb2hsv)
shape = (3, 150, 100)
for _ in range(10):
rgb_img = torch.rand(*shape, dtype=torch.float, device=device)
hsv_img = F_t._rgb2hsv(rgb_img)
ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
r, g, b, = rgb_img.unbind(dim=-3)
r = r.flatten().cpu().numpy()
g = g.flatten().cpu().numpy()
b = b.flatten().cpu().numpy()
hsv = []
for r1, g1, b1 in zip(r, g, b):
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))
colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=device)
ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1)
colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1)
max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max()
max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max()
max_diff = max(max_diff_h, max_diff_sv)
assert max_diff < 1e-5
s_hsv_img = scripted_fn(rgb_img)
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=device).float()
_test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('num_output_channels', (3, 1))
def test_rgb_to_grayscale(device, num_output_channels):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
img_tensor, pil_img = _create_data(32, 34, device=device)
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
_assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
assert_equal(s_gray_tensor, gray_tensor)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
_test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_center_crop(device):
script_center_crop = torch.jit.script(F.center_crop)
img_tensor, pil_img = _create_data(32, 34, device=device)
cropped_pil_image = F.center_crop(pil_img, [10, 11])
cropped_tensor = F.center_crop(img_tensor, [10, 11])
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
cropped_tensor = script_center_crop(img_tensor, [10, 11])
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
_test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_five_crop(device):
script_five_crop = torch.jit.script(F.five_crop)
img_tensor, pil_img = _create_data(32, 34, device=device)
cropped_pil_images = F.five_crop(pil_img, [10, 11])
cropped_tensors = F.five_crop(img_tensor, [10, 11])
for i in range(5):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_five_crop(img_tensor, [10, 11])
for i in range(5):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
assert len(tuple_transformed_imgs) == len(tuple_transformed_batches)
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
assert_equal(true_transformed_img, transformed_img)
# scriptable function test
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
assert_equal(transformed_batch, s_transformed_batch)
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_ten_crop(device):
script_ten_crop = torch.jit.script(F.ten_crop)
img_tensor, pil_img = _create_data(32, 34, device=device)
cropped_pil_images = F.ten_crop(pil_img, [10, 11])
cropped_tensors = F.ten_crop(img_tensor, [10, 11])
for i in range(10):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_ten_crop(img_tensor, [10, 11])
for i in range(10):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
assert len(tuple_transformed_imgs) == len(tuple_transformed_batches)
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
assert_equal(true_transformed_img, transformed_img)
# scriptable function test
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
assert_equal(transformed_batch, s_transformed_batch)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment