Commit efd14a44 authored by Baumgartner, Michael's avatar Baumgartner, Michael
Browse files

pt updates

parent 6c286a84
......@@ -146,7 +146,7 @@ class AnchorGenerator2D(torch.nn.Module):
shifts_x = torch.arange(0, size0, dtype=torch.float, device=device) * stride0
shifts_y = torch.arange(0, size1, dtype=torch.float, device=device) * stride1
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
......@@ -361,7 +361,7 @@ class AnchorGenerator3D(AnchorGenerator2D):
shifts_y = torch.arange(0, size1, dtype=dtype, device=device) * stride1
shifts_z = torch.arange(0, size2, dtype=dtype, device=device) * stride2
shift_x, shift_y, shift_z = torch.meshgrid(shifts_x, shifts_y, shifts_z)
shift_x, shift_y, shift_z = torch.meshgrid(shifts_x, shifts_y, shifts_z, indexing="ij")
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shift_z = shift_z.reshape(-1)
......
......@@ -364,7 +364,7 @@ class BaseRetinaNet(AbstractModel):
keep_idxs = probs > self.score_thresh
probs, idx = probs[keep_idxs], idx[keep_idxs]
anchor_idxs = idx // self.num_foreground_classes
anchor_idxs = torch.div(idx, self.num_foreground_classes, rounding_mode="floor")
labels = idx % self.num_foreground_classes
boxes = boxes[anchor_idxs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment