Unverified Commit 31d53367 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove TransformsTester (#3946)

parent d1f1a544
......@@ -325,54 +325,6 @@ def freeze_rng_state():
torch.set_rng_state(rng_state)
class TransformsTester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"):
batch_tensor = torch.randint(
0, 256,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
)
return batch_tensor
def compareTensorToPIL(self, tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
assert_equal(tensor.cpu(), pil_tensor, check_stride=False, msg=msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
allowed_percentage_diff=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
if allowed_percentage_diff is not None:
# Assert that less than a given %age of pixels are different
self.assertTrue(
(tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
)
# error value can be mean absolute error, max abs error
# Convert to float to avoid underflow when computing absolute difference
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
self.assertTrue(
err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
def cycle_over(objs):
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]
......@@ -457,3 +409,65 @@ def cpu_only(test_func):
return pytest.mark.dont_collect(test_func)
else:
return test_func
def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
batch_tensor = torch.randint(
0, 256,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
)
return batch_tensor
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
assert_equal(tensor.cpu(), pil_tensor, check_stride=False, msg=msg)
def _assert_approx_equal_tensor_to_pil(tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
allowed_percentage_diff=None):
# TODO: we could just merge this into _assert_equal_tensor_to_pil
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
if allowed_percentage_diff is not None:
# Assert that less than a given %age of pixels are different
assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
# error value can be mean absolute error, max abs error
# Convert to float to avoid underflow when computing absolute difference
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
assert err < tol
def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
transformed_batch = fn(batch_tensors, **fn_kwargs)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
assert_equal(transformed_img, transformed_batch[i, ...])
if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
This diff is collapsed.
......@@ -9,14 +9,22 @@ import numpy as np
import unittest
from typing import Sequence
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
from common_utils import (
get_tmp_dir,
int_dtypes,
float_dtypes,
_create_data,
_create_data_batch,
_assert_equal_tensor_to_pil,
_assert_approx_equal_tensor_to_pil,
)
from _assert_utils import assert_equal
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
class Tester(TransformsTester):
class Tester(unittest.TestCase):
def setUp(self):
self.device = "cpu"
......@@ -26,13 +34,13 @@ class Tester(TransformsTester):
fn_kwargs = {}
f = getattr(F, func)
tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
tensor, pil_img = _create_data(height=10, width=10, device=self.device)
transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs)
if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
self.approxEqualTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
_assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
torch.manual_seed(12)
......@@ -63,22 +71,22 @@ class Tester(TransformsTester):
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
tensor, pil_img = self._create_data(26, 34, device=self.device)
tensor, pil_img = _create_data(26, 34, device=self.device)
# set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12)
transformed_tensor = f(tensor)
torch.manual_seed(12)
transformed_pil_img = f(pil_img)
if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
_assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script)
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
......@@ -259,13 +267,13 @@ class Tester(TransformsTester):
fn = getattr(F, func)
scripted_fn = torch.jit.script(fn)
tensor, pil_img = self._create_data(height=20, width=20, device=self.device)
tensor, pil_img = _create_data(height=20, width=20, device=self.device)
transformed_t_list = fn(tensor, **fn_kwargs)
transformed_p_list = fn(pil_img, **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_p_list))
self.assertEqual(len(transformed_t_list), out_length)
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
......@@ -284,7 +292,7 @@ class Tester(TransformsTester):
self.assertEqual(len(output), len(transformed_t_list_script))
# test on batch of tensors
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
torch.manual_seed(12)
transformed_batch_list = fn(batch_tensors)
......@@ -350,7 +358,7 @@ class Tester(TransformsTester):
self.assertEqual(y.shape[1], 38)
self.assertEqual(y.shape[2], int(38 * 46 / 32))
tensor, _ = self._create_data(height=34, width=36, device=self.device)
tensor, _ = _create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for dt in [None, torch.float32, torch.float64]:
......@@ -487,7 +495,7 @@ class Tester(TransformsTester):
def test_normalize(self):
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
tensor, _ = self._create_data(26, 34, device=self.device)
tensor, _ = _create_data(26, 34, device=self.device)
with self.assertRaisesRegex(TypeError, r"Input tensor should be a float tensor"):
fn(tensor)
......@@ -506,7 +514,7 @@ class Tester(TransformsTester):
def test_linear_transformation(self):
c, h, w = 3, 24, 32
tensor, _ = self._create_data(h, w, channels=c, device=self.device)
tensor, _ = _create_data(h, w, channels=c, device=self.device)
matrix = torch.rand(c * h * w, c * h * w, device=self.device)
mean_vector = torch.rand(c * h * w, device=self.device)
......@@ -529,7 +537,7 @@ class Tester(TransformsTester):
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
def test_compose(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor, _ = _create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.Compose([
......@@ -552,7 +560,7 @@ class Tester(TransformsTester):
torch.jit.script(t)
def test_random_apply(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor, _ = _create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.RandomApply([
......@@ -620,7 +628,7 @@ class Tester(TransformsTester):
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
random_erasing(img)
tensor, _ = self._create_data(24, 32, channels=3, device=self.device)
tensor, _ = _create_data(24, 32, channels=3, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
test_configs = [
......@@ -640,7 +648,7 @@ class Tester(TransformsTester):
scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))
def test_convert_image_dtype(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor, _ = _create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
for in_dtype in int_dtypes() + float_dtypes():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment