Commit e6196137 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Added progress flag to model getters (#875)

* added progress flag to model getters

* flake8

* bug fix

* backward commpability
parent 830df55d
import torch.nn as nn import torch.nn as nn
import torch.utils.model_zoo as model_zoo from .utils import load_state_dict_from_url
__all__ = ['AlexNet', 'alexnet'] __all__ = ['AlexNet', 'alexnet']
...@@ -48,14 +48,17 @@ class AlexNet(nn.Module): ...@@ -48,14 +48,17 @@ class AlexNet(nn.Module):
return x return x
def alexnet(pretrained=False, **kwargs): def alexnet(pretrained=False, progress=True, **kwargs):
r"""AlexNet model architecture from the r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = AlexNet(**kwargs) model = AlexNet(**kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) state_dict = load_state_dict_from_url(model_urls['alexnet'],
progress=progress)
model.load_state_dict(state_dict)
return model return model
...@@ -2,12 +2,11 @@ import re ...@@ -2,12 +2,11 @@ import re
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo from .utils import load_state_dict_from_url
from collections import OrderedDict from collections import OrderedDict
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = { model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
...@@ -22,17 +21,20 @@ class _DenseLayer(nn.Sequential): ...@@ -22,17 +21,20 @@ class _DenseLayer(nn.Sequential):
self.add_module('norm1', nn.BatchNorm2d(num_input_features)), self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)), self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False)), growth_rate, kernel_size=1, stride=1,
bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)), self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)), kernel_size=3, stride=1, padding=1,
bias=False)),
self.drop_rate = drop_rate self.drop_rate = drop_rate
def forward(self, x): def forward(self, x):
new_features = super(_DenseLayer, self).forward(x) new_features = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0: if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) new_features = F.dropout(new_features, p=self.drop_rate,
training=self.training)
return torch.cat([x, new_features], 1) return torch.cat([x, new_features], 1)
...@@ -40,7 +42,8 @@ class _DenseBlock(nn.Sequential): ...@@ -40,7 +42,8 @@ class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
super(_DenseBlock, self).__init__() super(_DenseBlock, self).__init__()
for i in range(num_layers): for i in range(num_layers):
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate,
bn_size, drop_rate)
self.add_module('denselayer%d' % (i + 1), layer) self.add_module('denselayer%d' % (i + 1), layer)
...@@ -75,7 +78,8 @@ class DenseNet(nn.Module): ...@@ -75,7 +78,8 @@ class DenseNet(nn.Module):
# First convolution # First convolution
self.features = nn.Sequential(OrderedDict([ self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)), ('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)), ('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
...@@ -85,11 +89,13 @@ class DenseNet(nn.Module): ...@@ -85,11 +89,13 @@ class DenseNet(nn.Module):
num_features = num_init_features num_features = num_init_features
for i, num_layers in enumerate(block_config): for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) bn_size=bn_size, growth_rate=growth_rate,
drop_rate=drop_rate)
self.features.add_module('denseblock%d' % (i + 1), block) self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1: if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) trans = _Transition(num_input_features=num_features,
num_output_features=num_features // 2)
self.features.add_module('transition%d' % (i + 1), trans) self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2 num_features = num_features // 2
...@@ -117,14 +123,15 @@ class DenseNet(nn.Module): ...@@ -117,14 +123,15 @@ class DenseNet(nn.Module):
return out return out
def _load_state_dict(model, model_url): def _load_state_dict(model, model_url, progress):
# '.'s are no longer allowed in module names, but previous _DenseLayer # '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used # They are also in the checkpoints in model_urls. This pattern is used
# to find such keys. # to find such keys.
pattern = re.compile( pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_url)
state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
res = pattern.match(key) res = pattern.match(key)
if res: if res:
...@@ -134,57 +141,57 @@ def _load_state_dict(model, model_url): ...@@ -134,57 +141,57 @@ def _load_state_dict(model, model_url):
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
def densenet121(pretrained=False, **kwargs): def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
**kwargs):
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
if pretrained:
_load_state_dict(model, model_urls[arch], progress)
return model
def densenet121(pretrained=False, progress=True, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
**kwargs) **kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet121'])
return model
def densenet169(pretrained=False, **kwargs): def densenet161(pretrained=False, progress=True, **kwargs):
r"""Densenet-169 model from r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
**kwargs) **kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet169'])
return model
def densenet201(pretrained=False, **kwargs): def densenet169(pretrained=False, progress=True, **kwargs):
r"""Densenet-201 model from r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
**kwargs) **kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet201'])
return model
def densenet161(pretrained=False, **kwargs): def densenet201(pretrained=False, progress=True, **kwargs):
r"""Densenet-161 model from r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
**kwargs) **kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet161'])
return model
...@@ -3,7 +3,7 @@ from collections import namedtuple ...@@ -3,7 +3,7 @@ from collections import namedtuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils import model_zoo from .utils import load_state_dict_from_url
__all__ = ['GoogLeNet', 'googlenet'] __all__ = ['GoogLeNet', 'googlenet']
...@@ -15,12 +15,13 @@ model_urls = { ...@@ -15,12 +15,13 @@ model_urls = {
_GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1']) _GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1'])
def googlenet(pretrained=False, **kwargs): def googlenet(pretrained=False, progress=True, **kwargs):
r"""GoogLeNet (Inception v1) model architecture from r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_. `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, adds two auxiliary branches that can improve training. aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True* Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it transform_input (bool): If True, preprocesses the input according to the method with which it
...@@ -38,7 +39,9 @@ def googlenet(pretrained=False, **kwargs): ...@@ -38,7 +39,9 @@ def googlenet(pretrained=False, **kwargs):
kwargs['aux_logits'] = True kwargs['aux_logits'] = True
kwargs['init_weights'] = False kwargs['init_weights'] = False
model = GoogLeNet(**kwargs) model = GoogLeNet(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['googlenet'])) state_dict = load_state_dict_from_url(model_urls['googlenet'],
progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
del model.aux1, model.aux2 del model.aux1, model.aux2
......
...@@ -2,7 +2,7 @@ from collections import namedtuple ...@@ -2,7 +2,7 @@ from collections import namedtuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo from .utils import load_state_dict_from_url
__all__ = ['Inception3', 'inception_v3'] __all__ = ['Inception3', 'inception_v3']
...@@ -16,7 +16,7 @@ model_urls = { ...@@ -16,7 +16,7 @@ model_urls = {
_InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits']) _InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits'])
def inception_v3(pretrained=False, **kwargs): def inception_v3(pretrained=False, progress=True, **kwargs):
r"""Inception v3 model architecture from r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_. `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
...@@ -26,6 +26,7 @@ def inception_v3(pretrained=False, **kwargs): ...@@ -26,6 +26,7 @@ def inception_v3(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, add an auxiliary branch that can improve training. aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True* Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it transform_input (bool): If True, preprocesses the input according to the method with which it
...@@ -40,7 +41,9 @@ def inception_v3(pretrained=False, **kwargs): ...@@ -40,7 +41,9 @@ def inception_v3(pretrained=False, **kwargs):
else: else:
original_aux_logits = True original_aux_logits = True
model = Inception3(**kwargs) model = Inception3(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'])) state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
del model.AuxLogits del model.AuxLogits
......
import torch.nn as nn import torch.nn as nn
import torch.utils.model_zoo as model_zoo from .utils import load_state_dict_from_url
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
...@@ -204,75 +204,79 @@ class ResNet(nn.Module): ...@@ -204,75 +204,79 @@ class ResNet(nn.Module):
return x return x
def resnet18(pretrained=False, **kwargs): def _resnet(arch, inplanes, planes, pretrained, progress, **kwargs):
model = ResNet(inplanes, planes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-18 model. """Constructs a ResNet-18 model.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
if pretrained: **kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False, **kwargs): def resnet34(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-34 model. """Constructs a ResNet-34 model.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
if pretrained: **kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False, **kwargs): def resnet50(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-50 model. """Constructs a ResNet-50 model.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
if pretrained: **kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101(pretrained=False, **kwargs): def resnet101(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-101 model. """Constructs a ResNet-101 model.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
if pretrained: **kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False, **kwargs): def resnet152(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-152 model. """Constructs a ResNet-152 model.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
if pretrained: **kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
def resnext50_32x4d(pretrained=False, **kwargs): def resnext50_32x4d(**kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs) kwargs['groups'] = 32
# if pretrained: kwargs['width_per_group'] = 4
# model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
return model pretrained=False, progress=True, **kwargs)
def resnext101_32x8d(pretrained=False, **kwargs): def resnext101_32x8d(**kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) kwargs['groups'] = 32
# if pretrained: kwargs['width_per_group'] = 8
# model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
return model pretrained=False, progress=True, **kwargs)
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.init as init import torch.nn.init as init
import torch.utils.model_zoo as model_zoo from .utils import load_state_dict_from_url
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = { model_urls = {
'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
...@@ -38,13 +36,10 @@ class Fire(nn.Module): ...@@ -38,13 +36,10 @@ class Fire(nn.Module):
class SqueezeNet(nn.Module): class SqueezeNet(nn.Module):
def __init__(self, version=1.0, num_classes=1000): def __init__(self, version='1_0', num_classes=1000):
super(SqueezeNet, self).__init__() super(SqueezeNet, self).__init__()
if version not in [1.0, 1.1]:
raise ValueError("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
self.num_classes = num_classes self.num_classes = num_classes
if version == 1.0: if version == '1_0':
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=7, stride=2), nn.Conv2d(3, 96, kernel_size=7, stride=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
...@@ -60,7 +55,7 @@ class SqueezeNet(nn.Module): ...@@ -60,7 +55,7 @@ class SqueezeNet(nn.Module):
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(512, 64, 256, 256), Fire(512, 64, 256, 256),
) )
else: elif version == '1_1':
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2), nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
...@@ -76,6 +71,13 @@ class SqueezeNet(nn.Module): ...@@ -76,6 +71,13 @@ class SqueezeNet(nn.Module):
Fire(384, 64, 256, 256), Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256), Fire(512, 64, 256, 256),
) )
else:
# FIXME: Is this needed? SqueezeNet should only be called from the
# FIXME: squeezenet1_x() functions
# FIXME: This checking is not done for the other models
raise ValueError("Unsupported SqueezeNet version {version}:"
"1_0 or 1_1 expected".format(version=version))
# Final convolution is initialized differently from the rest # Final convolution is initialized differently from the rest
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
...@@ -100,21 +102,29 @@ class SqueezeNet(nn.Module): ...@@ -100,21 +102,29 @@ class SqueezeNet(nn.Module):
return x.view(x.size(0), self.num_classes) return x.view(x.size(0), self.num_classes)
def squeezenet1_0(pretrained=False, **kwargs): def _squeezenet(version, pretrained, progress, **kwargs):
model = SqueezeNet(version, **kwargs)
if pretrained:
arch = 'squeezenet' + version
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def squeezenet1_0(pretrained=False, progress=True, **kwargs):
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size" accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper. <https://arxiv.org/abs/1602.07360>`_ paper.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = SqueezeNet(version=1.0, **kwargs) return _squeezenet('1_0', pretrained, progress, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0']))
return model
def squeezenet1_1(pretrained=False, **kwargs): def squeezenet1_1(pretrained=False, progress=True, **kwargs):
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_. <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
...@@ -122,8 +132,6 @@ def squeezenet1_1(pretrained=False, **kwargs): ...@@ -122,8 +132,6 @@ def squeezenet1_1(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
model = SqueezeNet(version=1.1, **kwargs) return _squeezenet('1_1', pretrained, progress, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))
return model
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
import torch.nn as nn import torch.nn as nn
import torch.utils.model_zoo as model_zoo from .utils import load_state_dict_from_url
__all__ = [ __all__ = [
...@@ -75,7 +75,7 @@ def make_layers(cfg, batch_norm=False): ...@@ -75,7 +75,7 @@ def make_layers(cfg, batch_norm=False):
return nn.Sequential(*layers) return nn.Sequential(*layers)
cfg = { cfgs = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
...@@ -83,113 +83,92 @@ cfg = { ...@@ -83,113 +83,92 @@ cfg = {
} }
def vgg11(pretrained=False, **kwargs): def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def vgg11(pretrained=False, progress=True, **kwargs):
"""VGG 11-layer model (configuration "A") """VGG 11-layer model (configuration "A")
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['A']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
return model
def vgg11_bn(pretrained=False, **kwargs): def vgg11_bn(pretrained=False, progress=True, **kwargs):
"""VGG 11-layer model (configuration "A") with batch normalization """VGG 11-layer model (configuration "A") with batch normalization
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
return model
def vgg13(pretrained=False, **kwargs): def vgg13(pretrained=False, progress=True, **kwargs):
"""VGG 13-layer model (configuration "B") """VGG 13-layer model (configuration "B")
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['B']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
return model
def vgg13_bn(pretrained=False, **kwargs): def vgg13_bn(pretrained=False, progress=True, **kwargs):
"""VGG 13-layer model (configuration "B") with batch normalization """VGG 13-layer model (configuration "B") with batch normalization
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
return model
def vgg16(pretrained=False, **kwargs): def vgg16(pretrained=False, progress=True, **kwargs):
"""VGG 16-layer model (configuration "D") """VGG 16-layer model (configuration "D")
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['D']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
return model
def vgg16_bn(pretrained=False, **kwargs): def vgg16_bn(pretrained=False, progress=True, **kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization """VGG 16-layer model (configuration "D") with batch normalization
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
return model
def vgg19(pretrained=False, **kwargs): def vgg19(pretrained=False, progress=True, **kwargs):
"""VGG 19-layer model (configuration "E") """VGG 19-layer model (configuration "E")
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
return model
def vgg19_bn(pretrained=False, **kwargs): def vgg19_bn(pretrained=False, progress=True, **kwargs):
"""VGG 19-layer model (configuration 'E') with batch normalization """VGG 19-layer model (configuration 'E') with batch normalization
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment