"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "963d432c465324fa056c0bab96a8f5f8444f34e1"
Unverified Commit bb88c452 authored by Dragos Cristian's avatar Dragos Cristian Committed by GitHub
Browse files

adjust_hue now supports inputs of type Tensor (#2566)



* adjust_hue now supports inputs of type Tensor

* Added comparison between original adjust_hue and its Tensor and torch.jit.script versions.

* Added a few type checkings related to adjust_hue in functional_tensor.py in hopes to make F_t.adjust_hue scriptable...but to no avail.

* Changed implementation of _rgb2hsv and removed useless type declaration according to PR's review.

* Handled the range of hue_factor in the assertions and temporarily increased the assertLess bound to make sure that no other test fails.

* Fixed some lint issues with CircleCI and added type hints in functional_pil.py as well.

* Corrected type hint mistakes.

* Followed PR review recommendations and added test for class interface with hue.

* Refactored test_functional_tensor.py to match vfdev-5's d016cab branch by simple copy/paste and added the test_adjust_hue and ColorJitter class interface test in the same style (class interface test was removed in vfdev-5's branch for some reason).

* Removed test_adjustments from test_transforms_tensor.py and moved the ColorJitter class interface test in test_transforms_tensor.py.

* Added cuda test cases for test_adjustments and tried to fix conflict.

* Updated tests
- adjust hue
- color jitter

* Fixes incompatible devices

* Increased tol for cuda tests

* Fixes potential issue with inplace op
- fixes irreproducible failing test on Travis CI

* Reverted fmod -> %
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent ac3ba944
...@@ -217,11 +217,9 @@ class Tester(TransformsTester): ...@@ -217,11 +217,9 @@ class Tester(TransformsTester):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric") F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs): def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"):
script_fn = torch.jit.script(fn) script_fn = torch.jit.script(fn)
torch.manual_seed(15) torch.manual_seed(15)
tensor, pil_img = self._create_data(26, 34, device=self.device) tensor, pil_img = self._create_data(26, 34, device=self.device)
for dt in [None, torch.float32, torch.float64]: for dt in [None, torch.float32, torch.float64]:
...@@ -230,7 +228,6 @@ class Tester(TransformsTester): ...@@ -230,7 +228,6 @@ class Tester(TransformsTester):
tensor = F.convert_image_dtype(tensor, dt) tensor = F.convert_image_dtype(tensor, dt)
for config in configs: for config in configs:
adjusted_tensor = fn_t(tensor, **config) adjusted_tensor = fn_t(tensor, **config)
adjusted_pil = fn_pil(pil_img, **config) adjusted_pil = fn_pil(pil_img, **config)
scripted_result = script_fn(tensor, **config) scripted_result = script_fn(tensor, **config)
...@@ -245,9 +242,12 @@ class Tester(TransformsTester): ...@@ -245,9 +242,12 @@ class Tester(TransformsTester):
# Check that max difference does not exceed 2 in [0, 255] range # Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
tol = 2.0 + 1e-10 self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method)
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol, msg=msg, agg_method="max")
self.assertTrue(adjusted_tensor.allclose(scripted_result), msg=msg) atol = 1e-6
if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type:
atol = 1.0
self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)
def test_adjust_brightness(self): def test_adjust_brightness(self):
self._test_adjust_fn( self._test_adjust_fn(
...@@ -273,6 +273,16 @@ class Tester(TransformsTester): ...@@ -273,6 +273,16 @@ class Tester(TransformsTester):
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]] [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
) )
def test_adjust_hue(self):
self._test_adjust_fn(
F.adjust_hue,
F_pil.adjust_hue,
F_t.adjust_hue,
[{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
tol=0.1,
agg_method="mean"
)
def test_adjust_gamma(self): def test_adjust_gamma(self):
self._test_adjust_fn( self._test_adjust_fn(
F.adjust_gamma, F.adjust_gamma,
......
...@@ -60,24 +60,36 @@ class Tester(TransformsTester): ...@@ -60,24 +60,36 @@ class Tester(TransformsTester):
def test_color_jitter(self): def test_color_jitter(self):
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
for f in [0.1, 0.5, 1.0, 1.34]: for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"brightness": f} meth_kwargs = {"brightness": f}
self._test_class_op( self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
) )
for f in [0.2, 0.5, 1.0, 1.5]: for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"contrast": f} meth_kwargs = {"contrast": f}
self._test_class_op( self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
) )
for f in [0.5, 0.75, 1.0, 1.25]: for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
meth_kwargs = {"saturation": f} meth_kwargs = {"saturation": f}
self._test_class_op( self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
) )
for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
meth_kwargs = {"hue": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
)
# All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
)
def test_pad(self): def test_pad(self):
# Test functional.pad (PIL and Tensor) with padding as single int # Test functional.pad (PIL and Tensor) with padding as single int
......
...@@ -736,7 +736,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -736,7 +736,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
.. _Hue: https://en.wikipedia.org/wiki/Hue .. _Hue: https://en.wikipedia.org/wiki/Hue
Args: Args:
img (PIL Image): PIL Image to be adjusted. img (PIL Image or Tensor): Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively. HSV space in positive and negative direction respectively.
...@@ -744,12 +744,12 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -744,12 +744,12 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
with complementary colors while 0 gives the original image. with complementary colors while 0 gives the original image.
Returns: Returns:
PIL Image: Hue adjusted image. PIL Image or Tensor: Hue adjusted image.
""" """
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor) return F_pil.adjust_hue(img, hue_factor)
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return F_t.adjust_hue(img, hue_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
......
...@@ -157,7 +157,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: ...@@ -157,7 +157,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
return _blend(img, mean, contrast_factor) return _blend(img, mean, contrast_factor)
def adjust_hue(img, hue_factor): def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
"""Adjust hue of an image. """Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and The image hue is adjusted by converting the image to HSV and
...@@ -185,8 +185,8 @@ def adjust_hue(img, hue_factor): ...@@ -185,8 +185,8 @@ def adjust_hue(img, hue_factor):
if not (-0.5 <= hue_factor <= 0.5): if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
if not _is_tensor_a_torch_image(img): if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('tensor is not a torch image.') raise TypeError('img should be Tensor image. Got {}'.format(type(img)))
orig_dtype = img.dtype orig_dtype = img.dtype
if img.dtype == torch.uint8: if img.dtype == torch.uint8:
...@@ -194,8 +194,7 @@ def adjust_hue(img, hue_factor): ...@@ -194,8 +194,7 @@ def adjust_hue(img, hue_factor):
img = _rgb2hsv(img) img = _rgb2hsv(img)
h, s, v = img.unbind(0) h, s, v = img.unbind(0)
h += hue_factor h = (h + hue_factor) % 1.0
h = h % 1.0
img = torch.stack((h, s, v)) img = torch.stack((h, s, v))
img_hue_adj = _hsv2rgb(img) img_hue_adj = _hsv2rgb(img)
...@@ -408,6 +407,8 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: ...@@ -408,6 +407,8 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
def _rgb2hsv(img): def _rgb2hsv(img):
r, g, b = img.unbind(0) r, g, b = img.unbind(0)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=0).values maxc = torch.max(img, dim=0).values
minc = torch.min(img, dim=0).values minc = torch.min(img, dim=0).values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment