Unverified Commit 5d5d425d authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix anchor dtype in AnchorGenerator (#1341)

* Make AnchorGenerator support half precision

* Add test for fasterrcnn with double

* convert gt_boxes to right dtype
parent d7e88fb2
......@@ -146,6 +146,20 @@ class Tester(unittest.TestCase):
out = model(x)
self.assertEqual(out.shape[-1], 1000)
def test_fasterrcnn_double(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.double()
model.eval()
input_shape = (3, 300, 300)
x = torch.rand(input_shape, dtype=torch.float64)
model_input = [x]
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
for model_name in get_available_classification_models():
# for-loop bodies don't define scopes, so we have to save the variables
......
......@@ -444,7 +444,8 @@ class RoIHeads(torch.nn.Module):
def select_training_samples(self, proposals, targets):
self.check_targets(targets)
gt_boxes = [t["boxes"] for t in targets]
dtype = proposals[0].dtype
gt_boxes = [t["boxes"].to(dtype) for t in targets]
gt_labels = [t["labels"] for t in targets]
# append ground-truth bboxes to propos
......
......@@ -49,9 +49,9 @@ class AnchorGenerator(nn.Module):
self._cache = {}
@staticmethod
def generate_anchors(scales, aspect_ratios, device="cpu"):
scales = torch.as_tensor(scales, dtype=torch.float32, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=torch.float32, device=device)
def generate_anchors(scales, aspect_ratios, dtype=torch.float32, device="cpu"):
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios
......@@ -61,13 +61,14 @@ class AnchorGenerator(nn.Module):
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()
def set_cell_anchors(self, device):
def set_cell_anchors(self, dtype, device):
if self.cell_anchors is not None:
return self.cell_anchors
cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
......@@ -114,7 +115,8 @@ class AnchorGenerator(nn.Module):
grid_sizes = tuple([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
strides = tuple((image_size[0] / g[0], image_size[1] / g[1]) for g in grid_sizes)
self.set_cell_anchors(feature_maps[0].device)
dtype, device = feature_maps[0].dtype, feature_maps[0].device
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = []
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment