Unverified Commit 964ce1e9 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Prepare test_transforms_tensor.py for porting to pytest (#3976)

parent 44fefe60
...@@ -24,32 +24,15 @@ from _assert_utils import assert_equal ...@@ -24,32 +24,15 @@ from _assert_utils import assert_equal
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
class Tester(unittest.TestCase): def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None):
def setUp(self):
self.device = "cpu"
def _test_functional_op(self, func, fn_kwargs, test_exact_match=True, **match_kwargs):
if fn_kwargs is None:
fn_kwargs = {}
f = getattr(F, func)
tensor, pil_img = _create_data(height=10, width=10, device=self.device)
transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs)
if test_exact_match:
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
_assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
torch.manual_seed(12) torch.manual_seed(12)
out1 = transform(tensor) out1 = transform(tensor)
torch.manual_seed(12) torch.manual_seed(12)
out2 = s_transform(tensor) out2 = s_transform(tensor)
assert_equal(out1, out2, msg=msg) assert_equal(out1, out2, msg=msg)
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
def _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors, msg=None):
torch.manual_seed(12) torch.manual_seed(12)
transformed_batch = transform(batch_tensors) transformed_batch = transform(batch_tensors)
...@@ -63,15 +46,28 @@ class Tester(unittest.TestCase): ...@@ -63,15 +46,28 @@ class Tester(unittest.TestCase):
s_transformed_batch = s_transform(batch_tensors) s_transformed_batch = s_transform(batch_tensors)
assert_equal(transformed_batch, s_transformed_batch, msg=msg) assert_equal(transformed_batch, s_transformed_batch, msg=msg)
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None: def _test_functional_op(f, device, fn_kwargs=None, test_exact_match=True, **match_kwargs):
meth_kwargs = {} fn_kwargs = fn_kwargs or {}
tensor, pil_img = _create_data(height=10, width=10, device=device)
transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs)
if test_exact_match:
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
_assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
def _test_class_op(method, device, meth_kwargs=None, test_exact_match=True, **match_kwargs):
# TODO: change the name: it's not a method, it's a class.
meth_kwargs = meth_kwargs or {}
# test for class interface # test for class interface
f = getattr(T, method)(**meth_kwargs) f = method(**meth_kwargs)
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(f)
tensor, pil_img = _create_data(26, 34, device=self.device) tensor, pil_img = _create_data(26, 34, device=device)
# set seed to reproduce the same transformation for tensor and PIL image # set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor = f(tensor) transformed_tensor = f(tensor)
...@@ -86,83 +82,100 @@ class Tester(unittest.TestCase): ...@@ -86,83 +82,100 @@ class Tester(unittest.TestCase):
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script) assert_equal(transformed_tensor, transformed_tensor_script)
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device) batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors) _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_{}.pt".format(method))) scripted_fn.save(os.path.join(tmp_dir, f"t_{method.__name__}.pt"))
def _test_op(func, method, device, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
_test_functional_op(func, device, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
_test_class_op(method, device, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
self._test_functional_op(func, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs) class Tester(unittest.TestCase):
self._test_class_op(method, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
def setUp(self):
self.device = "cpu"
def test_random_horizontal_flip(self): def test_random_horizontal_flip(self):
self._test_op('hflip', 'RandomHorizontalFlip') _test_op(F.hflip, T.RandomHorizontalFlip, device=self.device)
def test_random_vertical_flip(self): def test_random_vertical_flip(self):
self._test_op('vflip', 'RandomVerticalFlip') _test_op(F.vflip, T.RandomVerticalFlip, device=self.device)
def test_random_invert(self): def test_random_invert(self):
self._test_op('invert', 'RandomInvert') _test_op(F.invert, T.RandomInvert, device=self.device)
def test_random_posterize(self): def test_random_posterize(self):
fn_kwargs = meth_kwargs = {"bits": 4} fn_kwargs = meth_kwargs = {"bits": 4}
self._test_op( _test_op(
'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs F.posterize, T.RandomPosterize, device=self.device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
) )
def test_random_solarize(self): def test_random_solarize(self):
fn_kwargs = meth_kwargs = {"threshold": 192.0} fn_kwargs = meth_kwargs = {"threshold": 192.0}
self._test_op( _test_op(
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs F.solarize, T.RandomSolarize, device=self.device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
) )
def test_random_adjust_sharpness(self): def test_random_adjust_sharpness(self):
fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0} fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
self._test_op( _test_op(
'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs F.adjust_sharpness, T.RandomAdjustSharpness, device=self.device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
) )
def test_random_autocontrast(self): def test_random_autocontrast(self):
# We check the max abs difference because on some (very rare) pixels, the actual value may be different # We check the max abs difference because on some (very rare) pixels, the actual value may be different
# between PIL and tensors due to floating approximations. # between PIL and tensors due to floating approximations.
self._test_op('autocontrast', 'RandomAutocontrast', test_exact_match=False, agg_method='max', _test_op(
tol=(1 + 1e-5), allowed_percentage_diff=.05) F.autocontrast, T.RandomAutocontrast, device=self.device, test_exact_match=False,
agg_method='max', tol=(1 + 1e-5), allowed_percentage_diff=.05
)
def test_random_equalize(self): def test_random_equalize(self):
self._test_op('equalize', 'RandomEqualize') _test_op(F.equalize, T.RandomEqualize, device=self.device)
def test_color_jitter(self): def test_color_jitter(self):
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]: for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"brightness": f} meth_kwargs = {"brightness": f}
self._test_class_op( _test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
) )
for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]: for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"contrast": f} meth_kwargs = {"contrast": f}
self._test_class_op( _test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
) )
for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]: for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
meth_kwargs = {"saturation": f} meth_kwargs = {"saturation": f}
self._test_class_op( _test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
) )
for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]: for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
meth_kwargs = {"hue": f} meth_kwargs = {"hue": f}
self._test_class_op( _test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max" T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=16.1, agg_method="max"
) )
# All 4 parameters together # All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2} meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
self._test_class_op( _test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max" T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=12.1, agg_method="max"
) )
def test_pad(self): def test_pad(self):
...@@ -170,48 +183,49 @@ class Tester(unittest.TestCase): ...@@ -170,48 +183,49 @@ class Tester(unittest.TestCase):
fill = 127 if m == "constant" else 0 fill = 127 if m == "constant" else 0
for mul in [1, -1]: for mul in [1, -1]:
# Test functional.pad (PIL and Tensor) with padding as single int # Test functional.pad (PIL and Tensor) with padding as single int
self._test_functional_op( _test_functional_op(
"pad", fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m} F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m},
device=self.device
) )
# Test functional.pad and transforms.Pad with padding as [int, ] # Test functional.pad and transforms.Pad with padding as [int, ]
fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m} fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m}
self._test_op( _test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs F.pad, T.Pad, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
# Test functional.pad and transforms.Pad with padding as list # Test functional.pad and transforms.Pad with padding as list
fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m} fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
self._test_op( _test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs F.pad, T.Pad, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
# Test functional.pad and transforms.Pad with padding as tuple # Test functional.pad and transforms.Pad with padding as tuple
fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m} fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
self._test_op( _test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs F.pad, T.Pad, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
def test_crop(self): def test_crop(self):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple # Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, } meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
self._test_op( _test_op(
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs F.crop, T.RandomCrop, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
# Test transforms.functional.crop including outside the image area # Test transforms.functional.crop including outside the image area
fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5} # top fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5} # top
self._test_functional_op('crop', fn_kwargs=fn_kwargs) _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
fn_kwargs = {"top": 1, "left": -3, "height": 4, "width": 5} # left fn_kwargs = {"top": 1, "left": -3, "height": 4, "width": 5} # left
self._test_functional_op('crop', fn_kwargs=fn_kwargs) _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
fn_kwargs = {"top": 7, "left": 3, "height": 4, "width": 5} # bottom fn_kwargs = {"top": 7, "left": 3, "height": 4, "width": 5} # bottom
self._test_functional_op('crop', fn_kwargs=fn_kwargs) _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
fn_kwargs = {"top": 3, "left": 8, "height": 4, "width": 5} # right fn_kwargs = {"top": 3, "left": 8, "height": 4, "width": 5} # right
self._test_functional_op('crop', fn_kwargs=fn_kwargs) _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
fn_kwargs = {"top": -3, "left": -3, "height": 15, "width": 15} # all fn_kwargs = {"top": -3, "left": -3, "height": 15, "width": 15} # all
self._test_functional_op('crop', fn_kwargs=fn_kwargs) _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
sizes = [5, [5, ], [6, 6]] sizes = [5, [5, ], [6, 6]]
padding_configs = [ padding_configs = [
...@@ -226,18 +240,20 @@ class Tester(unittest.TestCase): ...@@ -226,18 +240,20 @@ class Tester(unittest.TestCase):
for padding_config in padding_configs: for padding_config in padding_configs:
config = dict(padding_config) config = dict(padding_config)
config["size"] = size config["size"] = size
self._test_class_op("RandomCrop", config) _test_class_op(T.RandomCrop, self.device, config)
def test_center_crop(self): def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)} fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), } meth_kwargs = {"size": (4, 5), }
self._test_op( _test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs F.center_crop, T.CenterCrop, device=self.device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
) )
fn_kwargs = {"output_size": (5,)} fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5, )} meth_kwargs = {"size": (5, )}
self._test_op( _test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs F.center_crop, T.CenterCrop, device=self.device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
) )
tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=self.device)
# Test torchscript of transforms.CenterCrop with size as int # Test torchscript of transforms.CenterCrop with size as int
...@@ -378,8 +394,8 @@ class Tester(unittest.TestCase): ...@@ -378,8 +394,8 @@ class Tester(unittest.TestCase):
transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size) transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resize.pt")) s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))
...@@ -396,8 +412,8 @@ class Tester(unittest.TestCase): ...@@ -396,8 +412,8 @@ class Tester(unittest.TestCase):
size=size, scale=scale, ratio=ratio, interpolation=interpolation size=size, scale=scale, ratio=ratio, interpolation=interpolation
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt")) s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))
...@@ -410,8 +426,8 @@ class Tester(unittest.TestCase): ...@@ -410,8 +426,8 @@ class Tester(unittest.TestCase):
transform = T.RandomAffine(**kwargs) transform = T.RandomAffine(**kwargs)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
return s_transform return s_transform
...@@ -449,8 +465,8 @@ class Tester(unittest.TestCase): ...@@ -449,8 +465,8 @@ class Tester(unittest.TestCase):
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt")) s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))
...@@ -469,8 +485,8 @@ class Tester(unittest.TestCase): ...@@ -469,8 +485,8 @@ class Tester(unittest.TestCase):
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_perspective.pt")) s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))
...@@ -479,18 +495,21 @@ class Tester(unittest.TestCase): ...@@ -479,18 +495,21 @@ class Tester(unittest.TestCase):
meth_kwargs = {"num_output_channels": 1} meth_kwargs = {"num_output_channels": 1}
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
self._test_class_op( _test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" T.Grayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
) )
meth_kwargs = {"num_output_channels": 3} meth_kwargs = {"num_output_channels": 3}
self._test_class_op( _test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" T.Grayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
) )
meth_kwargs = {} meth_kwargs = {}
self._test_class_op( _test_class_op(
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" T.RandomGrayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
) )
def test_normalize(self): def test_normalize(self):
...@@ -505,8 +524,8 @@ class Tester(unittest.TestCase): ...@@ -505,8 +524,8 @@ class Tester(unittest.TestCase):
# test for class interface # test for class interface
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor) _test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt")) scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
...@@ -522,7 +541,7 @@ class Tester(unittest.TestCase): ...@@ -522,7 +541,7 @@ class Tester(unittest.TestCase):
fn = T.LinearTransformation(matrix, mean_vector) fn = T.LinearTransformation(matrix, mean_vector)
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor) _test_transform_vs_scripted(fn, scripted_fn, tensor)
batch_tensors = torch.rand(4, c, h, w, device=self.device) batch_tensors = torch.rand(4, c, h, w, device=self.device)
# We skip some tests from _test_transform_vs_scripted_on_batch as # We skip some tests from _test_transform_vs_scripted_on_batch as
...@@ -590,34 +609,34 @@ class Tester(unittest.TestCase): ...@@ -590,34 +609,34 @@ class Tester(unittest.TestCase):
def test_gaussian_blur(self): def test_gaussian_blur(self):
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
self._test_class_op( _test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 3, "sigma": 0.75}, T.GaussianBlur, meth_kwargs={"kernel_size": 3, "sigma": 0.75},
test_exact_match=False, agg_method="max", tol=tol test_exact_match=False, device=self.device, agg_method="max", tol=tol
) )
self._test_class_op( _test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]}, T.GaussianBlur, meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
test_exact_match=False, agg_method="max", tol=tol test_exact_match=False, device=self.device, agg_method="max", tol=tol
) )
self._test_class_op( _test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)}, T.GaussianBlur, meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
test_exact_match=False, agg_method="max", tol=tol test_exact_match=False, device=self.device, agg_method="max", tol=tol
) )
self._test_class_op( _test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)}, T.GaussianBlur, meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
test_exact_match=False, agg_method="max", tol=tol test_exact_match=False, device=self.device, agg_method="max", tol=tol
) )
self._test_class_op( _test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)}, T.GaussianBlur, meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
test_exact_match=False, agg_method="max", tol=tol test_exact_match=False, device=self.device, agg_method="max", tol=tol
) )
self._test_class_op( _test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": [23], "sigma": 0.75}, T.GaussianBlur, meth_kwargs={"kernel_size": [23], "sigma": 0.75},
test_exact_match=False, agg_method="max", tol=tol test_exact_match=False, device=self.device, agg_method="max", tol=tol
) )
def test_random_erasing(self): def test_random_erasing(self):
...@@ -641,8 +660,8 @@ class Tester(unittest.TestCase): ...@@ -641,8 +660,8 @@ class Tester(unittest.TestCase):
for config in test_configs: for config in test_configs:
fn = T.RandomErasing(**config) fn = T.RandomErasing(**config)
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor) _test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt")) scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))
...@@ -662,13 +681,13 @@ class Tester(unittest.TestCase): ...@@ -662,13 +681,13 @@ class Tester(unittest.TestCase):
if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \ if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \
(in_dtype == torch.float64 and out_dtype == torch.int64): (in_dtype == torch.float64 and out_dtype == torch.int64):
with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"): with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
self._test_transform_vs_scripted(fn, scripted_fn, in_tensor) _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"): with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors) _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
continue continue
self._test_transform_vs_scripted(fn, scripted_fn, in_tensor) _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors) _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt")) scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))
...@@ -683,8 +702,8 @@ class Tester(unittest.TestCase): ...@@ -683,8 +702,8 @@ class Tester(unittest.TestCase):
transform = T.AutoAugment(policy=policy, fill=fill) transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
for _ in range(25): for _ in range(25):
self._test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
if s_transform is not None: if s_transform is not None:
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment