Unverified Commit ebbead24 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Adds utlity to convert optical flow to an image (#5134)



* Start stuff

* Start adding some implementation

* Add implementation

* Use torch for colorwheel

* Add simple test

* Fix small numpy rename

* Adapt the changes

* Fix stuff

* Add suggestions

* Minor fixes for float

* Remove idx logic

* Simplify

* Fix test

* Update code and add test

* Add expected flow asset

* Fix expected flow path

* Doc nits
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent dd9080ca
...@@ -15,5 +15,6 @@ vizualization <sphx_glr_auto_examples_plot_visualization_utils.py>`. ...@@ -15,5 +15,6 @@ vizualization <sphx_glr_auto_examples_plot_visualization_utils.py>`.
draw_bounding_boxes draw_bounding_boxes
draw_segmentation_masks draw_segmentation_masks
draw_keypoints draw_keypoints
flow_to_image
make_grid make_grid
save_image save_image
...@@ -317,5 +317,30 @@ def test_draw_keypoints_errors(): ...@@ -317,5 +317,30 @@ def test_draw_keypoints_errors():
utils.draw_keypoints(image=img, keypoints=invalid_keypoints) utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
def test_flow_to_image():
h, w = 100, 100
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
flow = torch.stack(flow[::-1], dim=0).float()
flow[0] -= h / 2
flow[1] -= w / 2
img = utils.flow_to_image(flow)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
expected_img = torch.load(path, map_location="cpu")
assert_equal(expected_img, img)
def test_flow_to_image_errors():
wrong_flow1 = torch.full((3, 10, 10), 0, dtype=torch.float)
wrong_flow2 = torch.full((2, 10), 0, dtype=torch.float)
wrong_flow3 = torch.full((2, 10, 30), 0, dtype=torch.int)
with pytest.raises(ValueError, match="Input flow should have shape"):
utils.flow_to_image(flow=wrong_flow1)
with pytest.raises(ValueError, match="Input flow should have shape"):
utils.flow_to_image(flow=wrong_flow2)
with pytest.raises(ValueError, match="Flow should be of dtype torch.float"):
utils.flow_to_image(flow=wrong_flow3)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -8,7 +8,14 @@ import numpy as np ...@@ -8,7 +8,14 @@ import numpy as np
import torch import torch
from PIL import Image, ImageColor, ImageDraw, ImageFont from PIL import Image, ImageColor, ImageDraw, ImageFont
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks", "draw_keypoints"] __all__ = [
"make_grid",
"save_image",
"draw_bounding_boxes",
"draw_segmentation_masks",
"draw_keypoints",
"flow_to_image",
]
@torch.no_grad() @torch.no_grad()
...@@ -382,6 +389,113 @@ def draw_keypoints( ...@@ -382,6 +389,113 @@ def draw_keypoints(
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
@torch.no_grad()
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
"""
Converts a flow to an RGB image.
Args:
flow (Tensor): Flow of shape (2, H, W) and dtype torch.float.
Returns:
img (Tensor(3, H, W)): Image Tensor of dtype uint8 where each color corresponds to a given flow direction.
"""
if flow.dtype != torch.float:
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
if flow.ndim != 3 or flow.size(0) != 2:
raise ValueError(f"Input flow should have shape (2, H, W), got {flow.shape}.")
max_norm = torch.sum(flow ** 2, dim=0).sqrt().max()
epsilon = torch.finfo((flow).dtype).eps
normalized_flow = flow / (max_norm + epsilon)
return _normalized_flow_to_image(normalized_flow)
@torch.no_grad()
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
"""
Converts a normalized flow to an RGB image.
Args:
normalized_flow (torch.Tensor): Normalized flow tensor of shape (2, H, W)
Returns:
img (Tensor(3, H, W)): Flow visualization image of dtype uint8.
"""
_, H, W = normalized_flow.shape
flow_image = torch.zeros((3, H, W), dtype=torch.uint8)
colorwheel = _make_colorwheel() # shape [55x3]
num_cols = colorwheel.shape[0]
norm = torch.sum(normalized_flow ** 2, dim=0).sqrt()
a = torch.atan2(-normalized_flow[1], -normalized_flow[0]) / torch.pi
fk = (a + 1) / 2 * (num_cols - 1)
k0 = torch.floor(fk).to(torch.long)
k1 = k0 + 1
k1[k1 == num_cols] = 0
f = fk - k0
for c in range(colorwheel.shape[1]):
tmp = colorwheel[:, c]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1 - f) * col0 + f * col1
col = 1 - norm * (1 - col)
flow_image[c, :, :] = torch.floor(255 * col)
return flow_image
def _make_colorwheel() -> torch.Tensor:
"""
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
Returns:
colorwheel (Tensor[55, 3]): Colorwheel Tensor.
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = torch.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
col = col + RY
# YG
colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
colorwheel[col : col + YG, 1] = 255
col = col + YG
# GC
colorwheel[col : col + GC, 1] = 255
colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
col = col + GC
# CB
colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
colorwheel[col : col + CB, 2] = 255
col = col + CB
# BM
colorwheel[col : col + BM, 2] = 255
colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
col = col + BM
# MR
colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
colorwheel[col : col + MR, 0] = 255
return colorwheel
def _generate_color_palette(num_masks: int): def _generate_color_palette(num_masks: int):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
return [tuple((i * palette) % 255) for i in range(num_masks)] return [tuple((i * palette) % 255) for i in range(num_masks)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment