Unverified Commit d4cd0bed authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to mobilenet (#2862)

* style: Added annotation typing for mmobilenet

* fix: Fixed type hinting of adaptive pooling

* refactor: Removed un-necessary import

* fix: Fixed constructor typing

* fix: Fixed list typing
parent 59c97420
from torch import nn
from torch import Tensor
from .utils import load_state_dict_from_url
from typing import Callable, Any, Optional, List
__all__ = ['MobileNetV2', 'mobilenet_v2']
......@@ -10,7 +12,7 @@ model_urls = {
}
def _make_divisible(v, divisor, min_value=None):
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
......@@ -31,7 +33,15 @@ def _make_divisible(v, divisor, min_value=None):
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
def __init__(
self,
in_planes: int,
out_planes: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
padding = (kernel_size - 1) // 2
if norm_layer is None:
norm_layer = nn.BatchNorm2d
......@@ -43,7 +53,14 @@ class ConvBNReLU(nn.Sequential):
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
def __init__(
self,
inp: int,
oup: int,
stride: int,
expand_ratio: int,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
......@@ -54,7 +71,7 @@ class InvertedResidual(nn.Module):
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
layers: List[nn.Module] = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
......@@ -67,7 +84,7 @@ class InvertedResidual(nn.Module):
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
return x + self.conv(x)
else:
......@@ -75,13 +92,15 @@ class InvertedResidual(nn.Module):
class MobileNetV2(nn.Module):
def __init__(self,
num_classes=1000,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None,
norm_layer=None):
def __init__(
self,
num_classes: int = 1000,
width_mult: float = 1.0,
inverted_residual_setting: Optional[List[List[int]]] = None,
round_nearest: int = 8,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
"""
MobileNet V2 main class
......@@ -126,7 +145,7 @@ class MobileNetV2(nn.Module):
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
......@@ -158,20 +177,20 @@ class MobileNetV2(nn.Module):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x):
def _forward_impl(self, x: Tensor) -> Tensor:
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
x = self.classifier(x)
return x
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def mobilenet_v2(pretrained=False, progress=True, **kwargs):
def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2:
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment