Commit d44273b4 authored by Marat Dukhan's avatar Marat Dukhan Committed by Adam Paszke
Browse files

SqueezeNet 1.0 and 1.1 models (#49)

* Add SqueezeNet 1.0 and 1.1 models
* Selectively avoid inplace in SqueezeNet
* Use Glorot uniform initialization in SqueezeNet
* Make all ReLU in SqueezeNet in-place
* Add pretrained SqueezeNet 1.0 and 1.1
* Minor fixes in SqueezeNet models
parent 98f59ebc
...@@ -4,6 +4,7 @@ architectures: ...@@ -4,6 +4,7 @@ architectures:
- `AlexNet`_ - `AlexNet`_
- `VGG`_ - `VGG`_
- `ResNet`_ - `ResNet`_
- `SqueezeNet`_
You can construct a model with random weights by calling its constructor: You can construct a model with random weights by calling its constructor:
...@@ -12,6 +13,7 @@ You can construct a model with random weights by calling its constructor: ...@@ -12,6 +13,7 @@ You can construct a model with random weights by calling its constructor:
import torchvision.models as models import torchvision.models as models
resnet18 = models.resnet18() resnet18 = models.resnet18()
alexnet = models.alexnet() alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
We provide pre-trained models for the ResNet variants and AlexNet, using the We provide pre-trained models for the ResNet variants and AlexNet, using the
PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
...@@ -26,8 +28,10 @@ PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing ...@@ -26,8 +28,10 @@ PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
.. _AlexNet: https://arxiv.org/abs/1404.5997 .. _AlexNet: https://arxiv.org/abs/1404.5997
.. _VGG: https://arxiv.org/abs/1409.1556 .. _VGG: https://arxiv.org/abs/1409.1556
.. _ResNet: https://arxiv.org/abs/1512.03385 .. _ResNet: https://arxiv.org/abs/1512.03385
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
""" """
from .alexnet import * from .alexnet import *
from .resnet import * from .resnet import *
from .vgg import * from .vgg import *
from .squeezenet import *
import math
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = {
'squeezenet1_0': 'https://s3.amazonaws.com/pytorch/models/squeezenet1_0-a815701f.pth',
'squeezenet1_1': 'https://s3.amazonaws.com/pytorch/models/squeezenet1_1-f364aa15.pth',
}
class Fire(nn.Module):
def __init__(self, inplanes, squeeze_planes,
expand1x1_planes, expand3x3_planes):
super(Fire, self).__init__()
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
self.squeeze_activation = nn.ReLU(inplace=True)
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
kernel_size=1)
self.expand1x1_activation = nn.ReLU(inplace=True)
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
kernel_size=3, padding=1)
self.expand3x3_activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.squeeze_activation(self.squeeze(x))
return torch.cat([
self.expand1x1_activation(self.expand1x1(x)),
self.expand3x3_activation(self.expand3x3(x))
], 1)
class SqueezeNet(nn.Module):
def __init__(self, version=1.0, num_classes=1000):
super(SqueezeNet, self).__init__()
if version not in [1.0, 1.1]:
raise ValueError("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
self.num_classes = num_classes
if version == 1.0:
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(96, 16, 64, 64),
Fire(128, 16, 64, 64),
Fire(128, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 32, 128, 128),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(512, 64, 256, 256),
)
else:
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(64, 16, 64, 64),
Fire(128, 16, 64, 64),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(128, 32, 128, 128),
Fire(256, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256),
)
# Final convolution is initialized differently form the rest
final_conv = nn.Conv2d(512, num_classes, kernel_size=1)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
final_conv,
nn.ReLU(inplace=True),
nn.AvgPool2d(13)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
gain = 2.0
if m is final_conv:
m.weight.data.normal_(0, 0.01)
else:
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
u = math.sqrt(3.0 * gain / fan_in)
m.weight.data.uniform_(-u, u)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x.view(x.size(0), self.num_classes)
def squeezenet1_0(pretrained=False):
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = SqueezeNet(version=1.0)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0']))
return model
def squeezenet1_1(pretrained=False):
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
than SqueezeNet 1.0, without sacrificing accuracy.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = SqueezeNet(version=1.1)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment