Commit efd14a44 authored by Baumgartner, Michael's avatar Baumgartner, Michael
Browse files

pt updates

parent 6c286a84
...@@ -146,7 +146,7 @@ class AnchorGenerator2D(torch.nn.Module): ...@@ -146,7 +146,7 @@ class AnchorGenerator2D(torch.nn.Module):
shifts_x = torch.arange(0, size0, dtype=torch.float, device=device) * stride0 shifts_x = torch.arange(0, size0, dtype=torch.float, device=device) * stride0
shifts_y = torch.arange(0, size1, dtype=torch.float, device=device) * stride1 shifts_y = torch.arange(0, size1, dtype=torch.float, device=device) * stride1
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
shift_x = shift_x.reshape(-1) shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1) shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
...@@ -361,7 +361,7 @@ class AnchorGenerator3D(AnchorGenerator2D): ...@@ -361,7 +361,7 @@ class AnchorGenerator3D(AnchorGenerator2D):
shifts_y = torch.arange(0, size1, dtype=dtype, device=device) * stride1 shifts_y = torch.arange(0, size1, dtype=dtype, device=device) * stride1
shifts_z = torch.arange(0, size2, dtype=dtype, device=device) * stride2 shifts_z = torch.arange(0, size2, dtype=dtype, device=device) * stride2
shift_x, shift_y, shift_z = torch.meshgrid(shifts_x, shifts_y, shifts_z) shift_x, shift_y, shift_z = torch.meshgrid(shifts_x, shifts_y, shifts_z, indexing="ij")
shift_x = shift_x.reshape(-1) shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1) shift_y = shift_y.reshape(-1)
shift_z = shift_z.reshape(-1) shift_z = shift_z.reshape(-1)
......
...@@ -364,7 +364,7 @@ class BaseRetinaNet(AbstractModel): ...@@ -364,7 +364,7 @@ class BaseRetinaNet(AbstractModel):
keep_idxs = probs > self.score_thresh keep_idxs = probs > self.score_thresh
probs, idx = probs[keep_idxs], idx[keep_idxs] probs, idx = probs[keep_idxs], idx[keep_idxs]
anchor_idxs = idx // self.num_foreground_classes anchor_idxs = torch.div(idx, self.num_foreground_classes, rounding_mode="floor")
labels = idx % self.num_foreground_classes labels = idx % self.num_foreground_classes
boxes = boxes[anchor_idxs] boxes = boxes[anchor_idxs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment