Unverified Commit 54da5db4 authored by vikramtankasali's avatar vikramtankasali Committed by GitHub
Browse files

Adjust hue accepts torch tensor (#2300)



* Adjust hue

* Adjust hue acceps torch.tensor uint8
Co-authored-by: default avatarVikram Mukunda Rao Tankasali <vikramtankasali@devvm765.lla0.facebook.com>
parent 747f406a
...@@ -6,6 +6,7 @@ import torchvision.transforms.functional as F ...@@ -6,6 +6,7 @@ import torchvision.transforms.functional as F
import numpy as np import numpy as np
import unittest import unittest
import random import random
import colorsys
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
...@@ -56,6 +57,45 @@ class Tester(unittest.TestCase): ...@@ -56,6 +57,45 @@ class Tester(unittest.TestCase):
cropped_img_script = script_crop(img_tensor, top, left, height, width) cropped_img_script = script_crop(img_tensor, top, left, height, width)
self.assertTrue(torch.equal(img_cropped, cropped_img_script)) self.assertTrue(torch.equal(img_cropped, cropped_img_script))
def test_hsv2rgb(self):
shape = (3, 100, 150)
for _ in range(20):
img = torch.rand(*shape, dtype=torch.float)
ft_img = F_t._hsv2rgb(img).permute(1, 2, 0).flatten(0, 1)
h, s, v, = img.unbind(0)
h = h.flatten().numpy()
s = s.flatten().numpy()
v = v.flatten().numpy()
rgb = []
for h1, s1, v1 in zip(h, s, v):
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
colorsys_img = torch.tensor(rgb, dtype=torch.float32)
max_diff = (ft_img - colorsys_img).abs().max()
self.assertLess(max_diff, 1e-5)
def test_rgb2hsv(self):
shape = (3, 150, 100)
for _ in range(20):
img = torch.rand(*shape, dtype=torch.float)
ft_hsv_img = F_t._rgb2hsv(img).permute(1, 2, 0).flatten(0, 1)
r, g, b, = img.unbind(0)
r = r.flatten().numpy()
g = g.flatten().numpy()
b = b.flatten().numpy()
hsv = []
for r1, g1, b1 in zip(r, g, b):
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))
colorsys_img = torch.tensor(hsv, dtype=torch.float32)
max_diff = (colorsys_img - ft_hsv_img).abs().max()
self.assertLess(max_diff, 1e-5)
def test_adjustments(self): def test_adjustments(self):
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness) script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast) script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
......
...@@ -118,6 +118,54 @@ def adjust_contrast(img, contrast_factor): ...@@ -118,6 +118,54 @@ def adjust_contrast(img, contrast_factor):
return _blend(img, mean, contrast_factor) return _blend(img, mean, contrast_factor)
def adjust_hue(img, hue_factor):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See `Hue`_ for more details.
.. _Hue: https://en.wikipedia.org/wiki/Hue
Args:
img (Tensor): Image to be adjusted. Image type is either uint8 or float.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
Tensor: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0
img = _rgb2hsv(img)
h, s, v = img.unbind(0)
h += hue_factor
h = h % 1.0
img = torch.stack((h, s, v))
img_hue_adj = _hsv2rgb(img)
if orig_dtype == torch.uint8:
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)
return img_hue_adj
def adjust_saturation(img, saturation_factor): def adjust_saturation(img, saturation_factor):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
"""Adjust color saturation of an RGB image. """Adjust color saturation of an RGB image.
...@@ -235,3 +283,47 @@ def _blend(img1, img2, ratio): ...@@ -235,3 +283,47 @@ def _blend(img1, img2, ratio):
# type: (Tensor, Tensor, float) -> Tensor # type: (Tensor, Tensor, float) -> Tensor
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255 bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb2hsv(img):
r, g, b = img.unbind(0)
maxc, _ = torch.max(img, dim=0)
minc, _ = torch.min(img, dim=0)
cr = maxc - minc
s = cr / maxc
rc = (maxc - r) / cr
gc = (maxc - g) / cr
bc = (maxc - b) / cr
t = (maxc != minc)
s = t * s
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = (hr + hg + hb)
h = t * h
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc))
def _hsv2rgb(img):
h, s, v = img.unbind(0)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)
p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6
mask = i == torch.arange(6)[:, None, None]
a1 = torch.stack((v, q, p, p, t, v))
a2 = torch.stack((t, v, v, q, p, p))
a3 = torch.stack((p, p, t, v, v, q))
a4 = torch.stack((a1, a2, a3))
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment