Unverified Commit d72e9064 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Speed up `adjust_hue_image_tensor` (#6938)

* Performance optimization on adjust_hue_image_tensor

* handle ints

* Inplace logical ops

* Remove unnecessary casting.

* Fix linter.
parent 70edf96d
......@@ -208,11 +208,10 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
mask_maxc_neq_r = maxc != r
mask_maxc_eq_g = maxc == g
mask_maxc_neq_g = ~mask_maxc_eq_g
hr = (bc - gc).mul_(~mask_maxc_neq_r)
hg = (2.0 + rc).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r)
hg = rc.add(2.0).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
hr = bc.sub_(gc).mul_(~mask_maxc_neq_r)
hb = gc.add_(4.0).sub_(rc).mul_(mask_maxc_neq_r.logical_and_(mask_maxc_eq_g.logical_not_()))
h = hr.add_(hg).add_(hb)
h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0)
......@@ -221,14 +220,16 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
h, s, v = img.unbind(dim=-3)
h6 = h * 6
h6 = h.mul(6)
i = torch.floor(h6)
f = h6 - i
f = h6.sub_(i)
i = i.to(dtype=torch.int32)
p = (v * (1.0 - s)).clamp_(0.0, 1.0)
q = (v * (1.0 - s * f)).clamp_(0.0, 1.0)
t = (v * (1.0 - s * (1.0 - f))).clamp_(0.0, 1.0)
sxf = s * f
one_minus_s = 1.0 - s
q = (1.0 - sxf).mul_(v).clamp_(0.0, 1.0)
t = sxf.add_(one_minus_s).mul_(v).clamp_(0.0, 1.0)
p = one_minus_s.mul_(v).clamp_(0.0, 1.0)
i.remainder_(6)
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
......@@ -238,7 +239,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return (a4.mul_(mask.to(dtype=img.dtype).unsqueeze(dim=-4))).sum(dim=-3)
return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)
def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
......
......@@ -164,6 +164,7 @@ def convert_format_bounding_box(
if new_format == old_format:
return bounding_box
# TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
if old_format == BoundingBoxFormat.XYWH:
bounding_box = _xywh_to_xyxy(bounding_box, inplace)
elif old_format == BoundingBoxFormat.CXCYWH:
......
import unittest.mock
from typing import Any, Dict, Tuple, Union
import numpy as np
......@@ -20,6 +19,8 @@ def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:
@torch.jit.unused
def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
import unittest.mock
with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True):
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment