Unverified Commit 97b05a89 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_transforms_tensor.py (#3885)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 55150bfb
...@@ -10,6 +10,7 @@ import unittest ...@@ -10,6 +10,7 @@ import unittest
from typing import Sequence from typing import Sequence
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
from _assert_utils import assert_equal
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
...@@ -38,7 +39,7 @@ class Tester(TransformsTester): ...@@ -38,7 +39,7 @@ class Tester(TransformsTester):
out1 = transform(tensor) out1 = transform(tensor)
torch.manual_seed(12) torch.manual_seed(12)
out2 = s_transform(tensor) out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2), msg=msg) assert_equal(out1, out2, msg=msg)
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None): def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
torch.manual_seed(12) torch.manual_seed(12)
...@@ -48,11 +49,11 @@ class Tester(TransformsTester): ...@@ -48,11 +49,11 @@ class Tester(TransformsTester):
img_tensor = batch_tensors[i, ...] img_tensor = batch_tensors[i, ...]
torch.manual_seed(12) torch.manual_seed(12)
transformed_img = transform(img_tensor) transformed_img = transform(img_tensor)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), msg=msg) assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
torch.manual_seed(12) torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors) s_transformed_batch = s_transform(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch), msg=msg) assert_equal(transformed_batch, s_transformed_batch, msg=msg)
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs): def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None: if meth_kwargs is None:
...@@ -75,7 +76,7 @@ class Tester(TransformsTester): ...@@ -75,7 +76,7 @@ class Tester(TransformsTester):
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script)) assert_equal(transformed_tensor, transformed_tensor_script)
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device) batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors) self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
...@@ -270,8 +271,11 @@ class Tester(TransformsTester): ...@@ -270,8 +271,11 @@ class Tester(TransformsTester):
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script)) self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length) self.assertEqual(len(transformed_t_list_script), out_length)
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script): for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), assert_equal(
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script)) transformed_tensor,
transformed_tensor_script,
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script),
)
# test for class interface # test for class interface
fn = getattr(T, method)(**meth_kwargs) fn = getattr(T, method)(**meth_kwargs)
...@@ -289,8 +293,11 @@ class Tester(TransformsTester): ...@@ -289,8 +293,11 @@ class Tester(TransformsTester):
torch.manual_seed(12) torch.manual_seed(12)
transformed_img_list = fn(img_tensor) transformed_img_list = fn(img_tensor)
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list): for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), assert_equal(
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...])) transformed_img,
transformed_batch[i, ...],
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]),
)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method))) scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))
...@@ -505,7 +512,7 @@ class Tester(TransformsTester): ...@@ -505,7 +512,7 @@ class Tester(TransformsTester):
transformed_batch = fn(batch_tensors) transformed_batch = fn(batch_tensors)
torch.manual_seed(12) torch.manual_seed(12)
s_transformed_batch = scripted_fn(batch_tensors) s_transformed_batch = scripted_fn(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch)) assert_equal(transformed_batch, s_transformed_batch)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt")) scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
...@@ -525,7 +532,7 @@ class Tester(TransformsTester): ...@@ -525,7 +532,7 @@ class Tester(TransformsTester):
transformed_tensor = transforms(tensor) transformed_tensor = transforms(tensor)
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms)) assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
t = T.Compose([ t = T.Compose([
lambda x: x, lambda x: x,
...@@ -551,7 +558,7 @@ class Tester(TransformsTester): ...@@ -551,7 +558,7 @@ class Tester(TransformsTester):
transformed_tensor = transforms(tensor) transformed_tensor = transforms(tensor)
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms)) assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
if torch.device(self.device).type == "cpu": if torch.device(self.device).type == "cpu":
# Can't check this twice, otherwise # Can't check this twice, otherwise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment