"torchvision/vscode:/vscode.git/clone" did not exist on "97e21c1094b37ff356841e37fa9d17ab34527919"
Unverified Commit 964ce1e9 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Prepare test_transforms_tensor.py for porting to pytest (#3976)

parent 44fefe60
......@@ -24,145 +24,158 @@ from _assert_utils import assert_equal
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
class Tester(unittest.TestCase):
def setUp(self):
self.device = "cpu"
def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None):
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
assert_equal(out1, out2, msg=msg)
def _test_functional_op(self, func, fn_kwargs, test_exact_match=True, **match_kwargs):
if fn_kwargs is None:
fn_kwargs = {}
f = getattr(F, func)
tensor, pil_img = _create_data(height=10, width=10, device=self.device)
transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs)
if test_exact_match:
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
_assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
def _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors, msg=None):
torch.manual_seed(12)
transformed_batch = transform(batch_tensors)
def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
assert_equal(out1, out2, msg=msg)
transformed_img = transform(img_tensor)
assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
torch.manual_seed(12)
transformed_batch = transform(batch_tensors)
torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors)
assert_equal(transformed_batch, s_transformed_batch, msg=msg)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img = transform(img_tensor)
assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors)
assert_equal(transformed_batch, s_transformed_batch, msg=msg)
def _test_functional_op(f, device, fn_kwargs=None, test_exact_match=True, **match_kwargs):
fn_kwargs = fn_kwargs or {}
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None:
meth_kwargs = {}
tensor, pil_img = _create_data(height=10, width=10, device=device)
transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs)
if test_exact_match:
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
_assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
# test for class interface
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
tensor, pil_img = _create_data(26, 34, device=self.device)
# set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12)
transformed_tensor = f(tensor)
torch.manual_seed(12)
transformed_pil_img = f(pil_img)
if test_exact_match:
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
_assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
def _test_class_op(method, device, meth_kwargs=None, test_exact_match=True, **match_kwargs):
# TODO: change the name: it's not a method, it's a class.
meth_kwargs = meth_kwargs or {}
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script)
# test for class interface
f = method(**meth_kwargs)
scripted_fn = torch.jit.script(f)
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
tensor, pil_img = _create_data(26, 34, device=device)
# set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12)
transformed_tensor = f(tensor)
torch.manual_seed(12)
transformed_pil_img = f(pil_img)
if test_exact_match:
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
_assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_{}.pt".format(method)))
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script)
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
_test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, f"t_{method.__name__}.pt"))
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
self._test_functional_op(func, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
self._test_class_op(method, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
def _test_op(func, method, device, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
_test_functional_op(func, device, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
_test_class_op(method, device, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
class Tester(unittest.TestCase):
def setUp(self):
self.device = "cpu"
def test_random_horizontal_flip(self):
self._test_op('hflip', 'RandomHorizontalFlip')
_test_op(F.hflip, T.RandomHorizontalFlip, device=self.device)
def test_random_vertical_flip(self):
self._test_op('vflip', 'RandomVerticalFlip')
_test_op(F.vflip, T.RandomVerticalFlip, device=self.device)
def test_random_invert(self):
self._test_op('invert', 'RandomInvert')
_test_op(F.invert, T.RandomInvert, device=self.device)
def test_random_posterize(self):
fn_kwargs = meth_kwargs = {"bits": 4}
self._test_op(
'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
_test_op(
F.posterize, T.RandomPosterize, device=self.device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
)
def test_random_solarize(self):
fn_kwargs = meth_kwargs = {"threshold": 192.0}
self._test_op(
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
_test_op(
F.solarize, T.RandomSolarize, device=self.device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
)
def test_random_adjust_sharpness(self):
fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
self._test_op(
'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
_test_op(
F.adjust_sharpness, T.RandomAdjustSharpness, device=self.device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
)
def test_random_autocontrast(self):
# We check the max abs difference because on some (very rare) pixels, the actual value may be different
# between PIL and tensors due to floating approximations.
self._test_op('autocontrast', 'RandomAutocontrast', test_exact_match=False, agg_method='max',
tol=(1 + 1e-5), allowed_percentage_diff=.05)
_test_op(
F.autocontrast, T.RandomAutocontrast, device=self.device, test_exact_match=False,
agg_method='max', tol=(1 + 1e-5), allowed_percentage_diff=.05
)
def test_random_equalize(self):
self._test_op('equalize', 'RandomEqualize')
_test_op(F.equalize, T.RandomEqualize, device=self.device)
def test_color_jitter(self):
tol = 1.0 + 1e-10
for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"brightness": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
_test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
)
for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"contrast": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
_test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
)
for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
meth_kwargs = {"saturation": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
_test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
)
for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
meth_kwargs = {"hue": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max"
_test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=16.1, agg_method="max"
)
# All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max"
_test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=12.1, agg_method="max"
)
def test_pad(self):
......@@ -170,48 +183,49 @@ class Tester(unittest.TestCase):
fill = 127 if m == "constant" else 0
for mul in [1, -1]:
# Test functional.pad (PIL and Tensor) with padding as single int
self._test_functional_op(
"pad", fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}
_test_functional_op(
F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m},
device=self.device
)
# Test functional.pad and transforms.Pad with padding as [int, ]
fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m}
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
_test_op(
F.pad, T.Pad, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as list
fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
_test_op(
F.pad, T.Pad, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as tuple
fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
_test_op(
F.pad, T.Pad, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
def test_crop(self):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
self._test_op(
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
_test_op(
F.crop, T.RandomCrop, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test transforms.functional.crop including outside the image area
fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5} # top
self._test_functional_op('crop', fn_kwargs=fn_kwargs)
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
fn_kwargs = {"top": 1, "left": -3, "height": 4, "width": 5} # left
self._test_functional_op('crop', fn_kwargs=fn_kwargs)
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
fn_kwargs = {"top": 7, "left": 3, "height": 4, "width": 5} # bottom
self._test_functional_op('crop', fn_kwargs=fn_kwargs)
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
fn_kwargs = {"top": 3, "left": 8, "height": 4, "width": 5} # right
self._test_functional_op('crop', fn_kwargs=fn_kwargs)
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
fn_kwargs = {"top": -3, "left": -3, "height": 15, "width": 15} # all
self._test_functional_op('crop', fn_kwargs=fn_kwargs)
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
sizes = [5, [5, ], [6, 6]]
padding_configs = [
......@@ -226,18 +240,20 @@ class Tester(unittest.TestCase):
for padding_config in padding_configs:
config = dict(padding_config)
config["size"] = size
self._test_class_op("RandomCrop", config)
_test_class_op(T.RandomCrop, self.device, config)
def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), }
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
_test_op(
F.center_crop, T.CenterCrop, device=self.device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
)
fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5, )}
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
_test_op(
F.center_crop, T.CenterCrop, device=self.device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
)
tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=self.device)
# Test torchscript of transforms.CenterCrop with size as int
......@@ -378,8 +394,8 @@ class Tester(unittest.TestCase):
transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))
......@@ -396,8 +412,8 @@ class Tester(unittest.TestCase):
size=size, scale=scale, ratio=ratio, interpolation=interpolation
)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))
......@@ -410,8 +426,8 @@ class Tester(unittest.TestCase):
transform = T.RandomAffine(**kwargs)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
return s_transform
......@@ -449,8 +465,8 @@ class Tester(unittest.TestCase):
)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))
......@@ -469,8 +485,8 @@ class Tester(unittest.TestCase):
)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))
......@@ -479,18 +495,21 @@ class Tester(unittest.TestCase):
meth_kwargs = {"num_output_channels": 1}
tol = 1.0 + 1e-10
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
_test_class_op(
T.Grayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
)
meth_kwargs = {"num_output_channels": 3}
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
_test_class_op(
T.Grayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
)
meth_kwargs = {}
self._test_class_op(
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
_test_class_op(
T.RandomGrayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
tol=tol, agg_method="max"
)
def test_normalize(self):
......@@ -505,8 +524,8 @@ class Tester(unittest.TestCase):
# test for class interface
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
_test_transform_vs_scripted(fn, scripted_fn, tensor)
_test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
......@@ -522,7 +541,7 @@ class Tester(unittest.TestCase):
fn = T.LinearTransformation(matrix, mean_vector)
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
_test_transform_vs_scripted(fn, scripted_fn, tensor)
batch_tensors = torch.rand(4, c, h, w, device=self.device)
# We skip some tests from _test_transform_vs_scripted_on_batch as
......@@ -590,34 +609,34 @@ class Tester(unittest.TestCase):
def test_gaussian_blur(self):
tol = 1.0 + 1e-10
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 3, "sigma": 0.75},
test_exact_match=False, agg_method="max", tol=tol
_test_class_op(
T.GaussianBlur, meth_kwargs={"kernel_size": 3, "sigma": 0.75},
test_exact_match=False, device=self.device, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
test_exact_match=False, agg_method="max", tol=tol
_test_class_op(
T.GaussianBlur, meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
test_exact_match=False, device=self.device, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
test_exact_match=False, agg_method="max", tol=tol
_test_class_op(
T.GaussianBlur, meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
test_exact_match=False, device=self.device, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
test_exact_match=False, agg_method="max", tol=tol
_test_class_op(
T.GaussianBlur, meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
test_exact_match=False, device=self.device, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
test_exact_match=False, agg_method="max", tol=tol
_test_class_op(
T.GaussianBlur, meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
test_exact_match=False, device=self.device, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": [23], "sigma": 0.75},
test_exact_match=False, agg_method="max", tol=tol
_test_class_op(
T.GaussianBlur, meth_kwargs={"kernel_size": [23], "sigma": 0.75},
test_exact_match=False, device=self.device, agg_method="max", tol=tol
)
def test_random_erasing(self):
......@@ -641,8 +660,8 @@ class Tester(unittest.TestCase):
for config in test_configs:
fn = T.RandomErasing(**config)
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
_test_transform_vs_scripted(fn, scripted_fn, tensor)
_test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))
......@@ -662,13 +681,13 @@ class Tester(unittest.TestCase):
if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \
(in_dtype == torch.float64 and out_dtype == torch.int64):
with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
_test_transform_vs_scripted(fn, scripted_fn, in_tensor)
with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
_test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
continue
self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
_test_transform_vs_scripted(fn, scripted_fn, in_tensor)
_test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))
......@@ -683,8 +702,8 @@ class Tester(unittest.TestCase):
transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
if s_transform is not None:
with get_tmp_dir() as tmp_dir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment