Commit 367e8514 authored by Dmitry Belenko's avatar Dmitry Belenko Committed by Francisco Massa
Browse files

Bugfix for MNASNet (#1224)

* Add initial mnasnet impl

* Remove all type hints, comply with PyTorch overall style

* Expose models

* Remove avgpool from features() and add separately

* Fix python3-only stuff, replace subclasses with functions

* fix __all__

* Fix typo

* Remove conditional dropout

* Make dropout functional

* Addressing @fmassa's feedback, round 1

* Replaced adaptive avgpool with mean on H and W to prevent collapsing the batch dimension

* Partially address feedback

* YAPF

* Removed redundant class vars

* Update urls to releases

* Add information to models.rst

* Replace init with kaiming_normal_ in fan-out mode

* Use load_state_dict_from_url

* Fix depth scaling on first 2 layers

* Restore initialization

* Match reference implementation initialization for dense layer

* Meant to use Kaiming

* Remove spurious relu

* Point to the newest 0.5 checkpoint

* Latest pretrained checkpoint

* Restore 1.0 checkpoint

* YAPF

* Implement backwards compat as suggested by Soumith

* Update checkpoint URL

* Move warnings up

* Record a couple more function parameters

* Update comment

* Set the correct version such that if the BC-patched model is saved, it could be reloaded with BC patching again

* Set a member var, not class var

* Update mnasnet.py

Remove unused member var as per review.

* Update the path to weights
parent 3394c0f5
import math import math
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -8,7 +9,7 @@ __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] ...@@ -8,7 +9,7 @@ __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
_MODEL_URLS = { _MODEL_URLS = {
"mnasnet0_5": "mnasnet0_5":
"https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth", "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
"mnasnet0_75": None, "mnasnet0_75": None,
"mnasnet1_0": "mnasnet1_0":
"https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
...@@ -74,14 +75,16 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9): ...@@ -74,14 +75,16 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
return new_val if new_val >= round_up_bias * val else new_val + divisor return new_val if new_val >= round_up_bias * val else new_val + divisor
def _scale_depths(depths, alpha): def _get_depths(alpha):
""" Scales tensor depths as in reference MobileNet code, prefers rouding up """ Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """ rather than down. """
depths = [32, 16, 24, 40, 80, 96, 192, 320]
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
class MNASNet(torch.nn.Module): class MNASNet(torch.nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
>>> model = MNASNet(1000, 1.0) >>> model = MNASNet(1000, 1.0)
>>> x = torch.rand(1, 3, 224, 224) >>> x = torch.rand(1, 3, 224, 224)
>>> y = model(x) >>> y = model(x)
...@@ -90,30 +93,36 @@ class MNASNet(torch.nn.Module): ...@@ -90,30 +93,36 @@ class MNASNet(torch.nn.Module):
>>> y.nelement() >>> y.nelement()
1000 1000
""" """
# Version 2 adds depth scaling in the initial stages of the network.
_version = 2
def __init__(self, alpha, num_classes=1000, dropout=0.2): def __init__(self, alpha, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__() super(MNASNet, self).__init__()
depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha) assert alpha > 0.0
self.alpha = alpha
self.num_classes = num_classes
depths = _get_depths(alpha)
layers = [ layers = [
# First layer: regular conv. # First layer: regular conv.
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# Depthwise separable, no skip. # Depthwise separable, no skip.
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1,
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), groups=depths[0], bias=False),
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
# MNASNet blocks: stacks of inverted residuals. # MNASNet blocks: stacks of inverted residuals.
_stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
_stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM), _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
_stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM), _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
_stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM), _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
# Final mapping to classifier input. # Final mapping to classifier input.
nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False), nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
] ]
...@@ -139,16 +148,58 @@ class MNASNet(torch.nn.Module): ...@@ -139,16 +148,58 @@ class MNASNet(torch.nn.Module):
nn.init.ones_(m.weight) nn.init.ones_(m.weight)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0.01) nn.init.kaiming_uniform_(m.weight, mode="fan_out",
nonlinearity="sigmoid")
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get("version", None)
assert version in [1, 2]
if version == 1 and not self.alpha == 1.0:
# In the initial version of the model (v1), stem was fixed-size.
# All other layer configurations were the same. This will patch
# the model so that it's identical to v1. Model with alpha 1.0 is
# unaffected.
depths = _get_depths(self.alpha)
v1_stem = [
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32,
bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
_stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
]
for idx, layer in enumerate(v1_stem):
self.layers[idx] = layer
# The model is now identical to v1, and must be saved as such.
self._version = 1
warnings.warn(
"A new version of MNASNet model has been implemented. "
"Your checkpoint was saved using the previous version. "
"This checkpoint will load and work as before, but "
"you may want to upgrade by training a newer model or "
"transfer learning from an updated ImageNet checkpoint.",
UserWarning)
super(MNASNet, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
def _load_pretrained(model_name, model, progress): def _load_pretrained(model_name, model, progress):
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError( raise ValueError(
"No checkpoint is available for model type {}".format(model_name)) "No checkpoint is available for model type {}".format(model_name))
checkpoint_url = _MODEL_URLS[model_name] checkpoint_url = _MODEL_URLS[model_name]
model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress)) model.load_state_dict(
load_state_dict_from_url(checkpoint_url, progress=progress))
def mnasnet0_5(pretrained=False, progress=True, **kwargs): def mnasnet0_5(pretrained=False, progress=True, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment