Unverified Commit b96d381c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_functional_tensor (#3876)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 963d432c
...@@ -15,6 +15,7 @@ import torchvision.transforms as T ...@@ -15,6 +15,7 @@ import torchvision.transforms as T
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from common_utils import TransformsTester, cpu_and_gpu, needs_cuda from common_utils import TransformsTester, cpu_and_gpu, needs_cuda
from _assert_utils import assert_equal
from typing import Dict, List, Sequence, Tuple from typing import Dict, List, Sequence, Tuple
...@@ -39,13 +40,13 @@ class Tester(TransformsTester): ...@@ -39,13 +40,13 @@ class Tester(TransformsTester):
for i in range(len(batch_tensors)): for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...] img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs) transformed_img = fn(img_tensor, **fn_kwargs)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...])) assert_equal(transformed_img, transformed_batch[i, ...])
if scripted_fn_atol >= 0: if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
# scriptable function test # scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs) s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
self.assertTrue(transformed_batch.allclose(s_transformed_batch, atol=scripted_fn_atol)) torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
def test_assert_image_tensor(self): def test_assert_image_tensor(self):
shape = (100,) shape = (100,)
...@@ -79,7 +80,7 @@ class Tester(TransformsTester): ...@@ -79,7 +80,7 @@ class Tester(TransformsTester):
# scriptable function test # scriptable function test
vflipped_img_script = script_vflip(img_tensor) vflipped_img_script = script_vflip(img_tensor)
self.assertTrue(vflipped_img.equal(vflipped_img_script)) assert_equal(vflipped_img, vflipped_img_script)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.vflip) self._test_fn_on_batch(batch_tensors, F.vflip)
...@@ -94,7 +95,7 @@ class Tester(TransformsTester): ...@@ -94,7 +95,7 @@ class Tester(TransformsTester):
# scriptable function test # scriptable function test
hflipped_img_script = script_hflip(img_tensor) hflipped_img_script = script_hflip(img_tensor)
self.assertTrue(hflipped_img.equal(hflipped_img_script)) assert_equal(hflipped_img, hflipped_img_script)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.hflip) self._test_fn_on_batch(batch_tensors, F.hflip)
...@@ -140,11 +141,10 @@ class Tester(TransformsTester): ...@@ -140,11 +141,10 @@ class Tester(TransformsTester):
for h1, s1, v1 in zip(h, s, v): for h1, s1, v1 in zip(h, s, v):
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1)) rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device) colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device)
max_diff = (ft_img - colorsys_img).abs().max() torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5)
self.assertLess(max_diff, 1e-5)
s_rgb_img = scripted_fn(hsv_img) s_rgb_img = scripted_fn(hsv_img)
self.assertTrue(rgb_img.allclose(s_rgb_img)) torch.testing.assert_close(rgb_img, s_rgb_img)
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float() batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb) self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
...@@ -177,7 +177,7 @@ class Tester(TransformsTester): ...@@ -177,7 +177,7 @@ class Tester(TransformsTester):
self.assertLess(max_diff, 1e-5) self.assertLess(max_diff, 1e-5)
s_hsv_img = scripted_fn(rgb_img) s_hsv_img = scripted_fn(rgb_img)
self.assertTrue(hsv_img.allclose(s_hsv_img, atol=1e-7)) torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float() batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv) self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
...@@ -194,7 +194,7 @@ class Tester(TransformsTester): ...@@ -194,7 +194,7 @@ class Tester(TransformsTester):
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max") self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.assertTrue(s_gray_tensor.equal(gray_tensor)) assert_equal(s_gray_tensor, gray_tensor)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels) self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
...@@ -240,12 +240,12 @@ class Tester(TransformsTester): ...@@ -240,12 +240,12 @@ class Tester(TransformsTester):
for j in range(len(tuple_transformed_imgs)): for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j] true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...] transformed_img = tuple_transformed_batches[j][i, ...]
self.assertTrue(true_transformed_img.equal(transformed_img)) assert_equal(true_transformed_img, transformed_img)
# scriptable function test # scriptable function test
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11]) s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches): for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
self.assertTrue(transformed_batch.equal(s_transformed_batch)) assert_equal(transformed_batch, s_transformed_batch)
def test_ten_crop(self): def test_ten_crop(self):
script_ten_crop = torch.jit.script(F.ten_crop) script_ten_crop = torch.jit.script(F.ten_crop)
...@@ -272,12 +272,12 @@ class Tester(TransformsTester): ...@@ -272,12 +272,12 @@ class Tester(TransformsTester):
for j in range(len(tuple_transformed_imgs)): for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j] true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...] transformed_img = tuple_transformed_batches[j][i, ...]
self.assertTrue(true_transformed_img.equal(transformed_img)) assert_equal(true_transformed_img, transformed_img)
# scriptable function test # scriptable function test
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11]) s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches): for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
self.assertTrue(transformed_batch.equal(s_transformed_batch)) assert_equal(transformed_batch, s_transformed_batch)
def test_pad(self): def test_pad(self):
script_fn = torch.jit.script(F.pad) script_fn = torch.jit.script(F.pad)
...@@ -320,7 +320,7 @@ class Tester(TransformsTester): ...@@ -320,7 +320,7 @@ class Tester(TransformsTester):
else: else:
script_pad = pad script_pad = pad
pad_tensor_script = script_fn(tensor, script_pad, **kwargs) pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs)) assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, kwargs))
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs) self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
...@@ -348,9 +348,10 @@ class Tester(TransformsTester): ...@@ -348,9 +348,10 @@ class Tester(TransformsTester):
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size) resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
self.assertEqual( assert_equal(
resized_tensor.size()[1:], resized_pil_img.size[::-1], resized_tensor.size()[1:],
msg="{}, {}".format(size, interpolation) resized_pil_img.size[::-1],
msg="{}, {}".format(size, interpolation),
) )
if interpolation not in [NEAREST, ]: if interpolation not in [NEAREST, ]:
...@@ -374,7 +375,7 @@ class Tester(TransformsTester): ...@@ -374,7 +375,7 @@ class Tester(TransformsTester):
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
max_size=max_size) max_size=max_size)
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) assert_equal(resized_tensor, resize_result, msg="{}, {}".format(size, interpolation))
self._test_fn_on_batch( self._test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
...@@ -384,7 +385,7 @@ class Tester(TransformsTester): ...@@ -384,7 +385,7 @@ class Tester(TransformsTester):
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.resize(tensor, size=32, interpolation=2) res1 = F.resize(tensor, size=32, interpolation=2)
res2 = F.resize(tensor, size=32, interpolation=BILINEAR) res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2)) assert_equal(res1, res2)
for img in (tensor, pil_img): for img in (tensor, pil_img):
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge" exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
...@@ -400,15 +401,17 @@ class Tester(TransformsTester): ...@@ -400,15 +401,17 @@ class Tester(TransformsTester):
for mode in [NEAREST, BILINEAR, BICUBIC]: for mode in [NEAREST, BILINEAR, BICUBIC]:
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode) out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
# 2) resize by half and crop a TL corner # 2) resize by half and crop a TL corner
tensor, _ = self._create_data(26, 36, device=self.device) tensor, _ = self._create_data(26, 36, device=self.device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST) out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
expected_out_tensor = tensor[:, :20:2, :30:2] expected_out_tensor = tensor[:, :20:2, :30:2]
self.assertTrue( assert_equal(
expected_out_tensor.equal(out_tensor), expected_out_tensor,
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]) out_tensor,
check_stride=False,
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]),
) )
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
...@@ -420,15 +423,11 @@ class Tester(TransformsTester): ...@@ -420,15 +423,11 @@ class Tester(TransformsTester):
# 1) identity map # 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
self.assertTrue( assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
out_tensor = scripted_affine( out_tensor = scripted_affine(
tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
) )
self.assertTrue( assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine): def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
# 2) Test rotation # 2) Test rotation
...@@ -452,9 +451,11 @@ class Tester(TransformsTester): ...@@ -452,9 +451,11 @@ class Tester(TransformsTester):
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
) )
if true_tensor is not None: if true_tensor is not None:
self.assertTrue( assert_equal(
true_tensor.equal(out_tensor), true_tensor,
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]) out_tensor,
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]),
check_stride=False,
) )
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
...@@ -593,18 +594,19 @@ class Tester(TransformsTester): ...@@ -593,18 +594,19 @@ class Tester(TransformsTester):
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2) res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
self.assertTrue(res1.equal(res2)) assert_equal(res1, res2)
# assert changed type warning # assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2) res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
self.assertTrue(res1.equal(res2)) assert_equal(res1, res2)
with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"): with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10) res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10)
res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10) res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10)
self.assertEqual(res1, res2) # we convert the PIL images to numpy as assert_equal doesn't work on PIL images.
assert_equal(np.asarray(res1), np.asarray(res2))
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
img_size = pil_img.size img_size = pil_img.size
...@@ -682,13 +684,13 @@ class Tester(TransformsTester): ...@@ -682,13 +684,13 @@ class Tester(TransformsTester):
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.rotate(tensor, 45, resample=2) res1 = F.rotate(tensor, 45, resample=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR) res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2)) assert_equal(res1, res2)
# assert changed type warning # assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.rotate(tensor, 45, interpolation=2) res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR) res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2)) assert_equal(res1, res2)
def test_gaussian_blur(self): def test_gaussian_blur(self):
small_image_tensor = torch.from_numpy( small_image_tensor = torch.from_numpy(
...@@ -747,10 +749,8 @@ class Tester(TransformsTester): ...@@ -747,10 +749,8 @@ class Tester(TransformsTester):
for fn in [F.gaussian_blur, scripted_transform]: for fn in [F.gaussian_blur, scripted_transform]:
out = fn(tensor, kernel_size=ksize, sigma=sigma) out = fn(tensor, kernel_size=ksize, sigma=sigma)
self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma)) torch.testing.assert_close(
self.assertLessEqual( out, true_out, rtol=0.0, atol=1.0, check_stride=False,
torch.max(true_out.float() - out.float()),
1.0,
msg="{}, {}".format(ksize, sigma) msg="{}, {}".format(ksize, sigma)
) )
...@@ -771,7 +771,7 @@ class CUDATester(Tester): ...@@ -771,7 +771,7 @@ class CUDATester(Tester):
img_chan = torch.randint(0, 256, size=size).to('cpu') img_chan = torch.randint(0, 256, size=size).to('cpu')
scaled_cpu = F_t._scale_channel(img_chan) scaled_cpu = F_t._scale_channel(img_chan)
scaled_cuda = F_t._scale_channel(img_chan.to('cuda')) scaled_cuda = F_t._scale_channel(img_chan.to('cuda'))
self.assertTrue(scaled_cpu.equal(scaled_cuda.to('cpu'))) assert_equal(scaled_cpu, scaled_cuda.to('cpu'))
def _get_data_dims_and_points_for_perspective(): def _get_data_dims_and_points_for_perspective():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment