"src/vscode:/vscode.git/clone" did not exist on "5e12d5c6916872744696b5fd2d4dba9185ae98b3"
Commit e79caddf authored by pedrofreire's avatar pedrofreire Committed by Francisco Massa
Browse files

Add adjustment operations for RGB Tensor Images. (#1525)

* Add adjustment operations for RGB Tensor Images.

Right now, we have operations on PIL images, but we want to have a version of the opeartions that act directly on Tensor images.

Here, we add such operations for adjust_brightness, adjust_contrast and adjust_saturation.

In PIL, those functions are implemented by generating an degenerate image from the first, and then interpolating them together.
- https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageEnhance.py
- https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Blend.c

A few caveats:
* Since PIL operates on uint8, and the tensor operations might be on float, we can get slightly different values because of int truncation.
* We assume here the images are RGB; in particular, to handle an alpha channel, we need to check whether it is present, in which case we copy it to the final image.

* Keep dtype and use broadcast in adjust operations

- We make our operations have input.dtype == output.dtype, at the cost of
adding a few type checks and branches.

- By using Tensor broadcast, we can simplify the calls to _blend.

* Use is_floating_point to check dtype.

* Remove unpacking in tuple

It seems Python 2 does not support this type of unpacking, so it broke
Python 2 builds. This should fix it.

* Add from __future__ import division for Python 2
parent 9e27356f
from __future__ import division
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
......@@ -36,6 +37,37 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
"functional_tensor crop not working")
def test_adjustments(self):
fns = ((F.adjust_brightness, F_t.adjust_brightness),
(F.adjust_contrast, F_t.adjust_contrast),
(F.adjust_saturation, F_t.adjust_saturation))
for _ in range(20):
channels = 3
dims = torch.randint(1, 50, (2,))
shape = (channels, dims[0], dims[1])
if torch.randint(0, 2, (1,)) == 0:
img = torch.rand(*shape, dtype=torch.float)
else:
img = torch.randint(0, 256, shape, dtype=torch.uint8)
factor = 3 * torch.rand(1)
for f, ft in fns:
ft_img = ft(img, factor)
if not img.dtype.is_floating_point:
ft_img = ft_img.to(torch.float) / 255
img_pil = transforms.ToPILImage()(img)
f_img_pil = f(img_pil, factor)
f_img = transforms.ToTensor()(f_img_pil)
# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (ft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
if __name__ == '__main__':
unittest.main()
......@@ -48,3 +48,69 @@ def crop(img, top, left, height, width):
raise TypeError('tensor is not a torch image.')
return img[..., top:top + height, left:left + width]
def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an RGB image.
Args:
img (Tensor): Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
Tensor: Brightness adjusted image.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')
return _blend(img, 0, brightness_factor)
def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an RGB image.
Args:
img (Tensor): Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
Tensor: Contrast adjusted image.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')
mean = torch.mean(_rgb_to_grayscale(img).to(torch.float))
return _blend(img, mean, contrast_factor)
def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an RGB image.
Args:
img (Tensor): Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
Tensor: Saturation adjusted image.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')
return _blend(img, _rgb_to_grayscale(img), saturation_factor)
def _blend(img1, img2, ratio):
bound = 1 if img1.dtype.is_floating_point else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb_to_grayscale(img):
# ITU-R 601-2 luma transform, as used in PIL.
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment