Unverified Commit 31d53367 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove TransformsTester (#3946)

parent d1f1a544
...@@ -325,54 +325,6 @@ def freeze_rng_state(): ...@@ -325,54 +325,6 @@ def freeze_rng_state():
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
class TransformsTester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"):
batch_tensor = torch.randint(
0, 256,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
)
return batch_tensor
def compareTensorToPIL(self, tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
assert_equal(tensor.cpu(), pil_tensor, check_stride=False, msg=msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
allowed_percentage_diff=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
if allowed_percentage_diff is not None:
# Assert that less than a given %age of pixels are different
self.assertTrue(
(tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
)
# error value can be mean absolute error, max abs error
# Convert to float to avoid underflow when computing absolute difference
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
self.assertTrue(
err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
def cycle_over(objs): def cycle_over(objs):
for idx, obj in enumerate(objs): for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:] yield obj, objs[:idx] + objs[idx + 1:]
...@@ -457,3 +409,65 @@ def cpu_only(test_func): ...@@ -457,3 +409,65 @@ def cpu_only(test_func):
return pytest.mark.dont_collect(test_func) return pytest.mark.dont_collect(test_func)
else: else:
return test_func return test_func
def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
batch_tensor = torch.randint(
0, 256,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
)
return batch_tensor
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
assert_equal(tensor.cpu(), pil_tensor, check_stride=False, msg=msg)
def _assert_approx_equal_tensor_to_pil(tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
allowed_percentage_diff=None):
# TODO: we could just merge this into _assert_equal_tensor_to_pil
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
if allowed_percentage_diff is not None:
# Assert that less than a given %age of pixels are different
assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
# error value can be mean absolute error, max abs error
# Convert to float to avoid underflow when computing absolute difference
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
assert err < tol
def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
transformed_batch = fn(batch_tensors, **fn_kwargs)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
assert_equal(transformed_img, transformed_batch[i, ...])
if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
...@@ -14,7 +14,15 @@ import torchvision.transforms.functional as F ...@@ -14,7 +14,15 @@ import torchvision.transforms.functional as F
import torchvision.transforms as T import torchvision.transforms as T
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from common_utils import TransformsTester, cpu_and_gpu, needs_cuda from common_utils import (
cpu_and_gpu,
needs_cuda,
_create_data,
_create_data_batch,
_assert_equal_tensor_to_pil,
_assert_approx_equal_tensor_to_pil,
_test_fn_on_batch,
)
from _assert_utils import assert_equal from _assert_utils import assert_equal
from typing import Dict, List, Sequence, Tuple from typing import Dict, List, Sequence, Tuple
...@@ -23,31 +31,11 @@ from typing import Dict, List, Sequence, Tuple ...@@ -23,31 +31,11 @@ from typing import Dict, List, Sequence, Tuple
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
@pytest.fixture(scope='module') class Tester(unittest.TestCase):
def tester():
# instanciation of the Tester class used for equality assertions and other utilities
# TODO: remove this eventually when we don't need the class anymore
return Tester()
class Tester(TransformsTester):
def setUp(self): def setUp(self):
self.device = "cpu" self.device = "cpu"
def _test_fn_on_batch(self, batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
transformed_batch = fn(batch_tensors, **fn_kwargs)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
assert_equal(transformed_img, transformed_batch[i, ...])
if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
def test_assert_image_tensor(self): def test_assert_image_tensor(self):
shape = (100,) shape = (100,)
tensor = torch.rand(*shape, dtype=torch.float, device=self.device) tensor = torch.rand(*shape, dtype=torch.float, device=self.device)
...@@ -73,37 +61,37 @@ class Tester(TransformsTester): ...@@ -73,37 +61,37 @@ class Tester(TransformsTester):
def test_vflip(self): def test_vflip(self):
script_vflip = torch.jit.script(F.vflip) script_vflip = torch.jit.script(F.vflip)
img_tensor, pil_img = self._create_data(16, 18, device=self.device) img_tensor, pil_img = _create_data(16, 18, device=self.device)
vflipped_img = F.vflip(img_tensor) vflipped_img = F.vflip(img_tensor)
vflipped_pil_img = F.vflip(pil_img) vflipped_pil_img = F.vflip(pil_img)
self.compareTensorToPIL(vflipped_img, vflipped_pil_img) _assert_equal_tensor_to_pil(vflipped_img, vflipped_pil_img)
# scriptable function test # scriptable function test
vflipped_img_script = script_vflip(img_tensor) vflipped_img_script = script_vflip(img_tensor)
assert_equal(vflipped_img, vflipped_img_script) assert_equal(vflipped_img, vflipped_img_script)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.vflip) _test_fn_on_batch(batch_tensors, F.vflip)
def test_hflip(self): def test_hflip(self):
script_hflip = torch.jit.script(F.hflip) script_hflip = torch.jit.script(F.hflip)
img_tensor, pil_img = self._create_data(16, 18, device=self.device) img_tensor, pil_img = _create_data(16, 18, device=self.device)
hflipped_img = F.hflip(img_tensor) hflipped_img = F.hflip(img_tensor)
hflipped_pil_img = F.hflip(pil_img) hflipped_pil_img = F.hflip(pil_img)
self.compareTensorToPIL(hflipped_img, hflipped_pil_img) _assert_equal_tensor_to_pil(hflipped_img, hflipped_pil_img)
# scriptable function test # scriptable function test
hflipped_img_script = script_hflip(img_tensor) hflipped_img_script = script_hflip(img_tensor)
assert_equal(hflipped_img, hflipped_img_script) assert_equal(hflipped_img, hflipped_img_script)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.hflip) _test_fn_on_batch(batch_tensors, F.hflip)
def test_crop(self): def test_crop(self):
script_crop = torch.jit.script(F.crop) script_crop = torch.jit.script(F.crop)
img_tensor, pil_img = self._create_data(16, 18, device=self.device) img_tensor, pil_img = _create_data(16, 18, device=self.device)
test_configs = [ test_configs = [
(1, 2, 4, 5), # crop inside top-left corner (1, 2, 4, 5), # crop inside top-left corner
...@@ -116,13 +104,13 @@ class Tester(TransformsTester): ...@@ -116,13 +104,13 @@ class Tester(TransformsTester):
pil_img_cropped = F.crop(pil_img, top, left, height, width) pil_img_cropped = F.crop(pil_img, top, left, height, width)
img_tensor_cropped = F.crop(img_tensor, top, left, height, width) img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped) _assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)
img_tensor_cropped = script_crop(img_tensor, top, left, height, width) img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped) _assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width) _test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
def test_hsv2rgb(self): def test_hsv2rgb(self):
scripted_fn = torch.jit.script(F_t._hsv2rgb) scripted_fn = torch.jit.script(F_t._hsv2rgb)
...@@ -146,8 +134,8 @@ class Tester(TransformsTester): ...@@ -146,8 +134,8 @@ class Tester(TransformsTester):
s_rgb_img = scripted_fn(hsv_img) s_rgb_img = scripted_fn(hsv_img)
torch.testing.assert_close(rgb_img, s_rgb_img) torch.testing.assert_close(rgb_img, s_rgb_img)
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float() batch_tensors = _create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb) _test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
def test_rgb2hsv(self): def test_rgb2hsv(self):
scripted_fn = torch.jit.script(F_t._rgb2hsv) scripted_fn = torch.jit.script(F_t._rgb2hsv)
...@@ -179,58 +167,58 @@ class Tester(TransformsTester): ...@@ -179,58 +167,58 @@ class Tester(TransformsTester):
s_hsv_img = scripted_fn(rgb_img) s_hsv_img = scripted_fn(rgb_img)
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7) torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float() batch_tensors = _create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv) _test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
def test_rgb_to_grayscale(self): def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
img_tensor, pil_img = self._create_data(32, 34, device=self.device) img_tensor, pil_img = _create_data(32, 34, device=self.device)
for num_output_channels in (3, 1): for num_output_channels in (3, 1):
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels) gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max") _assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
assert_equal(s_gray_tensor, gray_tensor) assert_equal(s_gray_tensor, gray_tensor)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels) _test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
def test_center_crop(self): def test_center_crop(self):
script_center_crop = torch.jit.script(F.center_crop) script_center_crop = torch.jit.script(F.center_crop)
img_tensor, pil_img = self._create_data(32, 34, device=self.device) img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_image = F.center_crop(pil_img, [10, 11]) cropped_pil_image = F.center_crop(pil_img, [10, 11])
cropped_tensor = F.center_crop(img_tensor, [10, 11]) cropped_tensor = F.center_crop(img_tensor, [10, 11])
self.compareTensorToPIL(cropped_tensor, cropped_pil_image) _assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
cropped_tensor = script_center_crop(img_tensor, [10, 11]) cropped_tensor = script_center_crop(img_tensor, [10, 11])
self.compareTensorToPIL(cropped_tensor, cropped_pil_image) _assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11]) _test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
def test_five_crop(self): def test_five_crop(self):
script_five_crop = torch.jit.script(F.five_crop) script_five_crop = torch.jit.script(F.five_crop)
img_tensor, pil_img = self._create_data(32, 34, device=self.device) img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_images = F.five_crop(pil_img, [10, 11]) cropped_pil_images = F.five_crop(pil_img, [10, 11])
cropped_tensors = F.five_crop(img_tensor, [10, 11]) cropped_tensors = F.five_crop(img_tensor, [10, 11])
for i in range(5): for i in range(5):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_five_crop(img_tensor, [10, 11]) cropped_tensors = script_five_crop(img_tensor, [10, 11])
for i in range(5): for i in range(5):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11]) tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)): for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...] img_tensor = batch_tensors[i, ...]
...@@ -250,19 +238,19 @@ class Tester(TransformsTester): ...@@ -250,19 +238,19 @@ class Tester(TransformsTester):
def test_ten_crop(self): def test_ten_crop(self):
script_ten_crop = torch.jit.script(F.ten_crop) script_ten_crop = torch.jit.script(F.ten_crop)
img_tensor, pil_img = self._create_data(32, 34, device=self.device) img_tensor, pil_img = _create_data(32, 34, device=self.device)
cropped_pil_images = F.ten_crop(pil_img, [10, 11]) cropped_pil_images = F.ten_crop(pil_img, [10, 11])
cropped_tensors = F.ten_crop(img_tensor, [10, 11]) cropped_tensors = F.ten_crop(img_tensor, [10, 11])
for i in range(10): for i in range(10):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
cropped_tensors = script_ten_crop(img_tensor, [10, 11]) cropped_tensors = script_ten_crop(img_tensor, [10, 11])
for i in range(10): for i in range(10):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11]) tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)): for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...] img_tensor = batch_tensors[i, ...]
...@@ -281,8 +269,8 @@ class Tester(TransformsTester): ...@@ -281,8 +269,8 @@ class Tester(TransformsTester):
def test_pad(self): def test_pad(self):
script_fn = torch.jit.script(F.pad) script_fn = torch.jit.script(F.pad)
tensor, pil_img = self._create_data(7, 8, device=self.device) tensor, pil_img = _create_data(7, 8, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=self.device)
for dt in [None, torch.float32, torch.float64, torch.float16]: for dt in [None, torch.float32, torch.float64, torch.float16]:
...@@ -313,7 +301,7 @@ class Tester(TransformsTester): ...@@ -313,7 +301,7 @@ class Tester(TransformsTester):
if pad_tensor_8b.dtype != torch.uint8: if pad_tensor_8b.dtype != torch.uint8:
pad_tensor_8b = pad_tensor_8b.to(torch.uint8) pad_tensor_8b = pad_tensor_8b.to(torch.uint8)
self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs)) _assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))
if isinstance(pad, int): if isinstance(pad, int):
script_pad = [pad, ] script_pad = [pad, ]
...@@ -322,19 +310,19 @@ class Tester(TransformsTester): ...@@ -322,19 +310,19 @@ class Tester(TransformsTester):
pad_tensor_script = script_fn(tensor, script_pad, **kwargs) pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, kwargs)) assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, kwargs))
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs) _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
def test_resized_crop(self): def test_resized_crop(self):
# test values of F.resized_crop in several cases: # test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity # 1) resize to the same size, crop to the same size => should be identity
tensor, _ = self._create_data(26, 36, device=self.device) tensor, _ = _create_data(26, 36, device=self.device)
for mode in [NEAREST, BILINEAR, BICUBIC]: for mode in [NEAREST, BILINEAR, BICUBIC]:
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode) out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
# 2) resize by half and crop a TL corner # 2) resize by half and crop a TL corner
tensor, _ = self._create_data(26, 36, device=self.device) tensor, _ = _create_data(26, 36, device=self.device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST) out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
expected_out_tensor = tensor[:, :20:2, :30:2] expected_out_tensor = tensor[:, :20:2, :30:2]
assert_equal( assert_equal(
...@@ -344,8 +332,8 @@ class Tester(TransformsTester): ...@@ -344,8 +332,8 @@ class Tester(TransformsTester):
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]), msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]),
) )
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
self._test_fn_on_batch( _test_fn_on_batch(
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
) )
...@@ -447,7 +435,7 @@ class Tester(TransformsTester): ...@@ -447,7 +435,7 @@ class Tester(TransformsTester):
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
self.compareTensorToPIL(out_tensor, out_pil_img) _assert_equal_tensor_to_pil(out_tensor, out_pil_img)
def _test_affine_all_ops(self, tensor, pil_img, scripted_affine): def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
# 4) Test rotation + translation + scale + share # 4) Test rotation + translation + scale + share
...@@ -491,7 +479,7 @@ class Tester(TransformsTester): ...@@ -491,7 +479,7 @@ class Tester(TransformsTester):
# Tests on square and rectangular images # Tests on square and rectangular images
scripted_affine = torch.jit.script(F.affine) scripted_affine = torch.jit.script(F.affine)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)] data = [_create_data(26, 26, device=self.device), _create_data(32, 26, device=self.device)]
for tensor, pil_img in data: for tensor, pil_img in data:
for dt in [None, torch.float32, torch.float64, torch.float16]: for dt in [None, torch.float32, torch.float64, torch.float16]:
...@@ -511,11 +499,11 @@ class Tester(TransformsTester): ...@@ -511,11 +499,11 @@ class Tester(TransformsTester):
self._test_affine_translations(tensor, pil_img, scripted_affine) self._test_affine_translations(tensor, pil_img, scripted_affine)
self._test_affine_all_ops(tensor, pil_img, scripted_affine) self._test_affine_all_ops(tensor, pil_img, scripted_affine)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None: if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt) batch_tensors = batch_tensors.to(dtype=dt)
self._test_fn_on_batch( _test_fn_on_batch(
batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0] batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]
) )
...@@ -580,7 +568,7 @@ class Tester(TransformsTester): ...@@ -580,7 +568,7 @@ class Tester(TransformsTester):
# Tests on square image # Tests on square image
scripted_rotate = torch.jit.script(F.rotate) scripted_rotate = torch.jit.script(F.rotate)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)] data = [_create_data(26, 26, device=self.device), _create_data(32, 26, device=self.device)]
for tensor, pil_img in data: for tensor, pil_img in data:
img_size = pil_img.size img_size = pil_img.size
...@@ -601,12 +589,12 @@ class Tester(TransformsTester): ...@@ -601,12 +589,12 @@ class Tester(TransformsTester):
self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers) self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None: if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt) batch_tensors = batch_tensors.to(dtype=dt)
center = (20, 22) center = (20, 22)
self._test_fn_on_batch( _test_fn_on_batch(
batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
) )
tensor, pil_img = data[0] tensor, pil_img = data[0]
...@@ -735,7 +723,7 @@ def _get_data_dims_and_points_for_perspective(): ...@@ -735,7 +723,7 @@ def _get_data_dims_and_points_for_perspective():
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('fill', (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, ))) @pytest.mark.parametrize('fill', (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, )))
@pytest.mark.parametrize('fn', [F.perspective, torch.jit.script(F.perspective)]) @pytest.mark.parametrize('fn', [F.perspective, torch.jit.script(F.perspective)])
def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn, tester): def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case # skip float16 on CPU case
...@@ -743,7 +731,7 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn, tester ...@@ -743,7 +731,7 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn, tester
data_dims, (spoints, epoints) = dims_and_points data_dims, (spoints, epoints) = dims_and_points
tensor, pil_img = tester._create_data(*data_dims, device=device) tensor, pil_img = _create_data(*data_dims, device=device)
if dt is not None: if dt is not None:
tensor = tensor.to(dtype=dt) tensor = tensor.to(dtype=dt)
...@@ -766,7 +754,7 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn, tester ...@@ -766,7 +754,7 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn, tester
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective()) @pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective())
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
def test_perspective_batch(device, dims_and_points, dt, tester): def test_perspective_batch(device, dims_and_points, dt):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case # skip float16 on CPU case
...@@ -774,28 +762,28 @@ def test_perspective_batch(device, dims_and_points, dt, tester): ...@@ -774,28 +762,28 @@ def test_perspective_batch(device, dims_and_points, dt, tester):
data_dims, (spoints, epoints) = dims_and_points data_dims, (spoints, epoints) = dims_and_points
batch_tensors = tester._create_data_batch(*data_dims, num_samples=4, device=device) batch_tensors = _create_data_batch(*data_dims, num_samples=4, device=device)
if dt is not None: if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt) batch_tensors = batch_tensors.to(dtype=dt)
# Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at # Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at
# the border may be entirely different due to small rounding errors. # the border may be entirely different due to small rounding errors.
scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8 scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8
tester._test_fn_on_batch( _test_fn_on_batch(
batch_tensors, F.perspective, scripted_fn_atol=scripted_fn_atol, batch_tensors, F.perspective, scripted_fn_atol=scripted_fn_atol,
startpoints=spoints, endpoints=epoints, interpolation=NEAREST startpoints=spoints, endpoints=epoints, interpolation=NEAREST
) )
def test_perspective_interpolation_warning(tester): def test_perspective_interpolation_warning():
# assert changed type warning # assert changed type warning
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]] spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]] epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
tensor = torch.randint(0, 256, (3, 26, 26)) tensor = torch.randint(0, 256, (3, 26, 26))
with tester.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): with pytest.warns(UserWarning, match="Argument interpolation should be of type InterpolationMode"):
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2) res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR) res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
tester.assertTrue(res1.equal(res2)) assert_equal(res1, res2)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
...@@ -803,7 +791,7 @@ def test_perspective_interpolation_warning(tester): ...@@ -803,7 +791,7 @@ def test_perspective_interpolation_warning(tester):
@pytest.mark.parametrize('size', [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]) @pytest.mark.parametrize('size', [32, 26, [32, ], [32, 32], (32, 32), [26, 35]])
@pytest.mark.parametrize('max_size', [None, 34, 40, 1000]) @pytest.mark.parametrize('max_size', [None, 34, 40, 1000])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST]) @pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST])
def test_resize(device, dt, size, max_size, interpolation, tester): def test_resize(device, dt, size, max_size, interpolation):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case # skip float16 on CPU case
...@@ -814,8 +802,8 @@ def test_resize(device, dt, size, max_size, interpolation, tester): ...@@ -814,8 +802,8 @@ def test_resize(device, dt, size, max_size, interpolation, tester):
torch.manual_seed(12) torch.manual_seed(12)
script_fn = torch.jit.script(F.resize) script_fn = torch.jit.script(F.resize)
tensor, pil_img = tester._create_data(26, 36, device=device) tensor, pil_img = _create_data(26, 36, device=device)
batch_tensors = tester._create_data_batch(16, 18, num_samples=4, device=device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
if dt is not None: if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases # This is a trivial cast to float of uint8 data to test all cases
...@@ -837,7 +825,7 @@ def test_resize(device, dt, size, max_size, interpolation, tester): ...@@ -837,7 +825,7 @@ def test_resize(device, dt, size, max_size, interpolation, tester):
resized_tensor_f = resized_tensor_f.to(torch.float) resized_tensor_f = resized_tensor_f.to(torch.float)
# Pay attention to high tolerance for MAE # Pay attention to high tolerance for MAE
tester.approxEqualTensorToPIL(resized_tensor_f, resized_pil_img, tol=8.0) _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0)
if isinstance(size, int): if isinstance(size, int):
script_size = [size, ] script_size = [size, ]
...@@ -849,15 +837,15 @@ def test_resize(device, dt, size, max_size, interpolation, tester): ...@@ -849,15 +837,15 @@ def test_resize(device, dt, size, max_size, interpolation, tester):
) )
assert_equal(resized_tensor, resize_result) assert_equal(resized_tensor, resize_result)
tester._test_fn_on_batch( _test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
def test_resize_asserts(device, tester): def test_resize_asserts(device):
tensor, pil_img = tester._create_data(26, 36, device=device) tensor, pil_img = _create_data(26, 36, device=device)
# assert changed type warning # assert changed type warning
with pytest.warns(UserWarning, match=r"Argument interpolation should be of type InterpolationMode"): with pytest.warns(UserWarning, match=r"Argument interpolation should be of type InterpolationMode"):
...@@ -878,7 +866,7 @@ def test_resize_asserts(device, tester): ...@@ -878,7 +866,7 @@ def test_resize_asserts(device, tester):
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]]) @pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) @pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
def test_resize_antialias(device, dt, size, interpolation, tester): def test_resize_antialias(device, dt, size, interpolation):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case # skip float16 on CPU case
...@@ -886,7 +874,7 @@ def test_resize_antialias(device, dt, size, interpolation, tester): ...@@ -886,7 +874,7 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
torch.manual_seed(12) torch.manual_seed(12)
script_fn = torch.jit.script(F.resize) script_fn = torch.jit.script(F.resize)
tensor, pil_img = tester._create_data(320, 290, device=device) tensor, pil_img = _create_data(320, 290, device=device)
if dt is not None: if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases # This is a trivial cast to float of uint8 data to test all cases
...@@ -895,17 +883,14 @@ def test_resize_antialias(device, dt, size, interpolation, tester): ...@@ -895,17 +883,14 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True) resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
tester.assertEqual( assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
resized_tensor.size()[1:], resized_pil_img.size[::-1],
msg=f"{size}, {interpolation}, {dt}"
)
resized_tensor_f = resized_tensor resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image # we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8: if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float) resized_tensor_f = resized_tensor_f.to(torch.float)
tester.approxEqualTensorToPIL( _assert_approx_equal_tensor_to_pil(
resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}" resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}"
) )
...@@ -917,7 +902,7 @@ def test_resize_antialias(device, dt, size, interpolation, tester): ...@@ -917,7 +902,7 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
# match PIL implementation. # match PIL implementation.
accepted_tol = 15.0 accepted_tol = 15.0
tester.approxEqualTensorToPIL( _assert_approx_equal_tensor_to_pil(
resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max",
msg=f"{size}, {interpolation}, {dt}" msg=f"{size}, {interpolation}, {dt}"
) )
...@@ -928,17 +913,17 @@ def test_resize_antialias(device, dt, size, interpolation, tester): ...@@ -928,17 +913,17 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
script_size = size script_size = size
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True) resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True)
tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}") assert_equal(resized_tensor, resize_result)
@needs_cuda @needs_cuda
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) @pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
def test_assert_resize_antialias(interpolation, tester): def test_assert_resize_antialias(interpolation):
# Checks implementation on very large scales # Checks implementation on very large scales
# and catch TORCH_CHECK inside interpolate_aa_kernels.cu # and catch TORCH_CHECK inside interpolate_aa_kernels.cu
torch.manual_seed(12) torch.manual_seed(12)
tensor, pil_img = tester._create_data(1000, 1000, device="cuda") tensor, pil_img = _create_data(1000, 1000, device="cuda")
with pytest.raises(RuntimeError, match=r"Max supported scale factor is"): with pytest.raises(RuntimeError, match=r"Max supported scale factor is"):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)
...@@ -946,12 +931,10 @@ def test_assert_resize_antialias(interpolation, tester): ...@@ -946,12 +931,10 @@ def test_assert_resize_antialias(interpolation, tester):
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"): def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):
tester = Tester()
script_fn = torch.jit.script(fn) script_fn = torch.jit.script(fn)
torch.manual_seed(15) torch.manual_seed(15)
tensor, pil_img = tester._create_data(26, 34, device=device) tensor, pil_img = _create_data(26, 34, device=device)
batch_tensors = tester._create_data_batch(16, 18, num_samples=4, device=device) batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
if dtype is not None: if dtype is not None:
tensor = F.convert_image_dtype(tensor, dtype) tensor = F.convert_image_dtype(tensor, dtype)
...@@ -970,7 +953,7 @@ def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, ...@@ -970,7 +953,7 @@ def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype,
# Check that max difference does not exceed 2 in [0, 255] range # Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
tester.approxEqualTensorToPIL(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method) _assert_approx_equal_tensor_to_pil(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method)
atol = 1e-6 atol = 1e-6
if out_fn_t.dtype == torch.uint8 and "cuda" in torch.device(device).type: if out_fn_t.dtype == torch.uint8 and "cuda" in torch.device(device).type:
...@@ -978,7 +961,7 @@ def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, ...@@ -978,7 +961,7 @@ def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype,
assert out_fn_t.allclose(out_scripted, atol=atol) assert out_fn_t.allclose(out_scripted, atol=atol)
# FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that. # FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that.
tester._test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config) _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
......
...@@ -9,14 +9,22 @@ import numpy as np ...@@ -9,14 +9,22 @@ import numpy as np
import unittest import unittest
from typing import Sequence from typing import Sequence
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes from common_utils import (
get_tmp_dir,
int_dtypes,
float_dtypes,
_create_data,
_create_data_batch,
_assert_equal_tensor_to_pil,
_assert_approx_equal_tensor_to_pil,
)
from _assert_utils import assert_equal from _assert_utils import assert_equal
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
class Tester(TransformsTester): class Tester(unittest.TestCase):
def setUp(self): def setUp(self):
self.device = "cpu" self.device = "cpu"
...@@ -26,13 +34,13 @@ class Tester(TransformsTester): ...@@ -26,13 +34,13 @@ class Tester(TransformsTester):
fn_kwargs = {} fn_kwargs = {}
f = getattr(F, func) f = getattr(F, func)
tensor, pil_img = self._create_data(height=10, width=10, device=self.device) tensor, pil_img = _create_data(height=10, width=10, device=self.device)
transformed_tensor = f(tensor, **fn_kwargs) transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs) transformed_pil_img = f(pil_img, **fn_kwargs)
if test_exact_match: if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs) _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else: else:
self.approxEqualTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs) _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None): def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
torch.manual_seed(12) torch.manual_seed(12)
...@@ -63,22 +71,22 @@ class Tester(TransformsTester): ...@@ -63,22 +71,22 @@ class Tester(TransformsTester):
f = getattr(T, method)(**meth_kwargs) f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(f)
tensor, pil_img = self._create_data(26, 34, device=self.device) tensor, pil_img = _create_data(26, 34, device=self.device)
# set seed to reproduce the same transformation for tensor and PIL image # set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor = f(tensor) transformed_tensor = f(tensor)
torch.manual_seed(12) torch.manual_seed(12)
transformed_pil_img = f(pil_img) transformed_pil_img = f(pil_img)
if test_exact_match: if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs) _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else: else:
self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs) _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script) assert_equal(transformed_tensor, transformed_tensor_script)
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device) batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors) self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
...@@ -259,13 +267,13 @@ class Tester(TransformsTester): ...@@ -259,13 +267,13 @@ class Tester(TransformsTester):
fn = getattr(F, func) fn = getattr(F, func)
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
tensor, pil_img = self._create_data(height=20, width=20, device=self.device) tensor, pil_img = _create_data(height=20, width=20, device=self.device)
transformed_t_list = fn(tensor, **fn_kwargs) transformed_t_list = fn(tensor, **fn_kwargs)
transformed_p_list = fn(pil_img, **fn_kwargs) transformed_p_list = fn(pil_img, **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_p_list)) self.assertEqual(len(transformed_t_list), len(transformed_p_list))
self.assertEqual(len(transformed_t_list), out_length) self.assertEqual(len(transformed_t_list), out_length)
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list): for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
self.compareTensorToPIL(transformed_tensor, transformed_pil_img) _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs) transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script)) self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
...@@ -284,7 +292,7 @@ class Tester(TransformsTester): ...@@ -284,7 +292,7 @@ class Tester(TransformsTester):
self.assertEqual(len(output), len(transformed_t_list_script)) self.assertEqual(len(output), len(transformed_t_list_script))
# test on batch of tensors # test on batch of tensors
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device) batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
torch.manual_seed(12) torch.manual_seed(12)
transformed_batch_list = fn(batch_tensors) transformed_batch_list = fn(batch_tensors)
...@@ -350,7 +358,7 @@ class Tester(TransformsTester): ...@@ -350,7 +358,7 @@ class Tester(TransformsTester):
self.assertEqual(y.shape[1], 38) self.assertEqual(y.shape[1], 38)
self.assertEqual(y.shape[2], int(38 * 46 / 32)) self.assertEqual(y.shape[2], int(38 * 46 / 32))
tensor, _ = self._create_data(height=34, width=36, device=self.device) tensor, _ = _create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for dt in [None, torch.float32, torch.float64]: for dt in [None, torch.float32, torch.float64]:
...@@ -487,7 +495,7 @@ class Tester(TransformsTester): ...@@ -487,7 +495,7 @@ class Tester(TransformsTester):
def test_normalize(self): def test_normalize(self):
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
tensor, _ = self._create_data(26, 34, device=self.device) tensor, _ = _create_data(26, 34, device=self.device)
with self.assertRaisesRegex(TypeError, r"Input tensor should be a float tensor"): with self.assertRaisesRegex(TypeError, r"Input tensor should be a float tensor"):
fn(tensor) fn(tensor)
...@@ -506,7 +514,7 @@ class Tester(TransformsTester): ...@@ -506,7 +514,7 @@ class Tester(TransformsTester):
def test_linear_transformation(self): def test_linear_transformation(self):
c, h, w = 3, 24, 32 c, h, w = 3, 24, 32
tensor, _ = self._create_data(h, w, channels=c, device=self.device) tensor, _ = _create_data(h, w, channels=c, device=self.device)
matrix = torch.rand(c * h * w, c * h * w, device=self.device) matrix = torch.rand(c * h * w, c * h * w, device=self.device)
mean_vector = torch.rand(c * h * w, device=self.device) mean_vector = torch.rand(c * h * w, device=self.device)
...@@ -529,7 +537,7 @@ class Tester(TransformsTester): ...@@ -529,7 +537,7 @@ class Tester(TransformsTester):
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt")) scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
def test_compose(self): def test_compose(self):
tensor, _ = self._create_data(26, 34, device=self.device) tensor, _ = _create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0 tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.Compose([ transforms = T.Compose([
...@@ -552,7 +560,7 @@ class Tester(TransformsTester): ...@@ -552,7 +560,7 @@ class Tester(TransformsTester):
torch.jit.script(t) torch.jit.script(t)
def test_random_apply(self): def test_random_apply(self):
tensor, _ = self._create_data(26, 34, device=self.device) tensor, _ = _create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0 tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.RandomApply([ transforms = T.RandomApply([
...@@ -620,7 +628,7 @@ class Tester(TransformsTester): ...@@ -620,7 +628,7 @@ class Tester(TransformsTester):
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"): with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
random_erasing(img) random_erasing(img)
tensor, _ = self._create_data(24, 32, channels=3, device=self.device) tensor, _ = _create_data(24, 32, channels=3, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device) batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
test_configs = [ test_configs = [
...@@ -640,7 +648,7 @@ class Tester(TransformsTester): ...@@ -640,7 +648,7 @@ class Tester(TransformsTester):
scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt")) scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))
def test_convert_image_dtype(self): def test_convert_image_dtype(self):
tensor, _ = self._create_data(26, 34, device=self.device) tensor, _ = _create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device) batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
for in_dtype in int_dtypes() + float_dtypes(): for in_dtype in int_dtypes() + float_dtypes():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment