Unverified Commit 5d5d425d authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix anchor dtype in AnchorGenerator (#1341)

* Make AnchorGenerator support half precision

* Add test for fasterrcnn with double

* convert gt_boxes to right dtype
parent d7e88fb2
...@@ -146,6 +146,20 @@ class Tester(unittest.TestCase): ...@@ -146,6 +146,20 @@ class Tester(unittest.TestCase):
out = model(x) out = model(x)
self.assertEqual(out.shape[-1], 1000) self.assertEqual(out.shape[-1], 1000)
def test_fasterrcnn_double(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.double()
model.eval()
input_shape = (3, 300, 300)
x = torch.rand(input_shape, dtype=torch.float64)
model_input = [x]
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
for model_name in get_available_classification_models(): for model_name in get_available_classification_models():
# for-loop bodies don't define scopes, so we have to save the variables # for-loop bodies don't define scopes, so we have to save the variables
......
...@@ -444,7 +444,8 @@ class RoIHeads(torch.nn.Module): ...@@ -444,7 +444,8 @@ class RoIHeads(torch.nn.Module):
def select_training_samples(self, proposals, targets): def select_training_samples(self, proposals, targets):
self.check_targets(targets) self.check_targets(targets)
gt_boxes = [t["boxes"] for t in targets] dtype = proposals[0].dtype
gt_boxes = [t["boxes"].to(dtype) for t in targets]
gt_labels = [t["labels"] for t in targets] gt_labels = [t["labels"] for t in targets]
# append ground-truth bboxes to propos # append ground-truth bboxes to propos
......
...@@ -49,9 +49,9 @@ class AnchorGenerator(nn.Module): ...@@ -49,9 +49,9 @@ class AnchorGenerator(nn.Module):
self._cache = {} self._cache = {}
@staticmethod @staticmethod
def generate_anchors(scales, aspect_ratios, device="cpu"): def generate_anchors(scales, aspect_ratios, dtype=torch.float32, device="cpu"):
scales = torch.as_tensor(scales, dtype=torch.float32, device=device) scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=torch.float32, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios) h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios w_ratios = 1 / h_ratios
...@@ -61,13 +61,14 @@ class AnchorGenerator(nn.Module): ...@@ -61,13 +61,14 @@ class AnchorGenerator(nn.Module):
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round() return base_anchors.round()
def set_cell_anchors(self, device): def set_cell_anchors(self, dtype, device):
if self.cell_anchors is not None: if self.cell_anchors is not None:
return self.cell_anchors return self.cell_anchors
cell_anchors = [ cell_anchors = [
self.generate_anchors( self.generate_anchors(
sizes, sizes,
aspect_ratios, aspect_ratios,
dtype,
device device
) )
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
...@@ -114,7 +115,8 @@ class AnchorGenerator(nn.Module): ...@@ -114,7 +115,8 @@ class AnchorGenerator(nn.Module):
grid_sizes = tuple([feature_map.shape[-2:] for feature_map in feature_maps]) grid_sizes = tuple([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:] image_size = image_list.tensors.shape[-2:]
strides = tuple((image_size[0] / g[0], image_size[1] / g[1]) for g in grid_sizes) strides = tuple((image_size[0] / g[0], image_size[1] / g[1]) for g in grid_sizes)
self.set_cell_anchors(feature_maps[0].device) dtype, device = feature_maps[0].dtype, feature_maps[0].device
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = [] anchors = []
for i, (image_height, image_width) in enumerate(image_list.image_sizes): for i, (image_height, image_width) in enumerate(image_list.image_sizes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment