Commit 50ea596e authored by Sepehr Sameni's avatar Sepehr Sameni Committed by Francisco Massa
Browse files

make auxiliary heads in pretrained models optional (#828)

* add aux_logits support to inception

it is related to pytorch/pytorch#18668

* instantiate InceptionAux only when requested

it is related to pytorch/pytorch#18668

* revert googlenet

* support and aux_logits in pretrained models

* return namedtuple when aux_logit is True
parent f566fac8
import warnings
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -10,6 +12,8 @@ model_urls = {
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}
_GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1'])
def googlenet(pretrained=False, **kwargs):
r"""GoogLeNet (Inception v1) model architecture from
......@@ -18,7 +22,7 @@ def googlenet(pretrained=False, **kwargs):
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Automatically set to False if 'pretrained' is True. Default: *True*
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
......@@ -27,9 +31,16 @@ def googlenet(pretrained=False, **kwargs):
kwargs['transform_input'] = True
if 'aux_logits' not in kwargs:
kwargs['aux_logits'] = False
if kwargs['aux_logits']:
warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them')
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
model = GoogLeNet(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
if not original_aux_logits:
model.aux_logits = False
del model.aux1, model.aux2
return model
return GoogLeNet(**kwargs)
......@@ -62,8 +73,9 @@ class GoogLeNet(nn.Module):
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
self.aux1 = InceptionAux(512, num_classes)
self.aux2 = InceptionAux(528, num_classes)
if aux_logits:
self.aux1 = InceptionAux(512, num_classes)
self.aux2 = InceptionAux(528, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.2)
......@@ -141,7 +153,7 @@ class GoogLeNet(nn.Module):
x = self.fc(x)
# N x 1000 (num_classes)
if self.training and self.aux_logits:
return aux1, aux2, x
return _GoogLeNetOuputs(x, aux2, aux1)
return x
......
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -12,6 +13,8 @@ model_urls = {
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}
_InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits'])
def inception_v3(pretrained=False, **kwargs):
r"""Inception v3 model architecture from
......@@ -23,14 +26,24 @@ def inception_v3(pretrained=False, **kwargs):
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' in kwargs:
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = True
else:
original_aux_logits = True
model = Inception3(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google']))
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
return model
return Inception3(**kwargs)
......@@ -131,7 +144,7 @@ class Inception3(nn.Module):
x = self.fc(x)
# N x 1000 (num_classes)
if self.training and self.aux_logits:
return x, aux
return _InceptionOuputs(x, aux)
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment