Unverified Commit 67e78798 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to mnasnet (#2856)

* style: Added annotation typing for mnasnet

* refactor: Removed un-necessary import
parent aa4cf039
import warnings import warnings
import torch import torch
from torch import Tensor
import torch.nn as nn import torch.nn as nn
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
from typing import Any, Dict, List
__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
...@@ -22,8 +24,15 @@ _BN_MOMENTUM = 1 - 0.9997 ...@@ -22,8 +24,15 @@ _BN_MOMENTUM = 1 - 0.9997
class _InvertedResidual(nn.Module): class _InvertedResidual(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, def __init__(
bn_momentum=0.1): self,
in_ch: int,
out_ch: int,
kernel_size: int,
stride: int,
expansion_factor: int,
bn_momentum: float = 0.1
):
super(_InvertedResidual, self).__init__() super(_InvertedResidual, self).__init__()
assert stride in [1, 2] assert stride in [1, 2]
assert kernel_size in [3, 5] assert kernel_size in [3, 5]
...@@ -43,15 +52,15 @@ class _InvertedResidual(nn.Module): ...@@ -43,15 +52,15 @@ class _InvertedResidual(nn.Module):
nn.Conv2d(mid_ch, out_ch, 1, bias=False), nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum)) nn.BatchNorm2d(out_ch, momentum=bn_momentum))
def forward(self, input): def forward(self, input: Tensor) -> Tensor:
if self.apply_residual: if self.apply_residual:
return self.layers(input) + input return self.layers(input) + input
else: else:
return self.layers(input) return self.layers(input)
def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int,
bn_momentum): bn_momentum: float) -> nn.Sequential:
""" Creates a stack of inverted residuals. """ """ Creates a stack of inverted residuals. """
assert repeats >= 1 assert repeats >= 1
# First one has no skip, because feature map size changes. # First one has no skip, because feature map size changes.
...@@ -65,7 +74,7 @@ def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, ...@@ -65,7 +74,7 @@ def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
return nn.Sequential(first, *remaining) return nn.Sequential(first, *remaining)
def _round_to_multiple_of(val, divisor, round_up_bias=0.9): def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
""" Asymmetric rounding to make `val` divisible by `divisor`. With default """ Asymmetric rounding to make `val` divisible by `divisor`. With default
bias, will round up, unless the number is no more than 10% greater than the bias, will round up, unless the number is no more than 10% greater than the
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
...@@ -74,7 +83,7 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9): ...@@ -74,7 +83,7 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
return new_val if new_val >= round_up_bias * val else new_val + divisor return new_val if new_val >= round_up_bias * val else new_val + divisor
def _get_depths(alpha): def _get_depths(alpha: float) -> List[int]:
""" Scales tensor depths as in reference MobileNet code, prefers rouding up """ Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """ rather than down. """
depths = [32, 16, 24, 40, 80, 96, 192, 320] depths = [32, 16, 24, 40, 80, 96, 192, 320]
...@@ -95,7 +104,12 @@ class MNASNet(torch.nn.Module): ...@@ -95,7 +104,12 @@ class MNASNet(torch.nn.Module):
# Version 2 adds depth scaling in the initial stages of the network. # Version 2 adds depth scaling in the initial stages of the network.
_version = 2 _version = 2
def __init__(self, alpha, num_classes=1000, dropout=0.2): def __init__(
self,
alpha: float,
num_classes: int = 1000,
dropout: float = 0.2
):
super(MNASNet, self).__init__() super(MNASNet, self).__init__()
assert alpha > 0.0 assert alpha > 0.0
self.alpha = alpha self.alpha = alpha
...@@ -130,13 +144,13 @@ class MNASNet(torch.nn.Module): ...@@ -130,13 +144,13 @@ class MNASNet(torch.nn.Module):
nn.Linear(1280, num_classes)) nn.Linear(1280, num_classes))
self._initialize_weights() self._initialize_weights()
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
x = self.layers(x) x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions. # Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3]) x = x.mean([2, 3])
return self.classifier(x) return self.classifier(x)
def _initialize_weights(self): def _initialize_weights(self) -> None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nn.init.kaiming_normal_(m.weight, mode="fan_out",
...@@ -151,8 +165,8 @@ class MNASNet(torch.nn.Module): ...@@ -151,8 +165,8 @@ class MNASNet(torch.nn.Module):
nonlinearity="sigmoid") nonlinearity="sigmoid")
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool,
missing_keys, unexpected_keys, error_msgs): missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None:
version = local_metadata.get("version", None) version = local_metadata.get("version", None)
assert version in [1, 2] assert version in [1, 2]
...@@ -192,7 +206,7 @@ class MNASNet(torch.nn.Module): ...@@ -192,7 +206,7 @@ class MNASNet(torch.nn.Module):
unexpected_keys, error_msgs) unexpected_keys, error_msgs)
def _load_pretrained(model_name, model, progress): def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None:
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError( raise ValueError(
"No checkpoint is available for model type {}".format(model_name)) "No checkpoint is available for model type {}".format(model_name))
...@@ -201,7 +215,7 @@ def _load_pretrained(model_name, model, progress): ...@@ -201,7 +215,7 @@ def _load_pretrained(model_name, model, progress):
load_state_dict_from_url(checkpoint_url, progress=progress)) load_state_dict_from_url(checkpoint_url, progress=progress))
def mnasnet0_5(pretrained=False, progress=True, **kwargs): def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 0.5 from """MNASNet with depth multiplier of 0.5 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_. <https://arxiv.org/pdf/1807.11626.pdf>`_.
...@@ -215,7 +229,7 @@ def mnasnet0_5(pretrained=False, progress=True, **kwargs): ...@@ -215,7 +229,7 @@ def mnasnet0_5(pretrained=False, progress=True, **kwargs):
return model return model
def mnasnet0_75(pretrained=False, progress=True, **kwargs): def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 0.75 from """MNASNet with depth multiplier of 0.75 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_. <https://arxiv.org/pdf/1807.11626.pdf>`_.
...@@ -229,7 +243,7 @@ def mnasnet0_75(pretrained=False, progress=True, **kwargs): ...@@ -229,7 +243,7 @@ def mnasnet0_75(pretrained=False, progress=True, **kwargs):
return model return model
def mnasnet1_0(pretrained=False, progress=True, **kwargs): def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 1.0 from """MNASNet with depth multiplier of 1.0 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_. <https://arxiv.org/pdf/1807.11626.pdf>`_.
...@@ -243,7 +257,7 @@ def mnasnet1_0(pretrained=False, progress=True, **kwargs): ...@@ -243,7 +257,7 @@ def mnasnet1_0(pretrained=False, progress=True, **kwargs):
return model return model
def mnasnet1_3(pretrained=False, progress=True, **kwargs): def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 1.3 from """MNASNet with depth multiplier of 1.3 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_. <https://arxiv.org/pdf/1807.11626.pdf>`_.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment