Commit 367e8514 authored by Dmitry Belenko's avatar Dmitry Belenko Committed by Francisco Massa
Browse files

Bugfix for MNASNet (#1224)

* Add initial mnasnet impl

* Remove all type hints, comply with PyTorch overall style

* Expose models

* Remove avgpool from features() and add separately

* Fix python3-only stuff, replace subclasses with functions

* fix __all__

* Fix typo

* Remove conditional dropout

* Make dropout functional

* Addressing @fmassa's feedback, round 1

* Replaced adaptive avgpool with mean on H and W to prevent collapsing the batch dimension

* Partially address feedback

* YAPF

* Removed redundant class vars

* Update urls to releases

* Add information to models.rst

* Replace init with kaiming_normal_ in fan-out mode

* Use load_state_dict_from_url

* Fix depth scaling on first 2 layers

* Restore initialization

* Match reference implementation initialization for dense layer

* Meant to use Kaiming

* Remove spurious relu

* Point to the newest 0.5 checkpoint

* Latest pretrained checkpoint

* Restore 1.0 checkpoint

* YAPF

* Implement backwards compat as suggested by Soumith

* Update checkpoint URL

* Move warnings up

* Record a couple more function parameters

* Update comment

* Set the correct version such that if the BC-patched model is saved, it could be reloaded with BC patching again

* Set a member var, not class var

* Update mnasnet.py

Remove unused member var as per review.

* Update the path to weights
parent 3394c0f5
import math
import warnings
import torch
import torch.nn as nn
......@@ -8,7 +9,7 @@ __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
_MODEL_URLS = {
"mnasnet0_5":
"https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
"https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
"mnasnet0_75": None,
"mnasnet1_0":
"https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
......@@ -74,14 +75,16 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
return new_val if new_val >= round_up_bias * val else new_val + divisor
def _scale_depths(depths, alpha):
def _get_depths(alpha):
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """
depths = [32, 16, 24, 40, 80, 96, 192, 320]
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
class MNASNet(torch.nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf.
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
>>> model = MNASNet(1000, 1.0)
>>> x = torch.rand(1, 3, 224, 224)
>>> y = model(x)
......@@ -90,30 +93,36 @@ class MNASNet(torch.nn.Module):
>>> y.nelement()
1000
"""
# Version 2 adds depth scaling in the initial stages of the network.
_version = 2
def __init__(self, alpha, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__()
depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha)
assert alpha > 0.0
self.alpha = alpha
self.num_classes = num_classes
depths = _get_depths(alpha)
layers = [
# First layer: regular conv.
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
# Depthwise separable, no skip.
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1,
groups=depths[0], bias=False),
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
# MNASNet blocks: stacks of inverted residuals.
_stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM),
_stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM),
_stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM),
_stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM),
_stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
_stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
_stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
_stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
# Final mapping to classifier input.
nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False),
nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
]
......@@ -139,16 +148,58 @@ class MNASNet(torch.nn.Module):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0.01)
nn.init.kaiming_uniform_(m.weight, mode="fan_out",
nonlinearity="sigmoid")
nn.init.zeros_(m.bias)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get("version", None)
assert version in [1, 2]
if version == 1 and not self.alpha == 1.0:
# In the initial version of the model (v1), stem was fixed-size.
# All other layer configurations were the same. This will patch
# the model so that it's identical to v1. Model with alpha 1.0 is
# unaffected.
depths = _get_depths(self.alpha)
v1_stem = [
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32,
bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
_stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
]
for idx, layer in enumerate(v1_stem):
self.layers[idx] = layer
# The model is now identical to v1, and must be saved as such.
self._version = 1
warnings.warn(
"A new version of MNASNet model has been implemented. "
"Your checkpoint was saved using the previous version. "
"This checkpoint will load and work as before, but "
"you may want to upgrade by training a newer model or "
"transfer learning from an updated ImageNet checkpoint.",
UserWarning)
super(MNASNet, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
def _load_pretrained(model_name, model, progress):
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError(
"No checkpoint is available for model type {}".format(model_name))
checkpoint_url = _MODEL_URLS[model_name]
model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress))
model.load_state_dict(
load_state_dict_from_url(checkpoint_url, progress=progress))
def mnasnet0_5(pretrained=False, progress=True, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment