Commit 50ea596e authored by Sepehr Sameni's avatar Sepehr Sameni Committed by Francisco Massa
Browse files

make auxiliary heads in pretrained models optional (#828)

* add aux_logits support to inception

it is related to pytorch/pytorch#18668

* instantiate InceptionAux only when requested

it is related to pytorch/pytorch#18668

* revert googlenet

* support and aux_logits in pretrained models

* return namedtuple when aux_logit is True
parent f566fac8
import warnings
from collections import namedtuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -10,6 +12,8 @@ model_urls = { ...@@ -10,6 +12,8 @@ model_urls = {
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
} }
_GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1'])
def googlenet(pretrained=False, **kwargs): def googlenet(pretrained=False, **kwargs):
r"""GoogLeNet (Inception v1) model architecture from r"""GoogLeNet (Inception v1) model architecture from
...@@ -18,7 +22,7 @@ def googlenet(pretrained=False, **kwargs): ...@@ -18,7 +22,7 @@ def googlenet(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
aux_logits (bool): If True, adds two auxiliary branches that can improve training. aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Automatically set to False if 'pretrained' is True. Default: *True* Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False* was trained on ImageNet. Default: *False*
""" """
...@@ -27,9 +31,16 @@ def googlenet(pretrained=False, **kwargs): ...@@ -27,9 +31,16 @@ def googlenet(pretrained=False, **kwargs):
kwargs['transform_input'] = True kwargs['transform_input'] = True
if 'aux_logits' not in kwargs: if 'aux_logits' not in kwargs:
kwargs['aux_logits'] = False kwargs['aux_logits'] = False
if kwargs['aux_logits']:
warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them')
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = True
kwargs['init_weights'] = False kwargs['init_weights'] = False
model = GoogLeNet(**kwargs) model = GoogLeNet(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['googlenet'])) model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
if not original_aux_logits:
model.aux_logits = False
del model.aux1, model.aux2
return model return model
return GoogLeNet(**kwargs) return GoogLeNet(**kwargs)
...@@ -62,8 +73,9 @@ class GoogLeNet(nn.Module): ...@@ -62,8 +73,9 @@ class GoogLeNet(nn.Module):
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
self.aux1 = InceptionAux(512, num_classes) if aux_logits:
self.aux2 = InceptionAux(528, num_classes) self.aux1 = InceptionAux(512, num_classes)
self.aux2 = InceptionAux(528, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.2) self.dropout = nn.Dropout(0.2)
...@@ -141,7 +153,7 @@ class GoogLeNet(nn.Module): ...@@ -141,7 +153,7 @@ class GoogLeNet(nn.Module):
x = self.fc(x) x = self.fc(x)
# N x 1000 (num_classes) # N x 1000 (num_classes)
if self.training and self.aux_logits: if self.training and self.aux_logits:
return aux1, aux2, x return _GoogLeNetOuputs(x, aux2, aux1)
return x return x
......
from collections import namedtuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -12,6 +13,8 @@ model_urls = { ...@@ -12,6 +13,8 @@ model_urls = {
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
} }
_InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits'])
def inception_v3(pretrained=False, **kwargs): def inception_v3(pretrained=False, **kwargs):
r"""Inception v3 model architecture from r"""Inception v3 model architecture from
...@@ -23,14 +26,24 @@ def inception_v3(pretrained=False, **kwargs): ...@@ -23,14 +26,24 @@ def inception_v3(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False* was trained on ImageNet. Default: *False*
""" """
if pretrained: if pretrained:
if 'transform_input' not in kwargs: if 'transform_input' not in kwargs:
kwargs['transform_input'] = True kwargs['transform_input'] = True
if 'aux_logits' in kwargs:
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = True
else:
original_aux_logits = True
model = Inception3(**kwargs) model = Inception3(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'])) model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google']))
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
return model return model
return Inception3(**kwargs) return Inception3(**kwargs)
...@@ -131,7 +144,7 @@ class Inception3(nn.Module): ...@@ -131,7 +144,7 @@ class Inception3(nn.Module):
x = self.fc(x) x = self.fc(x)
# N x 1000 (num_classes) # N x 1000 (num_classes)
if self.training and self.aux_logits: if self.training and self.aux_logits:
return x, aux return _InceptionOuputs(x, aux)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment