Unverified Commit 5f616a2b authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Refactor adjust ops tests (#2595)

* [WIP] Unify ops Grayscale and RandomGrayscale

* Unified inputs for grayscale op and transforms
- deprecated F.to_grayscale in favor of F.rgb_to_grayscale

* Fixes bug with fp input

* Rewritten adjust_* tests
- split test_adjustments into 3 separate tests
- unified testing approach with test_adjust_gamma

* Added ColorJitter tests

* Relaxed tolerance for functional adjust-* tests

* Removed wrong merge and commented code
parent ab590a4a
...@@ -111,64 +111,6 @@ class Tester(TransformsTester): ...@@ -111,64 +111,6 @@ class Tester(TransformsTester):
self.assertLess(max_diff, 1e-5) self.assertLess(max_diff, 1e-5)
def test_adjustments(self):
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)
fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
(F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))
for _ in range(20):
channels = 3
dims = torch.randint(1, 50, (2,))
shape = (channels, dims[0], dims[1])
if torch.randint(0, 2, (1,)) == 0:
img = torch.rand(*shape, dtype=torch.float, device=self.device)
else:
img = torch.randint(0, 256, shape, dtype=torch.uint8, device=self.device)
factor = 3 * torch.rand(1).item()
img_clone = img.clone()
for f, ft, sft in fns:
ft_img = ft(img, factor).cpu()
sft_img = sft(img, factor).cpu()
if not img.dtype.is_floating_point:
ft_img = ft_img.to(torch.float) / 255
sft_img = sft_img.to(torch.float) / 255
img_pil = transforms.ToPILImage()(img)
f_img_pil = f(img_pil, factor)
f_img = transforms.ToTensor()(f_img_pil)
# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (ft_img - f_img).abs().max()
max_diff_scripted = (sft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
self.assertTrue(torch.equal(img, img_clone))
# test for class interface
f = transforms.ColorJitter(brightness=factor)
scripted_fn = torch.jit.script(f)
scripted_fn(img)
f = transforms.ColorJitter(contrast=factor)
scripted_fn = torch.jit.script(f)
scripted_fn(img)
f = transforms.ColorJitter(saturation=factor)
scripted_fn = torch.jit.script(f)
scripted_fn(img)
f = transforms.ColorJitter(brightness=1)
scripted_fn = torch.jit.script(f)
scripted_fn(img)
def test_rgb_to_grayscale(self): def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
...@@ -267,32 +209,69 @@ class Tester(TransformsTester): ...@@ -267,32 +209,69 @@ class Tester(TransformsTester):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric") F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
def test_adjust_gamma(self): def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):
script_fn = torch.jit.script(F.adjust_gamma) script_fn = torch.jit.script(fn)
tensor, pil_img = self._create_data(26, 36, device=self.device)
for dt in [torch.float64, torch.float32, None]: torch.manual_seed(15)
tensor, pil_img = self._create_data(26, 34, device=self.device)
for dt in [None, torch.float32, torch.float64]:
if dt is not None: if dt is not None:
tensor = F.convert_image_dtype(tensor, dt) tensor = F.convert_image_dtype(tensor, dt)
gammas = [0.8, 1.0, 1.2] for config in configs:
gains = [0.7, 1.0, 1.3]
for gamma, gain in zip(gammas, gains):
adjusted_tensor = F.adjust_gamma(tensor, gamma, gain) adjusted_tensor = fn_t(tensor, **config)
adjusted_pil = F.adjust_gamma(pil_img, gamma, gain) adjusted_pil = fn_pil(pil_img, **config)
scripted_result = script_fn(tensor, gamma, gain) scripted_result = script_fn(tensor, **config)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype) msg = "{}, {}".format(dt, config)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1]) self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype, msg=msg)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1], msg=msg)
rbg_tensor = adjusted_tensor rbg_tensor = adjusted_tensor
if adjusted_tensor.dtype != torch.uint8: if adjusted_tensor.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8) rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)
self.compareTensorToPIL(rbg_tensor, adjusted_pil) # Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
tol = 2.0 + 1e-10
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol, msg=msg, agg_method="max")
self.assertTrue(adjusted_tensor.allclose(scripted_result), msg=msg)
def test_adjust_brightness(self):
self._test_adjust_fn(
F.adjust_brightness,
F_pil.adjust_brightness,
F_t.adjust_brightness,
[{"brightness_factor": f} for f in [0.1, 0.5, 1.0, 1.34, 2.5]]
)
def test_adjust_contrast(self):
self._test_adjust_fn(
F.adjust_contrast,
F_pil.adjust_contrast,
F_t.adjust_contrast,
[{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)
self.assertTrue(adjusted_tensor.allclose(scripted_result)) def test_adjust_saturation(self):
self._test_adjust_fn(
F.adjust_saturation,
F_pil.adjust_saturation,
F_t.adjust_saturation,
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
)
def test_adjust_gamma(self):
self._test_adjust_fn(
F.adjust_gamma,
F_pil.adjust_gamma,
F_t.adjust_gamma,
[{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]
)
def test_resize(self): def test_resize(self):
script_fn = torch.jit.script(F_t.resize) script_fn = torch.jit.script(F_t.resize)
......
...@@ -28,7 +28,7 @@ class Tester(TransformsTester): ...@@ -28,7 +28,7 @@ class Tester(TransformsTester):
if meth_kwargs is None: if meth_kwargs is None:
meth_kwargs = {} meth_kwargs = {}
tensor, pil_img = self._create_data(height=10, width=10, device=self.device) tensor, pil_img = self._create_data(26, 34, device=self.device)
# test for class interface # test for class interface
f = getattr(T, method)(**meth_kwargs) f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(f)
...@@ -57,31 +57,26 @@ class Tester(TransformsTester): ...@@ -57,31 +57,26 @@ class Tester(TransformsTester):
def test_random_vertical_flip(self): def test_random_vertical_flip(self):
self._test_op('vflip', 'RandomVerticalFlip') self._test_op('vflip', 'RandomVerticalFlip')
def test_adjustments(self): def test_color_jitter(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
for _ in range(20):
factor = 3 * torch.rand(1).item()
tensor, _ = self._create_data(device=self.device)
pil_img = T.ToPILImage()(tensor)
for func in fns: tol = 1.0 + 1e-10
adjusted_tensor = getattr(F, func)(tensor, factor) for f in [0.1, 0.5, 1.0, 1.34]:
adjusted_pil_img = getattr(F, func)(pil_img, factor) meth_kwargs = {"brightness": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
adjusted_pil_tensor = T.ToTensor()(adjusted_pil_img).to(self.device) for f in [0.2, 0.5, 1.0, 1.5]:
scripted_fn = torch.jit.script(getattr(F, func)) meth_kwargs = {"contrast": f}
adjusted_tensor_script = scripted_fn(tensor, factor) self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
if not tensor.dtype.is_floating_point: )
adjusted_tensor = adjusted_tensor.to(torch.float) / 255
adjusted_tensor_script = adjusted_tensor_script.to(torch.float) / 255 for f in [0.5, 0.75, 1.0, 1.25]:
meth_kwargs = {"saturation": f}
# F uses uint8 and F_t uses float, so there is a small self._test_class_op(
# difference in values caused by (at most 5) truncations. "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
max_diff = (adjusted_tensor - adjusted_pil_tensor).abs().max() )
max_diff_scripted = (adjusted_tensor - adjusted_tensor_script).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
def test_pad(self): def test_pad(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment