Unverified Commit b96d381c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_functional_tensor (#3876)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 963d432c
......@@ -15,6 +15,7 @@ import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from common_utils import TransformsTester, cpu_and_gpu, needs_cuda
from _assert_utils import assert_equal
from typing import Dict, List, Sequence, Tuple
......@@ -39,13 +40,13 @@ class Tester(TransformsTester):
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))
assert_equal(transformed_img, transformed_batch[i, ...])
if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
self.assertTrue(transformed_batch.allclose(s_transformed_batch, atol=scripted_fn_atol))
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
def test_assert_image_tensor(self):
shape = (100,)
......@@ -79,7 +80,7 @@ class Tester(TransformsTester):
# scriptable function test
vflipped_img_script = script_vflip(img_tensor)
self.assertTrue(vflipped_img.equal(vflipped_img_script))
assert_equal(vflipped_img, vflipped_img_script)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.vflip)
......@@ -94,7 +95,7 @@ class Tester(TransformsTester):
# scriptable function test
hflipped_img_script = script_hflip(img_tensor)
self.assertTrue(hflipped_img.equal(hflipped_img_script))
assert_equal(hflipped_img, hflipped_img_script)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.hflip)
......@@ -140,11 +141,10 @@ class Tester(TransformsTester):
for h1, s1, v1 in zip(h, s, v):
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device)
max_diff = (ft_img - colorsys_img).abs().max()
self.assertLess(max_diff, 1e-5)
torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5)
s_rgb_img = scripted_fn(hsv_img)
self.assertTrue(rgb_img.allclose(s_rgb_img))
torch.testing.assert_close(rgb_img, s_rgb_img)
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
......@@ -177,7 +177,7 @@ class Tester(TransformsTester):
self.assertLess(max_diff, 1e-5)
s_hsv_img = scripted_fn(rgb_img)
self.assertTrue(hsv_img.allclose(s_hsv_img, atol=1e-7))
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
......@@ -194,7 +194,7 @@ class Tester(TransformsTester):
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.assertTrue(s_gray_tensor.equal(gray_tensor))
assert_equal(s_gray_tensor, gray_tensor)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
......@@ -240,12 +240,12 @@ class Tester(TransformsTester):
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
self.assertTrue(true_transformed_img.equal(transformed_img))
assert_equal(true_transformed_img, transformed_img)
# scriptable function test
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
self.assertTrue(transformed_batch.equal(s_transformed_batch))
assert_equal(transformed_batch, s_transformed_batch)
def test_ten_crop(self):
script_ten_crop = torch.jit.script(F.ten_crop)
......@@ -272,12 +272,12 @@ class Tester(TransformsTester):
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
self.assertTrue(true_transformed_img.equal(transformed_img))
assert_equal(true_transformed_img, transformed_img)
# scriptable function test
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
self.assertTrue(transformed_batch.equal(s_transformed_batch))
assert_equal(transformed_batch, s_transformed_batch)
def test_pad(self):
script_fn = torch.jit.script(F.pad)
......@@ -320,7 +320,7 @@ class Tester(TransformsTester):
else:
script_pad = pad
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, kwargs))
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
......@@ -348,9 +348,10 @@ class Tester(TransformsTester):
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
self.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1],
msg="{}, {}".format(size, interpolation)
assert_equal(
resized_tensor.size()[1:],
resized_pil_img.size[::-1],
msg="{}, {}".format(size, interpolation),
)
if interpolation not in [NEAREST, ]:
......@@ -374,7 +375,7 @@ class Tester(TransformsTester):
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
max_size=max_size)
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
assert_equal(resized_tensor, resize_result, msg="{}, {}".format(size, interpolation))
self._test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
......@@ -384,7 +385,7 @@ class Tester(TransformsTester):
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.resize(tensor, size=32, interpolation=2)
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
assert_equal(res1, res2)
for img in (tensor, pil_img):
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
......@@ -400,15 +401,17 @@ class Tester(TransformsTester):
for mode in [NEAREST, BILINEAR, BICUBIC]:
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
# 2) resize by half and crop a TL corner
tensor, _ = self._create_data(26, 36, device=self.device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
expected_out_tensor = tensor[:, :20:2, :30:2]
self.assertTrue(
expected_out_tensor.equal(out_tensor),
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
assert_equal(
expected_out_tensor,
out_tensor,
check_stride=False,
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]),
)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
......@@ -420,15 +423,11 @@ class Tester(TransformsTester):
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
out_tensor = scripted_affine(
tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
# 2) Test rotation
......@@ -452,9 +451,11 @@ class Tester(TransformsTester):
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
)
if true_tensor is not None:
self.assertTrue(
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
assert_equal(
true_tensor,
out_tensor,
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]),
check_stride=False,
)
if out_tensor.dtype != torch.uint8:
......@@ -593,18 +594,19 @@ class Tester(TransformsTester):
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
assert_equal(res1, res2)
# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
assert_equal(res1, res2)
with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10)
res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10)
self.assertEqual(res1, res2)
# we convert the PIL images to numpy as assert_equal doesn't work on PIL images.
assert_equal(np.asarray(res1), np.asarray(res2))
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
img_size = pil_img.size
......@@ -682,13 +684,13 @@ class Tester(TransformsTester):
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.rotate(tensor, 45, resample=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
assert_equal(res1, res2)
# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
assert_equal(res1, res2)
def test_gaussian_blur(self):
small_image_tensor = torch.from_numpy(
......@@ -747,10 +749,8 @@ class Tester(TransformsTester):
for fn in [F.gaussian_blur, scripted_transform]:
out = fn(tensor, kernel_size=ksize, sigma=sigma)
self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma))
self.assertLessEqual(
torch.max(true_out.float() - out.float()),
1.0,
torch.testing.assert_close(
out, true_out, rtol=0.0, atol=1.0, check_stride=False,
msg="{}, {}".format(ksize, sigma)
)
......@@ -771,7 +771,7 @@ class CUDATester(Tester):
img_chan = torch.randint(0, 256, size=size).to('cpu')
scaled_cpu = F_t._scale_channel(img_chan)
scaled_cuda = F_t._scale_channel(img_chan.to('cuda'))
self.assertTrue(scaled_cpu.equal(scaled_cuda.to('cpu')))
assert_equal(scaled_cpu, scaled_cuda.to('cpu'))
def _get_data_dims_and_points_for_perspective():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment