Unverified Commit a21ed3af authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Make Inception V3 torch-scriptable (#2976)

* Making quantized inception torchscriptable.

* Adding a test.

* Fix mypy warning.
parent d6bc625a
......@@ -339,6 +339,18 @@ class ModelTester(TestCase):
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))
def test_inceptionv3_eval(self):
# replacement for models.inception_v3(pretrained=True) that does not download weights
kwargs = {}
kwargs['transform_input'] = True
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
model = models.Inception3(**kwargs)
model.aux_logits = False
model.AuxLogits = None
m = torch.jit.script(model.eval())
self.checkModule(m, "inception_v3", torch.rand(1, 3, 299, 299))
def test_fasterrcnn_double(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.double()
......
......@@ -55,7 +55,7 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
model.AuxLogits = None
return model
return Inception3(**kwargs)
......@@ -108,6 +108,7 @@ class Inception3(nn.Module):
self.Mixed_6c = inception_c(768, channels_7x7=160)
self.Mixed_6d = inception_c(768, channels_7x7=160)
self.Mixed_6e = inception_c(768, channels_7x7=192)
self.AuxLogits: Optional[nn.Module] = None
if aux_logits:
self.AuxLogits = inception_aux(768, num_classes)
self.Mixed_7a = inception_d(768)
......@@ -170,11 +171,10 @@ class Inception3(nn.Module):
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
aux_defined = self.training and self.aux_logits
if aux_defined:
aux = self.AuxLogits(x)
else:
aux = None
aux = torch.jit.annotate(Optional[Tensor], None)
if self.AuxLogits is not None:
if self.training:
aux = self.AuxLogits(x)
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
......
......@@ -67,7 +67,7 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
if quantize:
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
model.AuxLogits = None
model_url = quant_model_urls['inception_v3_google' + '_' + backend]
else:
model_url = inception_module.model_urls['inception_v3_google']
......@@ -80,7 +80,7 @@ def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
if not quantize:
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
model.AuxLogits = None
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment