"vscode:/vscode.git/clone" did not exist on "516d988d3cd474b3bed2d4df72bd6fb2e4a78ebb"
Unverified Commit 31d53367 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove TransformsTester (#3946)

parent d1f1a544
......@@ -325,54 +325,6 @@ def freeze_rng_state():
torch.set_rng_state(rng_state)
class TransformsTester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"):
batch_tensor = torch.randint(
0, 256,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
)
return batch_tensor
def compareTensorToPIL(self, tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
assert_equal(tensor.cpu(), pil_tensor, check_stride=False, msg=msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
allowed_percentage_diff=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
if allowed_percentage_diff is not None:
# Assert that less than a given %age of pixels are different
self.assertTrue(
(tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
)
# error value can be mean absolute error, max abs error
# Convert to float to avoid underflow when computing absolute difference
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
self.assertTrue(
err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
def cycle_over(objs):
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]
......@@ -457,3 +409,65 @@ def cpu_only(test_func):
return pytest.mark.dont_collect(test_func)
else:
return test_func
def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
batch_tensor = torch.randint(
0, 256,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
)
return batch_tensor
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
assert_equal(tensor.cpu(), pil_tensor, check_stride=False, msg=msg)
def _assert_approx_equal_tensor_to_pil(tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
allowed_percentage_diff=None):
# TODO: we could just merge this into _assert_equal_tensor_to_pil
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
if allowed_percentage_diff is not None:
# Assert that less than a given %age of pixels are different
assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
# error value can be mean absolute error, max abs error
# Convert to float to avoid underflow when computing absolute difference
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
assert err < tol
def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
transformed_batch = fn(batch_tensors, **fn_kwargs)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
assert_equal(transformed_img, transformed_batch[i, ...])
if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
......@@ -14,7 +14,15 @@ import torchvision.transforms.functional as F
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from common_utils import TransformsTester, cpu_and_gpu, needs_cuda
from common_utils import (
cpu_and_gpu,
needs_cuda,
_create_data,
_create_data_batch,
_assert_equal_tensor_to_pil,
_assert_approx_equal_tensor_to_pil,
_test_fn_on_batch,
)
from _assert_utils import assert_equal
from typing import Dict, List, Sequence, Tuple
......@@ -23,31 +31,11 @@ from typing import Dict, List, Sequence, Tuple
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
@pytest.fixture(scope='module')
def tester():
# instanciation of the Tester class used for equality assertions and other utilities
# TODO: remove this eventually when we don't need the class anymore
return Tester()
class Tester(TransformsTester):
class Tester(unittest.TestCase):
def setUp(self):
self.device = "cpu"
def _test_fn_on_batch(self, batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
transformed_batch = fn(batch_tensors, **fn_kwargs)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
assert_equal(transformed_img, transformed_batch[i, ...])
if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
def test_assert_image_tensor(self):
shape = (100,)
tensor = torch.rand(*shape, dtype=torch.float, device=self.device)
......@@ -73,37 +61,37 @@ class Tester(TransformsTester):
def test_vflip(self):
script_vflip = torch.jit.script(F.vflip)
img_tensor, pil_img = self._create_data(16, 18, device=self.device)
img_tensor, pil_img = _create_data(16, 18, device=self.device)
vflipped_img = F.vflip(img_tensor)
vflipped_pil_img = F.vflip(pil_img)
self.compareTensorToPIL(vflipped_img, vflipped_pil_img)
_assert_equal_tensor_to_pil(vflipped_img, vflipped_pil_img)
# scriptable function test
vflipped_img_script = script_vflip(img_tensor)
assert_equal(vflipped_img, vflipped_img_script)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.vflip)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.vflip)
def test_hflip(self):
script_hflip = torch.jit.script(F.hflip)
img_tensor, pil_img = self._create_data(16, 18, device=self.device)
img_tensor, pil_img = _create_data(16, 18, device=self.device)
hflipped_img = F.hflip(img_tensor)
hflipped_pil_img = F.hflip(pil_img)
self.compareTensorToPIL(hflipped_img, hflipped_pil_img)
_assert_equal_tensor_to_pil(hflipped_img, hflipped_pil_img)
# scriptable function test
hflipped_img_script = script_hflip(img_tensor)
assert_equal(hflipped_img, hflipped_img_script)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.hflip)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.hflip)
def test_crop(self):
script_crop = torch.jit.script(F.crop)
img_tensor, pil_img = self._create_data(16, 18, device=self.device)
img_tensor, pil_img = _create_data(16, 18, device=self.device)
test_configs = [
(1, 2, 4, 5), # crop inside top-left corner
......@@ -116,13 +104,13 @@ class Tester(TransformsTester):
pil_img_cropped = F.crop(pil_img, top, left, height, width)
img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
_assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)
img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
_assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
def test_hsv2rgb(self):
scripted_fn = torch.jit.script(F_t._hsv2rgb)
......@@ -146,8 +134,8 @@ class Tester(TransformsTester):
s_rgb_img = scripted_fn(hsv_img)
torch.testing.assert_close(rgb_img, s_rgb_img)
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=self.device).float()
_test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
def test_rgb2hsv(self):
scripted_fn = torch.jit.script(F_t._rgb2hsv)
......@@ -179,58 +167,58 @@ class Tester(TransformsTester):
s_hsv_img = scripted_fn(rgb_img)
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=self.device).float()
_test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
img_tensor, pil_img = self._create_data(32, 34, device=self.device)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
for num_output_channels in (3, 1):
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
_assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
assert_equal(s_gray_tensor, gray_tensor)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
def test_center_crop(self):
script_center_crop = torch.jit.script(F.center_crop)
img_tensor, pil_img = self._create_data(32, 34, device=self.device)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_image = F.center_crop(pil_img, [10, 11])
cropped_tensor = F.center_crop(img_tensor, [10, 11])
self.compareTensorToPIL(cropped_tensor, cropped_pil_image)
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
cropped_tensor = script_center_crop(img_tensor, [10, 11])
self.compareTensorToPIL(cropped_tensor, cropped_pil_image)
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
_test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
def test_five_crop(self):
script_five_crop = torch.jit.script(F.five_crop)
img_tensor, pil_img = self._create_data(32, 34, device=self.device)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_images = F.five_crop(pil_img, [10, 11])
cropped_tensors = F.five_crop(img_tensor, [10, 11])
for i in range(5):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_five_crop(img_tensor, [10, 11])
for i in range(5):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
......@@ -250,19 +238,19 @@ class Tester(TransformsTester):
def test_ten_crop(self):
script_ten_crop = torch.jit.script(F.ten_crop)
img_tensor, pil_img = self._create_data(32, 34, device=self.device)
img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_images = F.ten_crop(pil_img, [10, 11])
cropped_tensors = F.ten_crop(img_tensor, [10, 11])
for i in range(10):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_ten_crop(img_tensor, [10, 11])
for i in range(10):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
......@@ -281,8 +269,8 @@ class Tester(TransformsTester):
def test_pad(self):
script_fn = torch.jit.script(F.pad)
tensor, pil_img = self._create_data(7, 8, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
tensor, pil_img = _create_data(7, 8, device=self.device)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
for dt in [None, torch.float32, torch.float64, torch.float16]:
......@@ -313,7 +301,7 @@ class Tester(TransformsTester):
if pad_tensor_8b.dtype != torch.uint8:
pad_tensor_8b = pad_tensor_8b.to(torch.uint8)
self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))
_assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))
if isinstance(pad, int):
script_pad = [pad, ]
......@@ -322,19 +310,19 @@ class Tester(TransformsTester):
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, kwargs))
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
_test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
def test_resized_crop(self):
# test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity
tensor, _ = self._create_data(26, 36, device=self.device)
tensor, _ = _create_data(26, 36, device=self.device)
for mode in [NEAREST, BILINEAR, BICUBIC]:
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
# 2) resize by half and crop a TL corner
tensor, _ = self._create_data(26, 36, device=self.device)
tensor, _ = _create_data(26, 36, device=self.device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
expected_out_tensor = tensor[:, :20:2, :30:2]
assert_equal(
......@@ -344,8 +332,8 @@ class Tester(TransformsTester):
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]),
)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
self._test_fn_on_batch(
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
_test_fn_on_batch(
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
)
......@@ -447,7 +435,7 @@ class Tester(TransformsTester):
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
self.compareTensorToPIL(out_tensor, out_pil_img)
_assert_equal_tensor_to_pil(out_tensor, out_pil_img)
def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
# 4) Test rotation + translation + scale + share
......@@ -491,7 +479,7 @@ class Tester(TransformsTester):
# Tests on square and rectangular images
scripted_affine = torch.jit.script(F.affine)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
data = [_create_data(26, 26, device=self.device), _create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
for dt in [None, torch.float32, torch.float64, torch.float16]:
......@@ -511,11 +499,11 @@ class Tester(TransformsTester):
self._test_affine_translations(tensor, pil_img, scripted_affine)
self._test_affine_all_ops(tensor, pil_img, scripted_affine)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
self._test_fn_on_batch(
_test_fn_on_batch(
batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]
)
......@@ -580,7 +568,7 @@ class Tester(TransformsTester):
# Tests on square image
scripted_rotate = torch.jit.script(F.rotate)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
data = [_create_data(26, 26, device=self.device), _create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
img_size = pil_img.size
......@@ -601,12 +589,12 @@ class Tester(TransformsTester):
self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
center = (20, 22)
self._test_fn_on_batch(
_test_fn_on_batch(
batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
)
tensor, pil_img = data[0]
......@@ -735,7 +723,7 @@ def _get_data_dims_and_points_for_perspective():
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('fill', (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, )))
@pytest.mark.parametrize('fn', [F.perspective, torch.jit.script(F.perspective)])
def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn, tester):
def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
......@@ -743,7 +731,7 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn, tester
data_dims, (spoints, epoints) = dims_and_points
tensor, pil_img = tester._create_data(*data_dims, device=device)
tensor, pil_img = _create_data(*data_dims, device=device)
if dt is not None:
tensor = tensor.to(dtype=dt)
......@@ -766,7 +754,7 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn, tester
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective())
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
def test_perspective_batch(device, dims_and_points, dt, tester):
def test_perspective_batch(device, dims_and_points, dt):
if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
......@@ -774,28 +762,28 @@ def test_perspective_batch(device, dims_and_points, dt, tester):
data_dims, (spoints, epoints) = dims_and_points
batch_tensors = tester._create_data_batch(*data_dims, num_samples=4, device=device)
batch_tensors = _create_data_batch(*data_dims, num_samples=4, device=device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
# Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at
# the border may be entirely different due to small rounding errors.
scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8
tester._test_fn_on_batch(
_test_fn_on_batch(
batch_tensors, F.perspective, scripted_fn_atol=scripted_fn_atol,
startpoints=spoints, endpoints=epoints, interpolation=NEAREST
)
def test_perspective_interpolation_warning(tester):
def test_perspective_interpolation_warning():
# assert changed type warning
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
tensor = torch.randint(0, 256, (3, 26, 26))
with tester.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
with pytest.warns(UserWarning, match="Argument interpolation should be of type InterpolationMode"):
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
tester.assertTrue(res1.equal(res2))
assert_equal(res1, res2)
@pytest.mark.parametrize('device', cpu_and_gpu())
......@@ -803,7 +791,7 @@ def test_perspective_interpolation_warning(tester):
@pytest.mark.parametrize('size', [32, 26, [32, ], [32, 32], (32, 32), [26, 35]])
@pytest.mark.parametrize('max_size', [None, 34, 40, 1000])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST])
def test_resize(device, dt, size, max_size, interpolation, tester):
def test_resize(device, dt, size, max_size, interpolation):
if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
......@@ -814,8 +802,8 @@ def test_resize(device, dt, size, max_size, interpolation, tester):
torch.manual_seed(12)
script_fn = torch.jit.script(F.resize)
tensor, pil_img = tester._create_data(26, 36, device=device)
batch_tensors = tester._create_data_batch(16, 18, num_samples=4, device=device)
tensor, pil_img = _create_data(26, 36, device=device)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
......@@ -837,7 +825,7 @@ def test_resize(device, dt, size, max_size, interpolation, tester):
resized_tensor_f = resized_tensor_f.to(torch.float)
# Pay attention to high tolerance for MAE
tester.approxEqualTensorToPIL(resized_tensor_f, resized_pil_img, tol=8.0)
_assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0)
if isinstance(size, int):
script_size = [size, ]
......@@ -849,15 +837,15 @@ def test_resize(device, dt, size, max_size, interpolation, tester):
)
assert_equal(resized_tensor, resize_result)
tester._test_fn_on_batch(
_test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
)
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_resize_asserts(device, tester):
def test_resize_asserts(device):
tensor, pil_img = tester._create_data(26, 36, device=device)
tensor, pil_img = _create_data(26, 36, device=device)
# assert changed type warning
with pytest.warns(UserWarning, match=r"Argument interpolation should be of type InterpolationMode"):
......@@ -878,7 +866,7 @@ def test_resize_asserts(device, tester):
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
def test_resize_antialias(device, dt, size, interpolation, tester):
def test_resize_antialias(device, dt, size, interpolation):
if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
......@@ -886,7 +874,7 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
torch.manual_seed(12)
script_fn = torch.jit.script(F.resize)
tensor, pil_img = tester._create_data(320, 290, device=device)
tensor, pil_img = _create_data(320, 290, device=device)
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
......@@ -895,17 +883,14 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
tester.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1],
msg=f"{size}, {interpolation}, {dt}"
)
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float)
tester.approxEqualTensorToPIL(
_assert_approx_equal_tensor_to_pil(
resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}"
)
......@@ -917,7 +902,7 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
# match PIL implementation.
accepted_tol = 15.0
tester.approxEqualTensorToPIL(
_assert_approx_equal_tensor_to_pil(
resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max",
msg=f"{size}, {interpolation}, {dt}"
)
......@@ -928,17 +913,17 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
script_size = size
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True)
tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}")
assert_equal(resized_tensor, resize_result)
@needs_cuda
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
def test_assert_resize_antialias(interpolation, tester):
def test_assert_resize_antialias(interpolation):
# Checks implementation on very large scales
# and catch TORCH_CHECK inside interpolate_aa_kernels.cu
torch.manual_seed(12)
tensor, pil_img = tester._create_data(1000, 1000, device="cuda")
tensor, pil_img = _create_data(1000, 1000, device="cuda")
with pytest.raises(RuntimeError, match=r"Max supported scale factor is"):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)
......@@ -946,12 +931,10 @@ def test_assert_resize_antialias(interpolation, tester):
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):
tester = Tester()
script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = tester._create_data(26, 34, device=device)
batch_tensors = tester._create_data_batch(16, 18, num_samples=4, device=device)
tensor, pil_img = _create_data(26, 34, device=device)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
if dtype is not None:
tensor = F.convert_image_dtype(tensor, dtype)
......@@ -970,7 +953,7 @@ def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype,
# Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
tester.approxEqualTensorToPIL(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method)
_assert_approx_equal_tensor_to_pil(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method)
atol = 1e-6
if out_fn_t.dtype == torch.uint8 and "cuda" in torch.device(device).type:
......@@ -978,7 +961,7 @@ def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype,
assert out_fn_t.allclose(out_scripted, atol=atol)
# FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that.
tester._test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
_test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
@pytest.mark.parametrize('device', cpu_and_gpu())
......
......@@ -9,14 +9,22 @@ import numpy as np
import unittest
from typing import Sequence
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
from common_utils import (
get_tmp_dir,
int_dtypes,
float_dtypes,
_create_data,
_create_data_batch,
_assert_equal_tensor_to_pil,
_assert_approx_equal_tensor_to_pil,
)
from _assert_utils import assert_equal
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
class Tester(TransformsTester):
class Tester(unittest.TestCase):
def setUp(self):
self.device = "cpu"
......@@ -26,13 +34,13 @@ class Tester(TransformsTester):
fn_kwargs = {}
f = getattr(F, func)
tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
tensor, pil_img = _create_data(height=10, width=10, device=self.device)
transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs)
if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
self.approxEqualTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
_assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
torch.manual_seed(12)
......@@ -63,22 +71,22 @@ class Tester(TransformsTester):
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
tensor, pil_img = self._create_data(26, 34, device=self.device)
tensor, pil_img = _create_data(26, 34, device=self.device)
# set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12)
transformed_tensor = f(tensor)
torch.manual_seed(12)
transformed_pil_img = f(pil_img)
if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
_assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script)
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
......@@ -259,13 +267,13 @@ class Tester(TransformsTester):
fn = getattr(F, func)
scripted_fn = torch.jit.script(fn)
tensor, pil_img = self._create_data(height=20, width=20, device=self.device)
tensor, pil_img = _create_data(height=20, width=20, device=self.device)
transformed_t_list = fn(tensor, **fn_kwargs)
transformed_p_list = fn(pil_img, **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_p_list))
self.assertEqual(len(transformed_t_list), out_length)
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
......@@ -284,7 +292,7 @@ class Tester(TransformsTester):
self.assertEqual(len(output), len(transformed_t_list_script))
# test on batch of tensors
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
torch.manual_seed(12)
transformed_batch_list = fn(batch_tensors)
......@@ -350,7 +358,7 @@ class Tester(TransformsTester):
self.assertEqual(y.shape[1], 38)
self.assertEqual(y.shape[2], int(38 * 46 / 32))
tensor, _ = self._create_data(height=34, width=36, device=self.device)
tensor, _ = _create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for dt in [None, torch.float32, torch.float64]:
......@@ -487,7 +495,7 @@ class Tester(TransformsTester):
def test_normalize(self):
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
tensor, _ = self._create_data(26, 34, device=self.device)
tensor, _ = _create_data(26, 34, device=self.device)
with self.assertRaisesRegex(TypeError, r"Input tensor should be a float tensor"):
fn(tensor)
......@@ -506,7 +514,7 @@ class Tester(TransformsTester):
def test_linear_transformation(self):
c, h, w = 3, 24, 32
tensor, _ = self._create_data(h, w, channels=c, device=self.device)
tensor, _ = _create_data(h, w, channels=c, device=self.device)
matrix = torch.rand(c * h * w, c * h * w, device=self.device)
mean_vector = torch.rand(c * h * w, device=self.device)
......@@ -529,7 +537,7 @@ class Tester(TransformsTester):
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
def test_compose(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor, _ = _create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.Compose([
......@@ -552,7 +560,7 @@ class Tester(TransformsTester):
torch.jit.script(t)
def test_random_apply(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor, _ = _create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.RandomApply([
......@@ -620,7 +628,7 @@ class Tester(TransformsTester):
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
random_erasing(img)
tensor, _ = self._create_data(24, 32, channels=3, device=self.device)
tensor, _ = _create_data(24, 32, channels=3, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
test_configs = [
......@@ -640,7 +648,7 @@ class Tester(TransformsTester):
scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))
def test_convert_image_dtype(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor, _ = _create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
for in_dtype in int_dtypes() + float_dtypes():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment