Unverified Commit 97b05a89 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_transforms_tensor.py (#3885)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 55150bfb
......@@ -10,6 +10,7 @@ import unittest
from typing import Sequence
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
from _assert_utils import assert_equal
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
......@@ -38,7 +39,7 @@ class Tester(TransformsTester):
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2), msg=msg)
assert_equal(out1, out2, msg=msg)
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
torch.manual_seed(12)
......@@ -48,11 +49,11 @@ class Tester(TransformsTester):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img = transform(img_tensor)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), msg=msg)
assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch), msg=msg)
assert_equal(transformed_batch, s_transformed_batch, msg=msg)
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None:
......@@ -75,7 +76,7 @@ class Tester(TransformsTester):
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
assert_equal(transformed_tensor, transformed_tensor_script)
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
......@@ -270,8 +271,11 @@ class Tester(TransformsTester):
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length)
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))
assert_equal(
transformed_tensor,
transformed_tensor_script,
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script),
)
# test for class interface
fn = getattr(T, method)(**meth_kwargs)
......@@ -289,8 +293,11 @@ class Tester(TransformsTester):
torch.manual_seed(12)
transformed_img_list = fn(img_tensor)
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))
assert_equal(
transformed_img,
transformed_batch[i, ...],
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]),
)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))
......@@ -505,7 +512,7 @@ class Tester(TransformsTester):
transformed_batch = fn(batch_tensors)
torch.manual_seed(12)
s_transformed_batch = scripted_fn(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))
assert_equal(transformed_batch, s_transformed_batch)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
......@@ -525,7 +532,7 @@ class Tester(TransformsTester):
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
t = T.Compose([
lambda x: x,
......@@ -551,7 +558,7 @@ class Tester(TransformsTester):
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
if torch.device(self.device).type == "cpu":
# Can't check this twice, otherwise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment