Unverified Commit d481f2d8 authored by Brian Vaughan's avatar Brian Vaughan Committed by GitHub
Browse files

Add torchscriptable adjust_gamma transform (#2459)

* add torchscriptable adjust_gamma transform

https://github.com/pytorch/vision/issues/1375



* changes based on code-review

* Apply suggested change to add type hint

Required by mypy, even thought technically incorrect due to possible Image parameter. torchscript doesn't support a union based type hint.
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent ab73b448
......@@ -24,6 +24,8 @@ class Tester(unittest.TestCase):
def compareTensorToPIL(self, tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.equal(pil_tensor), msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
......@@ -300,6 +302,33 @@ class Tester(unittest.TestCase):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
def test_adjust_gamma(self):
script_fn = torch.jit.script(F_t.adjust_gamma)
tensor, pil_img = self._create_data(26, 36)
for dt in [torch.float64, torch.float32, None]:
if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)
gammas = [0.8, 1.0, 1.2]
gains = [0.7, 1.0, 1.3]
for gamma, gain in zip(gammas, gains):
adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain)
adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain)
scripted_result = script_fn(tensor, gamma, gain)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])
rbg_tensor = adjusted_tensor
if adjusted_tensor.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)
self.compareTensorToPIL(rbg_tensor, adjusted_pil)
self.assertTrue(adjusted_tensor.equal(scripted_result))
def test_resize(self):
script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36)
......
......@@ -1179,14 +1179,14 @@ class Tester(unittest.TestCase):
# test 1
y_pil = F.adjust_gamma(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))
# test 2
y_pil = F.adjust_gamma(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))
......
......@@ -160,8 +160,14 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
result = image.mul(torch.iinfo(dtype).max + 1 - eps)
return result.to(dtype)
else:
# int to float
if dtype.is_floating_point:
......@@ -722,7 +728,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
def adjust_gamma(img, gamma, gain=1):
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
r"""Perform gamma correction on an image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted
......@@ -736,26 +742,18 @@ def adjust_gamma(img, gamma, gain=1):
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Tensor): PIL Image to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
Returns:
PIL Image or Tensor: Gamma correction adjusted image.
"""
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
input_mode = img.mode
img = img.convert('RGB')
gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
if not isinstance(img, torch.Tensor):
return F_pil.adjust_gamma(img, gamma, gain)
img = img.convert(input_mode)
return img
return F_t.adjust_gamma(img, gamma, gain)
def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
......
......@@ -165,6 +165,42 @@ def adjust_hue(img, hue_factor):
return img
@torch.jit.unused
def adjust_gamma(img, gamma, gain=1):
r"""Perform gamma correction on an image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:
.. math::
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
See `Gamma Correction`_ for more details.
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args:
img (PIL Image): PIL Image to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
input_mode = img.mode
img = img.convert('RGB')
gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode)
return img
@torch.jit.unused
def pad(img, padding, fill=0, padding_mode="constant"):
r"""Pad the given PIL.Image on all sides with the given "pad" value.
......
......@@ -198,6 +198,47 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
return _blend(img, rgb_to_grayscale(img), saturation_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
r"""Adjust gamma of an RGB image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:
.. math::
`I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}`
See `Gamma Correction`_ for more details.
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args:
img (Tensor): Tensor of RBG values to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
"""
if not isinstance(img, torch.Tensor):
raise TypeError('img should be a Tensor. Got {}'.format(type(img)))
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = result / 255.0
result = (gain * result ** gamma).clamp(0, 1)
if result.dtype != dtype:
eps = 1e-3
result = (255 + 1.0 - eps) * result
result = result.to(dtype)
return result
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
"""Crop the Image Tensor and resize it to desired size.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment