Unverified Commit a629a9b2 authored by Vivek Kumar's avatar Vivek Kumar Committed by GitHub
Browse files

Port some tests to pytest in test_functional_tensor.py (#3988)

parent 7fb4ef57
...@@ -36,160 +36,89 @@ class Tester(unittest.TestCase): ...@@ -36,160 +36,89 @@ class Tester(unittest.TestCase):
def setUp(self): def setUp(self):
self.device = "cpu" self.device = "cpu"
def test_hsv2rgb(self): def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
scripted_fn = torch.jit.script(F_t._hsv2rgb) img_size = pil_img.size
shape = (3, 100, 150) dt = tensor.dtype
for _ in range(10): for r in [NEAREST, ]:
hsv_img = torch.rand(*shape, dtype=torch.float, device=self.device) for a in range(-180, 180, 17):
rgb_img = F_t._hsv2rgb(hsv_img) for e in [True, False]:
ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1) for c in centers:
for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
h, s, v, = hsv_img.unbind(0) f_pil = int(f[0]) if f is not None and len(f) == 1 else f
h = h.flatten().cpu().numpy() out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
s = s.flatten().cpu().numpy() out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
v = v.flatten().cpu().numpy() for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()
rgb = []
for h1, s1, v1 in zip(h, s, v): if out_tensor.dtype != torch.uint8:
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1)) out_tensor = out_tensor.to(torch.uint8)
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device)
torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5) self.assertEqual(
out_tensor.shape,
s_rgb_img = scripted_fn(hsv_img) out_pil_tensor.shape,
torch.testing.assert_close(rgb_img, s_rgb_img) msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=self.device).float() ))
_test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
def test_rgb2hsv(self): ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
scripted_fn = torch.jit.script(F_t._rgb2hsv) # Tolerance : less than 3% of different pixels
shape = (3, 150, 100) self.assertLess(
for _ in range(10): ratio_diff_pixels,
rgb_img = torch.rand(*shape, dtype=torch.float, device=self.device) 0.03,
hsv_img = F_t._rgb2hsv(rgb_img) msg="{}: {}\n{} vs \n{}".format(
ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1) (img_size, r, dt, a, e, c, f),
ratio_diff_pixels,
r, g, b, = rgb_img.unbind(dim=-3) out_tensor[0, :7, :7],
r = r.flatten().cpu().numpy() out_pil_tensor[0, :7, :7]
g = g.flatten().cpu().numpy() )
b = b.flatten().cpu().numpy() )
hsv = [] def test_rotate(self):
for r1, g1, b1 in zip(r, g, b): # Tests on square image
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1)) scripted_rotate = torch.jit.script(F.rotate)
colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=self.device) data = [_create_data(26, 26, device=self.device), _create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1)
colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1) img_size = pil_img.size
centers = [
max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max() None,
max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max() (int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
max_diff = max(max_diff_h, max_diff_sv) [int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
self.assertLess(max_diff, 1e-5) ]
s_hsv_img = scripted_fn(rgb_img) for dt in [None, torch.float32, torch.float64, torch.float16]:
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
if dt == torch.float16 and torch.device(self.device).type == "cpu":
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=self.device).float() # skip float16 on CPU case
_test_fn_on_batch(batch_tensors, F_t._rgb2hsv) continue
def test_rgb_to_grayscale(self): if dt is not None:
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) tensor = tensor.to(dtype=dt)
img_tensor, pil_img = _create_data(32, 34, device=self.device) self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)
for num_output_channels in (3, 1): batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels) if dt is not None:
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) batch_tensors = batch_tensors.to(dtype=dt)
_assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max") center = (20, 22)
_test_fn_on_batch(
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
assert_equal(s_gray_tensor, gray_tensor) )
tensor, pil_img = data[0]
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device) # assert deprecation warning and non-BC
_test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels) with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.rotate(tensor, 45, resample=2)
def test_center_crop(self): res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
script_center_crop = torch.jit.script(F.center_crop) assert_equal(res1, res2)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_image = F.center_crop(pil_img, [10, 11])
cropped_tensor = F.center_crop(img_tensor, [10, 11])
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
cropped_tensor = script_center_crop(img_tensor, [10, 11])
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
def test_five_crop(self):
script_five_crop = torch.jit.script(F.five_crop)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_images = F.five_crop(pil_img, [10, 11])
cropped_tensors = F.five_crop(img_tensor, [10, 11])
for i in range(5):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_five_crop(img_tensor, [10, 11])
for i in range(5):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
assert_equal(true_transformed_img, transformed_img)
# scriptable function test
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
assert_equal(transformed_batch, s_transformed_batch)
def test_ten_crop(self):
script_ten_crop = torch.jit.script(F.ten_crop)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_images = F.ten_crop(pil_img, [10, 11])
cropped_tensors = F.ten_crop(img_tensor, [10, 11])
for i in range(10):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_ten_crop(img_tensor, [10, 11])
for i in range(10):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
assert_equal(true_transformed_img, transformed_img)
# scriptable function test # assert changed type warning
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11]) with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches): res1 = F.rotate(tensor, 45, interpolation=2)
assert_equal(transformed_batch, s_transformed_batch) res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
...@@ -1174,5 +1103,172 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): ...@@ -1174,5 +1103,172 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
) )
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_hsv2rgb(device):
scripted_fn = torch.jit.script(F_t._hsv2rgb)
shape = (3, 100, 150)
for _ in range(10):
hsv_img = torch.rand(*shape, dtype=torch.float, device=device)
rgb_img = F_t._hsv2rgb(hsv_img)
ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1)
h, s, v, = hsv_img.unbind(0)
h = h.flatten().cpu().numpy()
s = s.flatten().cpu().numpy()
v = v.flatten().cpu().numpy()
rgb = []
for h1, s1, v1 in zip(h, s, v):
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=device)
torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5)
s_rgb_img = scripted_fn(hsv_img)
torch.testing.assert_close(rgb_img, s_rgb_img)
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=device).float()
_test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_rgb2hsv(device):
scripted_fn = torch.jit.script(F_t._rgb2hsv)
shape = (3, 150, 100)
for _ in range(10):
rgb_img = torch.rand(*shape, dtype=torch.float, device=device)
hsv_img = F_t._rgb2hsv(rgb_img)
ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
r, g, b, = rgb_img.unbind(dim=-3)
r = r.flatten().cpu().numpy()
g = g.flatten().cpu().numpy()
b = b.flatten().cpu().numpy()
hsv = []
for r1, g1, b1 in zip(r, g, b):
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))
colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=device)
ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1)
colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1)
max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max()
max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max()
max_diff = max(max_diff_h, max_diff_sv)
assert max_diff < 1e-5
s_hsv_img = scripted_fn(rgb_img)
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=device).float()
_test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('num_output_channels', (3, 1))
def test_rgb_to_grayscale(device, num_output_channels):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
img_tensor, pil_img = _create_data(32, 34, device=device)
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
_assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
assert_equal(s_gray_tensor, gray_tensor)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
_test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_center_crop(device):
script_center_crop = torch.jit.script(F.center_crop)
img_tensor, pil_img = _create_data(32, 34, device=device)
cropped_pil_image = F.center_crop(pil_img, [10, 11])
cropped_tensor = F.center_crop(img_tensor, [10, 11])
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
cropped_tensor = script_center_crop(img_tensor, [10, 11])
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
_test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_five_crop(device):
script_five_crop = torch.jit.script(F.five_crop)
img_tensor, pil_img = _create_data(32, 34, device=device)
cropped_pil_images = F.five_crop(pil_img, [10, 11])
cropped_tensors = F.five_crop(img_tensor, [10, 11])
for i in range(5):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_five_crop(img_tensor, [10, 11])
for i in range(5):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
assert len(tuple_transformed_imgs) == len(tuple_transformed_batches)
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
assert_equal(true_transformed_img, transformed_img)
# scriptable function test
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
assert_equal(transformed_batch, s_transformed_batch)
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_ten_crop(device):
script_ten_crop = torch.jit.script(F.ten_crop)
img_tensor, pil_img = _create_data(32, 34, device=device)
cropped_pil_images = F.ten_crop(pil_img, [10, 11])
cropped_tensors = F.ten_crop(img_tensor, [10, 11])
for i in range(10):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_ten_crop(img_tensor, [10, 11])
for i in range(10):
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
assert len(tuple_transformed_imgs) == len(tuple_transformed_batches)
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
assert_equal(true_transformed_img, transformed_img)
# scriptable function test
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
assert_equal(transformed_batch, s_transformed_batch)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment