"vscode:/vscode.git/clone" did not exist on "03566d8689ebf2b14b1db99f15043f50071f7719"
Commit 69b28578 authored by Dmitry Belenko's avatar Dmitry Belenko Committed by Francisco Massa
Browse files

Implementation of the MNASNet family of models (#829)

* Add initial mnasnet impl

* Remove all type hints, comply with PyTorch overall style

* Expose models

* Remove avgpool from features() and add separately

* Fix python3-only stuff, replace subclasses with functions

* fix __all__

* Fix typo

* Remove conditional dropout

* Make dropout functional

* Addressing @fmassa's feedback, round 1

* Replaced adaptive avgpool with mean on H and W to prevent collapsing the batch dimension

* Partially address feedback

* YAPF

* Removed redundant class vars

* Update urls to releases

* Add information to models.rst

* Replace init with kaiming_normal_ in fan-out mode

* Use load_state_dict_from_url
parent 12fab3a2
...@@ -24,6 +24,7 @@ architectures for image classification: ...@@ -24,6 +24,7 @@ architectures for image classification:
- `ShuffleNet`_ v2 - `ShuffleNet`_ v2
- `MobileNet`_ v2 - `MobileNet`_ v2
- `ResNeXt`_ - `ResNeXt`_
- `MNASNet`_
You can construct a model with random weights by calling its constructor: You can construct a model with random weights by calling its constructor:
...@@ -40,6 +41,7 @@ You can construct a model with random weights by calling its constructor: ...@@ -40,6 +41,7 @@ You can construct a model with random weights by calling its constructor:
shufflenet = models.shufflenet_v2_x1_0() shufflenet = models.shufflenet_v2_x1_0()
mobilenet = models.mobilenet_v2() mobilenet = models.mobilenet_v2()
resnext50_32x4d = models.resnext50_32x4d() resnext50_32x4d = models.resnext50_32x4d()
mnasnet = models.mnasnet1_0()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``: These can be constructed by passing ``pretrained=True``:
...@@ -57,6 +59,7 @@ These can be constructed by passing ``pretrained=True``: ...@@ -57,6 +59,7 @@ These can be constructed by passing ``pretrained=True``:
shufflenet = models.shufflenet_v2_x1_0(pretrained=True) shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True) mobilenet = models.mobilenet_v2(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True) resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
Instancing a pre-trained model will download its weights to a cache directory. Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
...@@ -111,6 +114,7 @@ ShuffleNet V2 30.64 11.68 ...@@ -111,6 +114,7 @@ ShuffleNet V2 30.64 11.68
MobileNet V2 28.12 9.71 MobileNet V2 28.12 9.71
ResNeXt-50-32x4d 22.38 6.30 ResNeXt-50-32x4d 22.38 6.30
ResNeXt-101-32x8d 20.69 5.47 ResNeXt-101-32x8d 20.69 5.47
MNASNet 1.0 26.49 8.456
================================ ============= ============= ================================ ============= =============
...@@ -124,6 +128,7 @@ ResNeXt-101-32x8d 20.69 5.47 ...@@ -124,6 +128,7 @@ ResNeXt-101-32x8d 20.69 5.47
.. _ShuffleNet: https://arxiv.org/abs/1807.11164 .. _ShuffleNet: https://arxiv.org/abs/1807.11164
.. _MobileNet: https://arxiv.org/abs/1801.04381 .. _MobileNet: https://arxiv.org/abs/1801.04381
.. _ResNeXt: https://arxiv.org/abs/1611.05431 .. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626
.. currentmodule:: torchvision.models .. currentmodule:: torchvision.models
...@@ -197,6 +202,14 @@ ResNext ...@@ -197,6 +202,14 @@ ResNext
.. autofunction:: resnext50_32x4d .. autofunction:: resnext50_32x4d
.. autofunction:: resnext101_32x8d .. autofunction:: resnext101_32x8d
MNASNet
--------
.. autofunction:: mnasnet0_5
.. autofunction:: mnasnet0_75
.. autofunction:: mnasnet1_0
.. autofunction:: mnasnet1_3
Semantic Segmentation Semantic Segmentation
===================== =====================
......
...@@ -6,6 +6,7 @@ from .inception import * ...@@ -6,6 +6,7 @@ from .inception import *
from .densenet import * from .densenet import *
from .googlenet import * from .googlenet import *
from .mobilenet import * from .mobilenet import *
from .mnasnet import *
from .shufflenetv2 import * from .shufflenetv2 import *
from . import segmentation from . import segmentation
from . import detection from . import detection
import math
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
_MODEL_URLS = {
"mnasnet0_5":
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
"mnasnet0_75": None,
"mnasnet1_0":
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet1.0_top1_73.512-f206786ef8.pth",
"mnasnet1_3": None
}
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
_BN_MOMENTUM = 1 - 0.9997
class _InvertedResidual(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
bn_momentum=0.1):
super(_InvertedResidual, self).__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5]
mid_ch = in_ch * expansion_factor
self.apply_residual = (in_ch == out_ch and stride == 1)
self.layers = nn.Sequential(
# Pointwise
nn.Conv2d(in_ch, mid_ch, 1, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Linear pointwise. Note that there's no activation.
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum))
def forward(self, input):
if self.apply_residual:
return self.layers(input) + input
else:
return self.layers(input)
def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
bn_momentum):
""" Creates a stack of inverted residuals. """
assert repeats >= 1
# First one has no skip, because feature map size changes.
first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor,
bn_momentum=bn_momentum)
remaining = []
for _ in range(1, repeats):
remaining.append(
_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor,
bn_momentum=bn_momentum))
return nn.Sequential(first, *remaining)
def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
bias, will round up, unless the number is no more than 10% greater than the
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
assert 0.0 < round_up_bias < 1.0
new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
return new_val if new_val >= round_up_bias * val else new_val + divisor
def _scale_depths(depths, alpha):
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
class MNASNet(torch.nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf.
>>> model = MNASNet(1000, 1.0)
>>> x = torch.rand(1, 3, 224, 224)
>>> y = model(x)
>>> y.dim()
1
>>> y.nelement()
1000
"""
def __init__(self, alpha, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__()
depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha)
layers = [
# First layer: regular conv.
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
# Depthwise separable, no skip.
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
# MNASNet blocks: stacks of inverted residuals.
_stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM),
_stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM),
_stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM),
_stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM),
# Final mapping to classifier input.
nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
]
self.layers = nn.Sequential(*layers)
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True),
nn.Linear(1280, num_classes))
self._initialize_weights()
def forward(self, x):
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
return self.classifier(x)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out",
nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0.01)
nn.init.zeros_(m.bias)
def _load_pretrained(model_name, model):
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError(
"No checkpoint is available for model type {}".format(model_name))
checkpoint_url = _MODEL_URLS[model_name]
model.load_state_dict(load_state_dict_from_url(checkpoint_url))
def mnasnet0_5(pretrained=False, **kwargs):
""" MNASNet with depth multiplier of 0.5. """
model = MNASNet(0.5, **kwargs)
if pretrained:
_load_pretrained("mnasnet0_5", model)
return model
def mnasnet0_75(pretrained=False, **kwargs):
""" MNASNet with depth multiplier of 0.75. """
model = MNASNet(0.75, **kwargs)
if pretrained:
_load_pretrained("mnasnet0_75", model)
return model
def mnasnet1_0(pretrained=False, **kwargs):
""" MNASNet with depth multiplier of 1.0. """
model = MNASNet(1.0, **kwargs)
if pretrained:
_load_pretrained("mnasnet1_0", model)
return model
def mnasnet1_3(pretrained=False, **kwargs):
""" MNASNet with depth multiplier of 1.3. """
model = MNASNet(1.3, **kwargs)
if pretrained:
_load_pretrained("mnasnet1_3", model)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment