Unverified Commit 62da7d48 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Speed improvements for adjust hue op (#6805)



* WIP

* Updated rgb2hsv and a bit of hsv2rgb

* Fix issue with batch of images

* Few improvements

* hsv2rgb improvements

* PR review

* another update

* Fix cuda issue with empty images
torch.aminmax is failing
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 3dd2e3d8
...@@ -143,7 +143,104 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe ...@@ -143,7 +143,104 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
adjust_hue_image_tensor = _FT.adjust_hue def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
r, g, _ = image.unbind(dim=-3)
# Implementation is based on
# https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330
minc, maxc = torch.aminmax(image, dim=-3)
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
# + H channel has division by `(maxc - minc)`.
#
# Instead of overwriting NaN afterwards, we just prevent it from occuring so
# we don't need to deal with it in case we save the NaN in a buffer in
# backprop, if it is ever supported, but it doesn't hurt to do so.
eqc = maxc == minc
channels_range = maxc - minc
# Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine.
ones = torch.ones_like(maxc)
s = channels_range / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
channels_range_divisor = torch.where(eqc, ones, channels_range).unsqueeze_(dim=-3)
rc, gc, bc = ((maxc.unsqueeze(dim=-3) - image) / channels_range_divisor).unbind(dim=-3)
mask_maxc_neq_r = maxc != r
mask_maxc_eq_g = maxc == g
mask_maxc_neq_g = ~mask_maxc_eq_g
hr = (bc - gc).mul_(~mask_maxc_neq_r)
hg = (2.0 + rc).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r)
h = hr.add_(hg).add_(hb)
h = h.div_(6.0).add_(1.0).fmod_(1.0)
return torch.stack((h, s, maxc), dim=-3)
def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
h, s, v = img.unbind(dim=-3)
h6 = h * 6
i = torch.floor(h6)
f = (h6) - i
i = i.to(dtype=torch.int32)
p = (v * (1.0 - s)).clamp_(0.0, 1.0)
q = (v * (1.0 - s * f)).clamp_(0.0, 1.0)
t = (v * (1.0 - s * (1.0 - f))).clamp_(0.0, 1.0)
i.remainder_(6)
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return (a4.mul_(mask.to(dtype=img.dtype).unsqueeze(dim=-4))).sum(dim=-3)
def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not (isinstance(image, torch.Tensor)):
raise TypeError("Input img should be Tensor image")
c = get_num_channels_image_tensor(image)
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
if c == 1: # Match PIL behaviour
return image
if image.numel() == 0:
# exit earlier on empty images
return image
orig_dtype = image.dtype
if image.dtype == torch.uint8:
image = image / 255.0
image = _rgb_to_hsv(image)
h, s, v = image.unbind(dim=-3)
h.add_(hue_factor).remainder_(1.0)
image = torch.stack((h, s, v), dim=-3)
image_hue_adj = _hsv_to_rgb(image)
if orig_dtype == torch.uint8:
image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype)
return image_hue_adj
adjust_hue_image_pil = _FP.adjust_hue adjust_hue_image_pil = _FP.adjust_hue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment