Unverified Commit 144f0980 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixes device mismatch issue while building docs (#5428)

parent 8bf46d4e
......@@ -449,8 +449,9 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
"""
N, _, H, W = normalized_flow.shape
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8)
colorwheel = _make_colorwheel() # shape [55x3]
device = normalized_flow.device
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
colorwheel = _make_colorwheel().to(device) # shape [55x3]
num_cols = colorwheel.shape[0]
norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment