Commit 8909ff43 authored by Ankit Jha's avatar Ankit Jha Committed by Francisco Massa
Browse files

[WIP] Add Scriptable Transform: Grayscale (#1505)

* Add Scriptable Transform: Grayscale

* add scriptable transforms: rgb_to_grayscale

* add scriptable transform: rgb_to_grayscale

* add scriptable transform: rgb_to_grayscale

* add scriptable transform: rgb_to_grayscale

* update code: rgb_to_grayscale

* add test: rgb_to_grayscale

* update parameters: rgb_to_grayscale

* add scriptable transform: rgb_to_grayscale

* update rgb_to_grayscale

* update rgb_to_grayscale
parent cd174844
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
import numpy as np
import unittest import unittest
import random import random
...@@ -68,6 +69,13 @@ class Tester(unittest.TestCase): ...@@ -68,6 +69,13 @@ class Tester(unittest.TestCase):
max_diff = (ft_img - f_img).abs().max() max_diff = (ft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5) self.assertLess(max_diff, 5 / 255 + 1e-5)
def test_rgb_to_grayscale(self):
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -26,7 +26,6 @@ def hflip(img_tensor): ...@@ -26,7 +26,6 @@ def hflip(img_tensor):
Returns: Returns:
Tensor: Horizontally flipped image Tensor. Tensor: Horizontally flipped image Tensor.
""" """
if not F._is_tensor_image(img_tensor): if not F._is_tensor_image(img_tensor):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
...@@ -35,12 +34,14 @@ def hflip(img_tensor): ...@@ -35,12 +34,14 @@ def hflip(img_tensor):
def crop(img, top, left, height, width): def crop(img, top, left, height, width):
"""Crop the given Image Tensor. """Crop the given Image Tensor.
Args: Args:
img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image. img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box. top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box. height (int): Height of the crop box.
width (int): Width of the crop box. width (int): Width of the crop box.
Returns: Returns:
Tensor: Cropped image. Tensor: Cropped image.
""" """
...@@ -50,6 +51,24 @@ def crop(img, top, left, height, width): ...@@ -50,6 +51,24 @@ def crop(img, top, left, height, width):
return img[..., top:top + height, left:left + width] return img[..., top:top + height, left:left + width]
def rgb_to_grayscale(img):
"""Convert the given RGB Image Tensor to Grayscale.
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
is L = R * 0.2989 + G * 0.5870 + B * 0.1140
Args:
img (Tensor): Image to be converted to Grayscale in the form [C, H, W].
Returns:
Tensor: Grayscale image.
"""
if img.shape[0] != 3:
raise TypeError('Input Image does not contain 3 Channels')
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)
def adjust_brightness(img, brightness_factor): def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an RGB image. """Adjust brightness of an RGB image.
...@@ -83,7 +102,7 @@ def adjust_contrast(img, contrast_factor): ...@@ -83,7 +102,7 @@ def adjust_contrast(img, contrast_factor):
if not F._is_tensor_image(img): if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
mean = torch.mean(_rgb_to_grayscale(img).to(torch.float)) mean = torch.mean(rgb_to_grayscale(img).to(torch.float))
return _blend(img, mean, contrast_factor) return _blend(img, mean, contrast_factor)
...@@ -103,14 +122,9 @@ def adjust_saturation(img, saturation_factor): ...@@ -103,14 +122,9 @@ def adjust_saturation(img, saturation_factor):
if not F._is_tensor_image(img): if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
return _blend(img, _rgb_to_grayscale(img), saturation_factor) return _blend(img, rgb_to_grayscale(img), saturation_factor)
def _blend(img1, img2, ratio): def _blend(img1, img2, ratio):
bound = 1 if img1.dtype.is_floating_point else 255 bound = 1 if img1.dtype.is_floating_point else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb_to_grayscale(img):
# ITU-R 601-2 luma transform, as used in PIL.
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment