"...text-generation-inference.git" did not exist on "c2d4a3b5c7bb6a8367c00f7c797bf87f4b2fcef9"
Unverified Commit 9fa8000d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add support for flow batches in flow_to_image (#5308)

parent 8e874ff8
......@@ -317,29 +317,42 @@ def test_draw_keypoints_errors():
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
def test_flow_to_image():
@pytest.mark.parametrize("batch", (True, False))
def test_flow_to_image(batch):
h, w = 100, 100
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
flow = torch.stack(flow[::-1], dim=0).float()
flow[0] -= h / 2
flow[1] -= w / 2
if batch:
flow = torch.stack([flow, flow])
img = utils.flow_to_image(flow)
assert img.shape == (2, 3, h, w) if batch else (3, h, w)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
expected_img = torch.load(path, map_location="cpu")
assert_equal(expected_img, img)
if batch:
expected_img = torch.stack([expected_img, expected_img])
assert_equal(expected_img, img)
def test_flow_to_image_errors():
wrong_flow1 = torch.full((3, 10, 10), 0, dtype=torch.float)
wrong_flow2 = torch.full((2, 10), 0, dtype=torch.float)
wrong_flow3 = torch.full((2, 10, 30), 0, dtype=torch.int)
with pytest.raises(ValueError, match="Input flow should have shape"):
utils.flow_to_image(flow=wrong_flow1)
with pytest.raises(ValueError, match="Input flow should have shape"):
utils.flow_to_image(flow=wrong_flow2)
with pytest.raises(ValueError, match="Flow should be of dtype torch.float"):
utils.flow_to_image(flow=wrong_flow3)
@pytest.mark.parametrize(
"input_flow, match",
(
(torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
),
)
def test_flow_to_image_errors(input_flow, match):
with pytest.raises(ValueError, match=match):
utils.flow_to_image(flow=input_flow)
if __name__ == "__main__":
......
......@@ -397,42 +397,51 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
Converts a flow to an RGB image.
Args:
flow (Tensor): Flow of shape (2, H, W) and dtype torch.float.
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
Returns:
img (Tensor(3, H, W)): Image Tensor of dtype uint8 where each color corresponds to a given flow direction.
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
"""
if flow.dtype != torch.float:
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
if flow.ndim != 3 or flow.size(0) != 2:
raise ValueError(f"Input flow should have shape (2, H, W), got {flow.shape}.")
orig_shape = flow.shape
if flow.ndim == 3:
flow = flow[None] # Add batch dim
max_norm = torch.sum(flow ** 2, dim=0).sqrt().max()
if flow.ndim != 4 or flow.shape[1] != 2:
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
max_norm = torch.sum(flow ** 2, dim=1).sqrt().max()
epsilon = torch.finfo((flow).dtype).eps
normalized_flow = flow / (max_norm + epsilon)
return _normalized_flow_to_image(normalized_flow)
img = _normalized_flow_to_image(normalized_flow)
if len(orig_shape) == 3:
img = img[0] # Remove batch dim
return img
@torch.no_grad()
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
"""
Converts a normalized flow to an RGB image.
Converts a batch of normalized flow to an RGB image.
Args:
normalized_flow (torch.Tensor): Normalized flow tensor of shape (2, H, W)
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
Returns:
img (Tensor(3, H, W)): Flow visualization image of dtype uint8.
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
"""
_, H, W = normalized_flow.shape
flow_image = torch.zeros((3, H, W), dtype=torch.uint8)
N, _, H, W = normalized_flow.shape
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8)
colorwheel = _make_colorwheel() # shape [55x3]
num_cols = colorwheel.shape[0]
norm = torch.sum(normalized_flow ** 2, dim=0).sqrt()
a = torch.atan2(-normalized_flow[1], -normalized_flow[0]) / torch.pi
norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
fk = (a + 1) / 2 * (num_cols - 1)
k0 = torch.floor(fk).to(torch.long)
k1 = k0 + 1
......@@ -445,7 +454,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
col1 = tmp[k1] / 255.0
col = (1 - f) * col0 + f * col1
col = 1 - norm * (1 - col)
flow_image[c, :, :] = torch.floor(255 * col)
flow_image[:, c, :, :] = torch.floor(255 * col)
return flow_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment