Unverified Commit 66f29220 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix for AnchorGenerator when device switch happen (#1745)

* Fix AnchorGenerator if moving from one device to another

* Fixes for the test
parent 7bfbc81a
......@@ -226,6 +226,28 @@ class ModelTester(TestCase):
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
def test_fasterrcnn_switch_devices(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.cuda()
model.eval()
input_shape = (3, 300, 300)
x = torch.rand(input_shape, device='cuda')
model_input = [x]
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
# now switch to cpu and make sure it works
model.cpu()
x = x.cpu()
out_cpu = model([x])
self.assertTrue("boxes" in out_cpu[0])
self.assertTrue("scores" in out_cpu[0])
self.assertTrue("labels" in out_cpu[0])
for model_name in get_available_classification_models():
# for-loop bodies don't define scopes, so we have to save the variables
......
......@@ -90,6 +90,11 @@ class AnchorGenerator(nn.Module):
def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return
cell_anchors = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment