Unverified Commit 32bccc53 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Port _test_adjust_fn to pytest (#3845)

parent 0fece1f7
...@@ -324,85 +324,6 @@ class Tester(TransformsTester): ...@@ -324,85 +324,6 @@ class Tester(TransformsTester):
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs) self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
dts=(None, torch.float32, torch.float64)):
script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = self._create_data(26, 34, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
for dt in dts:
if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)
batch_tensors = F.convert_image_dtype(batch_tensors, dt)
for config in configs:
adjusted_tensor = fn_t(tensor, **config)
adjusted_pil = fn_pil(pil_img, **config)
scripted_result = script_fn(tensor, **config)
msg = "{}, {}".format(dt, config)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype, msg=msg)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1], msg=msg)
rbg_tensor = adjusted_tensor
if adjusted_tensor.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)
# Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method)
atol = 1e-6
if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type:
atol = 1.0
self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)
self._test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
def test_adjust_brightness(self):
self._test_adjust_fn(
F.adjust_brightness,
F_pil.adjust_brightness,
F_t.adjust_brightness,
[{"brightness_factor": f} for f in [0.1, 0.5, 1.0, 1.34, 2.5]]
)
def test_adjust_contrast(self):
self._test_adjust_fn(
F.adjust_contrast,
F_pil.adjust_contrast,
F_t.adjust_contrast,
[{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)
def test_adjust_saturation(self):
self._test_adjust_fn(
F.adjust_saturation,
F_pil.adjust_saturation,
F_t.adjust_saturation,
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
)
def test_adjust_hue(self):
self._test_adjust_fn(
F.adjust_hue,
F_pil.adjust_hue,
F_t.adjust_hue,
[{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
tol=16.1,
agg_method="max"
)
def test_adjust_gamma(self):
self._test_adjust_fn(
F.adjust_gamma,
F_pil.adjust_gamma,
F_t.adjust_gamma,
[{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]
)
def test_resize(self): def test_resize(self):
script_fn = torch.jit.script(F.resize) script_fn = torch.jit.script(F.resize)
tensor, pil_img = self._create_data(26, 36, device=self.device) tensor, pil_img = self._create_data(26, 36, device=self.device)
...@@ -833,77 +754,6 @@ class Tester(TransformsTester): ...@@ -833,77 +754,6 @@ class Tester(TransformsTester):
msg="{}, {}".format(ksize, sigma) msg="{}, {}".format(ksize, sigma)
) )
def test_invert(self):
self._test_adjust_fn(
F.invert,
F_pil.invert,
F_t.invert,
[{}],
tol=1.0,
agg_method="max"
)
def test_posterize(self):
self._test_adjust_fn(
F.posterize,
F_pil.posterize,
F_t.posterize,
[{"bits": bits} for bits in range(0, 8)],
tol=1.0,
agg_method="max",
dts=(None,)
)
def test_solarize(self):
self._test_adjust_fn(
F.solarize,
F_pil.solarize,
F_t.solarize,
[{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
tol=1.0,
agg_method="max",
dts=(None,)
)
self._test_adjust_fn(
F.solarize,
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
F_t.solarize,
[{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
tol=1.0,
agg_method="max",
dts=(torch.float32, torch.float64)
)
def test_adjust_sharpness(self):
self._test_adjust_fn(
F.adjust_sharpness,
F_pil.adjust_sharpness,
F_t.adjust_sharpness,
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)
def test_autocontrast(self):
self._test_adjust_fn(
F.autocontrast,
F_pil.autocontrast,
F_t.autocontrast,
[{}],
tol=1.0,
agg_method="max"
)
def test_equalize(self):
torch.set_deterministic(False)
self._test_adjust_fn(
F.equalize,
F_pil.equalize,
F_t.equalize,
[{}],
tol=1.0,
agg_method="max",
dts=(None,)
)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester): class CUDATester(Tester):
...@@ -1074,5 +924,219 @@ def test_resize_antialias(device, dt, size, interpolation, tester): ...@@ -1074,5 +924,219 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}") tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}")
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):
tester = Tester()
script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = tester._create_data(26, 34, device=device)
batch_tensors = tester._create_data_batch(16, 18, num_samples=4, device=device)
if dtype is not None:
tensor = F.convert_image_dtype(tensor, dtype)
batch_tensors = F.convert_image_dtype(batch_tensors, dtype)
out_fn_t = fn_t(tensor, **config)
out_pil = fn_pil(pil_img, **config)
out_scripted = script_fn(tensor, **config)
assert out_fn_t.dtype == out_scripted.dtype
assert out_fn_t.size()[1:] == out_pil.size[::-1]
rbg_tensor = out_fn_t
if out_fn_t.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(out_fn_t, torch.uint8)
# Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
tester.approxEqualTensorToPIL(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method)
atol = 1e-6
if out_fn_t.dtype == torch.uint8 and "cuda" in torch.device(device).type:
atol = 1.0
assert out_fn_t.allclose(out_scripted, atol=atol)
# FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that.
tester._test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)])
def test_adjust_brightness(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_brightness,
F_pil.adjust_brightness,
F_t.adjust_brightness,
config,
device,
dtype,
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
def test_invert(device, dtype):
check_functional_vs_PIL_vs_scripted(
F.invert,
F_pil.invert,
F_t.invert,
{},
device,
dtype,
tol=1.0,
agg_method="max"
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('config', [{"bits": bits} for bits in range(0, 8)])
def test_posterize(device, config):
check_functional_vs_PIL_vs_scripted(
F.posterize,
F_pil.posterize,
F_t.posterize,
config,
device,
dtype=None,
tol=1.0,
agg_method="max",
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]])
def test_solarize1(device, config):
check_functional_vs_PIL_vs_scripted(
F.solarize,
F_pil.solarize,
F_t.solarize,
config,
device,
dtype=None,
tol=1.0,
agg_method="max",
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]])
def test_solarize2(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.solarize,
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
F_t.solarize,
config,
device,
dtype,
tol=1.0,
agg_method="max",
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
def test_adjust_sharpness(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_sharpness,
F_pil.adjust_sharpness,
F_t.adjust_sharpness,
config,
device,
dtype,
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
def test_autocontrast(device, dtype):
check_functional_vs_PIL_vs_scripted(
F.autocontrast,
F_pil.autocontrast,
F_t.autocontrast,
{},
device,
dtype,
tol=1.0,
agg_method="max"
)
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_equalize(device):
torch.set_deterministic(False)
check_functional_vs_PIL_vs_scripted(
F.equalize,
F_pil.equalize,
F_t.equalize,
{},
device,
dtype=None,
tol=1.0,
agg_method="max",
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
def test_adjust_contrast(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_contrast,
F_pil.adjust_contrast,
F_t.adjust_contrast,
config,
device,
dtype
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]])
def test_adjust_saturation(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_saturation,
F_pil.adjust_saturation,
F_t.adjust_saturation,
config,
device,
dtype
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]])
def test_adjust_hue(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_hue,
F_pil.adjust_hue,
F_t.adjust_hue,
config,
device,
dtype,
tol=16.1,
agg_method="max"
)
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])])
def test_adjust_gamma(device, dtype, config):
check_functional_vs_PIL_vs_scripted(
F.adjust_gamma,
F_pil.adjust_gamma,
F_t.adjust_gamma,
config,
device,
dtype,
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment