Unverified Commit bb88c452 authored by Dragos Cristian's avatar Dragos Cristian Committed by GitHub
Browse files

adjust_hue now supports inputs of type Tensor (#2566)



* adjust_hue now supports inputs of type Tensor

* Added comparison between original adjust_hue and its Tensor and torch.jit.script versions.

* Added a few type checkings related to adjust_hue in functional_tensor.py in hopes to make F_t.adjust_hue scriptable...but to no avail.

* Changed implementation of _rgb2hsv and removed useless type declaration according to PR's review.

* Handled the range of hue_factor in the assertions and temporarily increased the assertLess bound to make sure that no other test fails.

* Fixed some lint issues with CircleCI and added type hints in functional_pil.py as well.

* Corrected type hint mistakes.

* Followed PR review recommendations and added test for class interface with hue.

* Refactored test_functional_tensor.py to match vfdev-5's d016cab branch by simple copy/paste and added the test_adjust_hue and ColorJitter class interface test in the same style (class interface test was removed in vfdev-5's branch for some reason).

* Removed test_adjustments from test_transforms_tensor.py and moved the ColorJitter class interface test in test_transforms_tensor.py.

* Added cuda test cases for test_adjustments and tried to fix conflict.

* Updated tests
- adjust hue
- color jitter

* Fixes incompatible devices

* Increased tol for cuda tests

* Fixes potential issue with inplace op
- fixes irreproducible failing test on Travis CI

* Reverted fmod -> %
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent ac3ba944
......@@ -217,11 +217,9 @@ class Tester(TransformsTester):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"):
script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = self._create_data(26, 34, device=self.device)
for dt in [None, torch.float32, torch.float64]:
......@@ -230,7 +228,6 @@ class Tester(TransformsTester):
tensor = F.convert_image_dtype(tensor, dt)
for config in configs:
adjusted_tensor = fn_t(tensor, **config)
adjusted_pil = fn_pil(pil_img, **config)
scripted_result = script_fn(tensor, **config)
......@@ -245,9 +242,12 @@ class Tester(TransformsTester):
# Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
tol = 2.0 + 1e-10
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol, msg=msg, agg_method="max")
self.assertTrue(adjusted_tensor.allclose(scripted_result), msg=msg)
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method)
atol = 1e-6
if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type:
atol = 1.0
self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)
def test_adjust_brightness(self):
self._test_adjust_fn(
......@@ -273,6 +273,16 @@ class Tester(TransformsTester):
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
)
def test_adjust_hue(self):
self._test_adjust_fn(
F.adjust_hue,
F_pil.adjust_hue,
F_t.adjust_hue,
[{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
tol=0.1,
agg_method="mean"
)
def test_adjust_gamma(self):
self._test_adjust_fn(
F.adjust_gamma,
......
......@@ -60,24 +60,36 @@ class Tester(TransformsTester):
def test_color_jitter(self):
tol = 1.0 + 1e-10
for f in [0.1, 0.5, 1.0, 1.34]:
for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"brightness": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
for f in [0.2, 0.5, 1.0, 1.5]:
for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"contrast": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
for f in [0.5, 0.75, 1.0, 1.25]:
for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
meth_kwargs = {"saturation": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
meth_kwargs = {"hue": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
)
# All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
)
def test_pad(self):
# Test functional.pad (PIL and Tensor) with padding as single int
......
......@@ -736,7 +736,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
.. _Hue: https://en.wikipedia.org/wiki/Hue
Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Tensor): Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
......@@ -744,12 +744,12 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
with complementary colors while 0 gives the original image.
Returns:
PIL Image: Hue adjusted image.
PIL Image or Tensor: Hue adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor)
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return F_t.adjust_hue(img, hue_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
......
......@@ -157,7 +157,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
return _blend(img, mean, contrast_factor)
def adjust_hue(img, hue_factor):
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
......@@ -185,8 +185,8 @@ def adjust_hue(img, hue_factor):
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('img should be Tensor image. Got {}'.format(type(img)))
orig_dtype = img.dtype
if img.dtype == torch.uint8:
......@@ -194,8 +194,7 @@ def adjust_hue(img, hue_factor):
img = _rgb2hsv(img)
h, s, v = img.unbind(0)
h += hue_factor
h = h % 1.0
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v))
img_hue_adj = _hsv2rgb(img)
......@@ -408,6 +407,8 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
def _rgb2hsv(img):
r, g, b = img.unbind(0)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=0).values
minc = torch.min(img, dim=0).values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment