Unverified Commit 9fa8000d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add support for flow batches in flow_to_image (#5308)

parent 8e874ff8
...@@ -317,29 +317,42 @@ def test_draw_keypoints_errors(): ...@@ -317,29 +317,42 @@ def test_draw_keypoints_errors():
utils.draw_keypoints(image=img, keypoints=invalid_keypoints) utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
def test_flow_to_image(): @pytest.mark.parametrize("batch", (True, False))
def test_flow_to_image(batch):
h, w = 100, 100 h, w = 100, 100
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
flow = torch.stack(flow[::-1], dim=0).float() flow = torch.stack(flow[::-1], dim=0).float()
flow[0] -= h / 2 flow[0] -= h / 2
flow[1] -= w / 2 flow[1] -= w / 2
if batch:
flow = torch.stack([flow, flow])
img = utils.flow_to_image(flow) img = utils.flow_to_image(flow)
assert img.shape == (2, 3, h, w) if batch else (3, h, w)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
expected_img = torch.load(path, map_location="cpu") expected_img = torch.load(path, map_location="cpu")
assert_equal(expected_img, img)
if batch:
expected_img = torch.stack([expected_img, expected_img])
assert_equal(expected_img, img)
def test_flow_to_image_errors():
wrong_flow1 = torch.full((3, 10, 10), 0, dtype=torch.float)
wrong_flow2 = torch.full((2, 10), 0, dtype=torch.float)
wrong_flow3 = torch.full((2, 10, 30), 0, dtype=torch.int)
with pytest.raises(ValueError, match="Input flow should have shape"): @pytest.mark.parametrize(
utils.flow_to_image(flow=wrong_flow1) "input_flow, match",
with pytest.raises(ValueError, match="Input flow should have shape"): (
utils.flow_to_image(flow=wrong_flow2) (torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
with pytest.raises(ValueError, match="Flow should be of dtype torch.float"): (torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
utils.flow_to_image(flow=wrong_flow3) (torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
),
)
def test_flow_to_image_errors(input_flow, match):
with pytest.raises(ValueError, match=match):
utils.flow_to_image(flow=input_flow)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -397,42 +397,51 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor: ...@@ -397,42 +397,51 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
Converts a flow to an RGB image. Converts a flow to an RGB image.
Args: Args:
flow (Tensor): Flow of shape (2, H, W) and dtype torch.float. flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
Returns: Returns:
img (Tensor(3, H, W)): Image Tensor of dtype uint8 where each color corresponds to a given flow direction. img (Tensor): Image Tensor of dtype uint8 where each color corresponds
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
""" """
if flow.dtype != torch.float: if flow.dtype != torch.float:
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
if flow.ndim != 3 or flow.size(0) != 2: orig_shape = flow.shape
raise ValueError(f"Input flow should have shape (2, H, W), got {flow.shape}.") if flow.ndim == 3:
flow = flow[None] # Add batch dim
max_norm = torch.sum(flow ** 2, dim=0).sqrt().max() if flow.ndim != 4 or flow.shape[1] != 2:
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
max_norm = torch.sum(flow ** 2, dim=1).sqrt().max()
epsilon = torch.finfo((flow).dtype).eps epsilon = torch.finfo((flow).dtype).eps
normalized_flow = flow / (max_norm + epsilon) normalized_flow = flow / (max_norm + epsilon)
return _normalized_flow_to_image(normalized_flow) img = _normalized_flow_to_image(normalized_flow)
if len(orig_shape) == 3:
img = img[0] # Remove batch dim
return img
@torch.no_grad() @torch.no_grad()
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
""" """
Converts a normalized flow to an RGB image. Converts a batch of normalized flow to an RGB image.
Args: Args:
normalized_flow (torch.Tensor): Normalized flow tensor of shape (2, H, W) normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
Returns: Returns:
img (Tensor(3, H, W)): Flow visualization image of dtype uint8. img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
""" """
_, H, W = normalized_flow.shape N, _, H, W = normalized_flow.shape
flow_image = torch.zeros((3, H, W), dtype=torch.uint8) flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8)
colorwheel = _make_colorwheel() # shape [55x3] colorwheel = _make_colorwheel() # shape [55x3]
num_cols = colorwheel.shape[0] num_cols = colorwheel.shape[0]
norm = torch.sum(normalized_flow ** 2, dim=0).sqrt() norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
a = torch.atan2(-normalized_flow[1], -normalized_flow[0]) / torch.pi a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
fk = (a + 1) / 2 * (num_cols - 1) fk = (a + 1) / 2 * (num_cols - 1)
k0 = torch.floor(fk).to(torch.long) k0 = torch.floor(fk).to(torch.long)
k1 = k0 + 1 k1 = k0 + 1
...@@ -445,7 +454,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: ...@@ -445,7 +454,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
col1 = tmp[k1] / 255.0 col1 = tmp[k1] / 255.0
col = (1 - f) * col0 + f * col1 col = (1 - f) * col0 + f * col1
col = 1 - norm * (1 - col) col = 1 - norm * (1 - col)
flow_image[c, :, :] = torch.floor(255 * col) flow_image[:, c, :, :] = torch.floor(255 * col)
return flow_image return flow_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment