Unverified Commit d4cd0bed authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to mobilenet (#2862)

* style: Added annotation typing for mmobilenet

* fix: Fixed type hinting of adaptive pooling

* refactor: Removed un-necessary import

* fix: Fixed constructor typing

* fix: Fixed list typing
parent 59c97420
from torch import nn from torch import nn
from torch import Tensor
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
from typing import Callable, Any, Optional, List
__all__ = ['MobileNetV2', 'mobilenet_v2'] __all__ = ['MobileNetV2', 'mobilenet_v2']
...@@ -10,7 +12,7 @@ model_urls = { ...@@ -10,7 +12,7 @@ model_urls = {
} }
def _make_divisible(v, divisor, min_value=None): def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
""" """
This function is taken from the original tf repo. This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8 It ensures that all layers have a channel number that is divisible by 8
...@@ -31,7 +33,15 @@ def _make_divisible(v, divisor, min_value=None): ...@@ -31,7 +33,15 @@ def _make_divisible(v, divisor, min_value=None):
class ConvBNReLU(nn.Sequential): class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None): def __init__(
self,
in_planes: int,
out_planes: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
...@@ -43,7 +53,14 @@ class ConvBNReLU(nn.Sequential): ...@@ -43,7 +53,14 @@ class ConvBNReLU(nn.Sequential):
class InvertedResidual(nn.Module): class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None): def __init__(
self,
inp: int,
oup: int,
stride: int,
expand_ratio: int,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
self.stride = stride self.stride = stride
assert stride in [1, 2] assert stride in [1, 2]
...@@ -54,7 +71,7 @@ class InvertedResidual(nn.Module): ...@@ -54,7 +71,7 @@ class InvertedResidual(nn.Module):
hidden_dim = int(round(inp * expand_ratio)) hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup self.use_res_connect = self.stride == 1 and inp == oup
layers = [] layers: List[nn.Module] = []
if expand_ratio != 1: if expand_ratio != 1:
# pw # pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
...@@ -67,7 +84,7 @@ class InvertedResidual(nn.Module): ...@@ -67,7 +84,7 @@ class InvertedResidual(nn.Module):
]) ])
self.conv = nn.Sequential(*layers) self.conv = nn.Sequential(*layers)
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect: if self.use_res_connect:
return x + self.conv(x) return x + self.conv(x)
else: else:
...@@ -75,13 +92,15 @@ class InvertedResidual(nn.Module): ...@@ -75,13 +92,15 @@ class InvertedResidual(nn.Module):
class MobileNetV2(nn.Module): class MobileNetV2(nn.Module):
def __init__(self, def __init__(
num_classes=1000, self,
width_mult=1.0, num_classes: int = 1000,
inverted_residual_setting=None, width_mult: float = 1.0,
round_nearest=8, inverted_residual_setting: Optional[List[List[int]]] = None,
block=None, round_nearest: int = 8,
norm_layer=None): block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
""" """
MobileNet V2 main class MobileNet V2 main class
...@@ -126,7 +145,7 @@ class MobileNetV2(nn.Module): ...@@ -126,7 +145,7 @@ class MobileNetV2(nn.Module):
# building first layer # building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest) input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
# building inverted residual blocks # building inverted residual blocks
for t, c, n, s in inverted_residual_setting: for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest) output_channel = _make_divisible(c * width_mult, round_nearest)
...@@ -158,20 +177,20 @@ class MobileNetV2(nn.Module): ...@@ -158,20 +177,20 @@ class MobileNetV2(nn.Module):
nn.init.normal_(m.weight, 0, 0.01) nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
def _forward_impl(self, x): def _forward_impl(self, x: Tensor) -> Tensor:
# This exists since TorchScript doesn't support inheritance, so the superclass method # This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass # (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x) x = self.features(x)
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
x = self.classifier(x) x = self.classifier(x)
return x return x
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x) return self._forward_impl(x)
def mobilenet_v2(pretrained=False, progress=True, **kwargs): def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2:
""" """
Constructs a MobileNetV2 architecture from Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_. `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment