Unverified Commit 6c44ceb5 authored by Nicolas Granger's avatar Nicolas Granger Committed by GitHub
Browse files

Replace stack/mask/reduce by indexing in _hsv2rgb (#7754)


Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent f244e27e
......@@ -317,14 +317,20 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
p = one_minus_s.mul_(v).clamp_(0.0, 1.0)
i.remainder_(6)
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
vpqt = torch.stack((v, p, q, t), dim=-3)
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
# vpqt -> rgb mapping based on i
select = torch.tensor([[0, 2, 1, 1, 3, 0], [3, 0, 0, 2, 1, 1], [1, 1, 3, 0, 0, 2]], dtype=torch.long)
select = select.to(device=img.device, non_blocking=True)
return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)
select = select[:, i]
if select.ndim > 3:
# if input.shape is (B, ..., C, H, W) then
# select.shape is (C, B, ..., H, W)
# thus we move C axis to get (B, ..., C, H, W)
select = select.moveaxis(0, -3)
return vpqt.gather(-3, select)
@_register_kernel_internal(adjust_hue, torch.Tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment