Unverified Commit 67e78798 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to mnasnet (#2856)

* style: Added annotation typing for mnasnet

* refactor: Removed un-necessary import
parent aa4cf039
import warnings
import torch
from torch import Tensor
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Any, Dict, List
__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
......@@ -22,8 +24,15 @@ _BN_MOMENTUM = 1 - 0.9997
class _InvertedResidual(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
bn_momentum=0.1):
def __init__(
self,
in_ch: int,
out_ch: int,
kernel_size: int,
stride: int,
expansion_factor: int,
bn_momentum: float = 0.1
):
super(_InvertedResidual, self).__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5]
......@@ -43,15 +52,15 @@ class _InvertedResidual(nn.Module):
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum))
def forward(self, input):
def forward(self, input: Tensor) -> Tensor:
if self.apply_residual:
return self.layers(input) + input
else:
return self.layers(input)
def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
bn_momentum):
def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int,
bn_momentum: float) -> nn.Sequential:
""" Creates a stack of inverted residuals. """
assert repeats >= 1
# First one has no skip, because feature map size changes.
......@@ -65,7 +74,7 @@ def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
return nn.Sequential(first, *remaining)
def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
bias, will round up, unless the number is no more than 10% greater than the
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
......@@ -74,7 +83,7 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
return new_val if new_val >= round_up_bias * val else new_val + divisor
def _get_depths(alpha):
def _get_depths(alpha: float) -> List[int]:
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """
depths = [32, 16, 24, 40, 80, 96, 192, 320]
......@@ -95,7 +104,12 @@ class MNASNet(torch.nn.Module):
# Version 2 adds depth scaling in the initial stages of the network.
_version = 2
def __init__(self, alpha, num_classes=1000, dropout=0.2):
def __init__(
self,
alpha: float,
num_classes: int = 1000,
dropout: float = 0.2
):
super(MNASNet, self).__init__()
assert alpha > 0.0
self.alpha = alpha
......@@ -130,13 +144,13 @@ class MNASNet(torch.nn.Module):
nn.Linear(1280, num_classes))
self._initialize_weights()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
return self.classifier(x)
def _initialize_weights(self):
def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out",
......@@ -151,8 +165,8 @@ class MNASNet(torch.nn.Module):
nonlinearity="sigmoid")
nn.init.zeros_(m.bias)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool,
missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None:
version = local_metadata.get("version", None)
assert version in [1, 2]
......@@ -192,7 +206,7 @@ class MNASNet(torch.nn.Module):
unexpected_keys, error_msgs)
def _load_pretrained(model_name, model, progress):
def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None:
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError(
"No checkpoint is available for model type {}".format(model_name))
......@@ -201,7 +215,7 @@ def _load_pretrained(model_name, model, progress):
load_state_dict_from_url(checkpoint_url, progress=progress))
def mnasnet0_5(pretrained=False, progress=True, **kwargs):
def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 0.5 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_.
......@@ -215,7 +229,7 @@ def mnasnet0_5(pretrained=False, progress=True, **kwargs):
return model
def mnasnet0_75(pretrained=False, progress=True, **kwargs):
def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 0.75 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_.
......@@ -229,7 +243,7 @@ def mnasnet0_75(pretrained=False, progress=True, **kwargs):
return model
def mnasnet1_0(pretrained=False, progress=True, **kwargs):
def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 1.0 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_.
......@@ -243,7 +257,7 @@ def mnasnet1_0(pretrained=False, progress=True, **kwargs):
return model
def mnasnet1_3(pretrained=False, progress=True, **kwargs):
def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 1.3 from
`"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
<https://arxiv.org/pdf/1807.11626.pdf>`_.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment