You need to sign in or sign up before continuing.
Unverified Commit 66f29220 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix for AnchorGenerator when device switch happen (#1745)

* Fix AnchorGenerator if moving from one device to another

* Fixes for the test
parent 7bfbc81a
......@@ -226,6 +226,28 @@ class ModelTester(TestCase):
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
def test_fasterrcnn_switch_devices(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.cuda()
model.eval()
input_shape = (3, 300, 300)
x = torch.rand(input_shape, device='cuda')
model_input = [x]
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
# now switch to cpu and make sure it works
model.cpu()
x = x.cpu()
out_cpu = model([x])
self.assertTrue("boxes" in out_cpu[0])
self.assertTrue("scores" in out_cpu[0])
self.assertTrue("labels" in out_cpu[0])
for model_name in get_available_classification_models():
# for-loop bodies don't define scopes, so we have to save the variables
......
......@@ -90,7 +90,12 @@ class AnchorGenerator(nn.Module):
def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
return
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return
cell_anchors = [
self.generate_anchors(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment