Unverified Commit d481f2d8 authored by Brian Vaughan's avatar Brian Vaughan Committed by GitHub
Browse files

Add torchscriptable adjust_gamma transform (#2459)

* add torchscriptable adjust_gamma transform

https://github.com/pytorch/vision/issues/1375



* changes based on code-review

* Apply suggested change to add type hint

Required by mypy, even thought technically incorrect due to possible Image parameter. torchscript doesn't support a union based type hint.
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent ab73b448
...@@ -24,6 +24,8 @@ class Tester(unittest.TestCase): ...@@ -24,6 +24,8 @@ class Tester(unittest.TestCase):
def compareTensorToPIL(self, tensor, pil_image, msg=None): def compareTensorToPIL(self, tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.equal(pil_tensor), msg) self.assertTrue(tensor.equal(pil_tensor), msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None): def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
...@@ -300,6 +302,33 @@ class Tester(unittest.TestCase): ...@@ -300,6 +302,33 @@ class Tester(unittest.TestCase):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric") F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
def test_adjust_gamma(self):
script_fn = torch.jit.script(F_t.adjust_gamma)
tensor, pil_img = self._create_data(26, 36)
for dt in [torch.float64, torch.float32, None]:
if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)
gammas = [0.8, 1.0, 1.2]
gains = [0.7, 1.0, 1.3]
for gamma, gain in zip(gammas, gains):
adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain)
adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain)
scripted_result = script_fn(tensor, gamma, gain)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])
rbg_tensor = adjusted_tensor
if adjusted_tensor.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)
self.compareTensorToPIL(rbg_tensor, adjusted_pil)
self.assertTrue(adjusted_tensor.equal(scripted_result))
def test_resize(self): def test_resize(self):
script_fn = torch.jit.script(F_t.resize) script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36) tensor, pil_img = self._create_data(26, 36)
......
...@@ -1179,14 +1179,14 @@ class Tester(unittest.TestCase): ...@@ -1179,14 +1179,14 @@ class Tester(unittest.TestCase):
# test 1 # test 1
y_pil = F.adjust_gamma(x_pil, 0.5) y_pil = F.adjust_gamma(x_pil, 0.5)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15] y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) self.assertTrue(np.allclose(y_np, y_ans))
# test 2 # test 2
y_pil = F.adjust_gamma(x_pil, 2) y_pil = F.adjust_gamma(x_pil, 2)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0] y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) self.assertTrue(np.allclose(y_np, y_ans))
......
...@@ -160,8 +160,14 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - ...@@ -160,8 +160,14 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg) raise RuntimeError(msg)
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3 eps = 1e-3
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) result = image.mul(torch.iinfo(dtype).max + 1 - eps)
return result.to(dtype)
else: else:
# int to float # int to float
if dtype.is_floating_point: if dtype.is_floating_point:
...@@ -722,7 +728,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -722,7 +728,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
def adjust_gamma(img, gamma, gain=1): def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
r"""Perform gamma correction on an image. r"""Perform gamma correction on an image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted Also known as Power Law Transform. Intensities in RGB mode are adjusted
...@@ -736,26 +742,18 @@ def adjust_gamma(img, gamma, gain=1): ...@@ -736,26 +742,18 @@ def adjust_gamma(img, gamma, gain=1):
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args: Args:
img (PIL Image): PIL Image to be adjusted. img (PIL Image or Tensor): PIL Image to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation. gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker, gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter. while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier. gain (float): The constant multiplier.
Returns:
PIL Image or Tensor: Gamma correction adjusted image.
""" """
if not F_pil._is_pil_image(img): if not isinstance(img, torch.Tensor):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return F_pil.adjust_gamma(img, gamma, gain)
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
input_mode = img.mode
img = img.convert('RGB')
gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode) return F_t.adjust_gamma(img, gamma, gain)
return img
def rotate(img, angle, resample=False, expand=False, center=None, fill=None): def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
......
...@@ -165,6 +165,42 @@ def adjust_hue(img, hue_factor): ...@@ -165,6 +165,42 @@ def adjust_hue(img, hue_factor):
return img return img
@torch.jit.unused
def adjust_gamma(img, gamma, gain=1):
r"""Perform gamma correction on an image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:
.. math::
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
See `Gamma Correction`_ for more details.
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args:
img (PIL Image): PIL Image to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
input_mode = img.mode
img = img.convert('RGB')
gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode)
return img
@torch.jit.unused @torch.jit.unused
def pad(img, padding, fill=0, padding_mode="constant"): def pad(img, padding, fill=0, padding_mode="constant"):
r"""Pad the given PIL.Image on all sides with the given "pad" value. r"""Pad the given PIL.Image on all sides with the given "pad" value.
......
...@@ -198,6 +198,47 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: ...@@ -198,6 +198,47 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
return _blend(img, rgb_to_grayscale(img), saturation_factor) return _blend(img, rgb_to_grayscale(img), saturation_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
r"""Adjust gamma of an RGB image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:
.. math::
`I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}`
See `Gamma Correction`_ for more details.
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args:
img (Tensor): Tensor of RBG values to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
"""
if not isinstance(img, torch.Tensor):
raise TypeError('img should be a Tensor. Got {}'.format(type(img)))
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = result / 255.0
result = (gain * result ** gamma).clamp(0, 1)
if result.dtype != dtype:
eps = 1e-3
result = (255 + 1.0 - eps) * result
result = result.to(dtype)
return result
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
"""Crop the Image Tensor and resize it to desired size. """Crop the Image Tensor and resize it to desired size.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment