Unverified Commit 015eb46b authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Pass indexing param to meshgrid to avoid warning (#4645)

parent 29dcf767
...@@ -23,7 +23,7 @@ VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", " ...@@ -23,7 +23,7 @@ VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "
def _create_video_frames(num_frames, height, width): def _create_video_frames(num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width), indexing="ij")
data = [] data = []
for i in range(num_frames): for i in range(num_frames):
xc = float(i) / num_frames xc = float(i) / num_frames
......
...@@ -104,7 +104,7 @@ class AnchorGenerator(nn.Module): ...@@ -104,7 +104,7 @@ class AnchorGenerator(nn.Module):
# For output anchor, compute [x_center, y_center, x_center, y_center] # For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
shift_x = shift_x.reshape(-1) shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1) shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
...@@ -222,7 +222,7 @@ class DefaultBoxGenerator(nn.Module): ...@@ -222,7 +222,7 @@ class DefaultBoxGenerator(nn.Module):
shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype) shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype) shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
shift_x = shift_x.reshape(-1) shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1) shift_y = shift_y.reshape(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment