Unverified Commit de140c19 authored by eellison's avatar eellison Committed by GitHub
Browse files

[JIT] fix googlenet no aux logits (#1949)



* fix googlenet no aux logits

* small fix
Co-authored-by: default avatareellison <eellison@fb.com>
parent b6f28ec1
......@@ -226,6 +226,10 @@ class ModelTester(TestCase):
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
def test_googlenet_eval(self):
m = torch.jit.script(models.googlenet(pretrained=True).eval())
self.checkModule(m, "googlenet", torch.rand(1, 3, 224, 224))
@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
def test_fasterrcnn_switch_devices(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
......
......@@ -54,7 +54,8 @@ def googlenet(pretrained=False, progress=True, **kwargs):
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.aux1, model.aux2
model.aux1 = None
model.aux2 = None
return model
return GoogLeNet(**kwargs)
......@@ -99,6 +100,9 @@ class GoogLeNet(nn.Module):
if aux_logits:
self.aux1 = inception_aux_block(512, num_classes)
self.aux2 = inception_aux_block(528, num_classes)
else:
self.aux1 = None
self.aux2 = None
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.2)
......@@ -151,11 +155,10 @@ class GoogLeNet(nn.Module):
# N x 480 x 14 x 14
x = self.inception4a(x)
# N x 512 x 14 x 14
aux_defined = self.training and self.aux_logits
if aux_defined:
aux1 = torch.jit.annotate(Optional[Tensor], None)
if self.aux1 is not None:
if self.training:
aux1 = self.aux1(x)
else:
aux1 = None
x = self.inception4b(x)
# N x 512 x 14 x 14
......@@ -163,10 +166,10 @@ class GoogLeNet(nn.Module):
# N x 512 x 14 x 14
x = self.inception4d(x)
# N x 528 x 14 x 14
if aux_defined:
aux2 = torch.jit.annotate(Optional[Tensor], None)
if self.aux2 is not None:
if self.training:
aux2 = self.aux2(x)
else:
aux2 = None
x = self.inception4e(x)
# N x 832 x 14 x 14
......@@ -208,7 +211,6 @@ class GoogLeNet(nn.Module):
class Inception(nn.Module):
__constants__ = ['branch2', 'branch3', 'branch4']
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj,
conv_block=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment