Unverified Commit 015eb46b authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Pass indexing param to meshgrid to avoid warning (#4645)

parent 29dcf767
......@@ -23,7 +23,7 @@ VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "
def _create_video_frames(num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width), indexing="ij")
data = []
for i in range(num_frames):
xc = float(i) / num_frames
......
......@@ -104,7 +104,7 @@ class AnchorGenerator(nn.Module):
# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
......@@ -222,7 +222,7 @@ class DefaultBoxGenerator(nn.Module):
shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment