Unverified Commit 65591f14 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to squeezenet (#2865)

* style: Added annotation typing for squeezenet

* feat: Added typing for kwargs
parent 67e78798
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.init as init import torch.nn.init as init
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
from typing import Any
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
...@@ -13,8 +14,13 @@ model_urls = { ...@@ -13,8 +14,13 @@ model_urls = {
class Fire(nn.Module): class Fire(nn.Module):
def __init__(self, inplanes, squeeze_planes, def __init__(
expand1x1_planes, expand3x3_planes): self,
inplanes: int,
squeeze_planes: int,
expand1x1_planes: int,
expand3x3_planes: int
):
super(Fire, self).__init__() super(Fire, self).__init__()
self.inplanes = inplanes self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
...@@ -26,7 +32,7 @@ class Fire(nn.Module): ...@@ -26,7 +32,7 @@ class Fire(nn.Module):
kernel_size=3, padding=1) kernel_size=3, padding=1)
self.expand3x3_activation = nn.ReLU(inplace=True) self.expand3x3_activation = nn.ReLU(inplace=True)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.squeeze_activation(self.squeeze(x)) x = self.squeeze_activation(self.squeeze(x))
return torch.cat([ return torch.cat([
self.expand1x1_activation(self.expand1x1(x)), self.expand1x1_activation(self.expand1x1(x)),
...@@ -36,7 +42,11 @@ class Fire(nn.Module): ...@@ -36,7 +42,11 @@ class Fire(nn.Module):
class SqueezeNet(nn.Module): class SqueezeNet(nn.Module):
def __init__(self, version='1_0', num_classes=1000): def __init__(
self,
version: str = '1_0',
num_classes: int = 1000
):
super(SqueezeNet, self).__init__() super(SqueezeNet, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
if version == '1_0': if version == '1_0':
...@@ -96,13 +106,13 @@ class SqueezeNet(nn.Module): ...@@ -96,13 +106,13 @@ class SqueezeNet(nn.Module):
if m.bias is not None: if m.bias is not None:
init.constant_(m.bias, 0) init.constant_(m.bias, 0)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x) x = self.features(x)
x = self.classifier(x) x = self.classifier(x)
return torch.flatten(x, 1) return torch.flatten(x, 1)
def _squeezenet(version, pretrained, progress, **kwargs): def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet:
model = SqueezeNet(version, **kwargs) model = SqueezeNet(version, **kwargs)
if pretrained: if pretrained:
arch = 'squeezenet' + version arch = 'squeezenet' + version
...@@ -112,7 +122,7 @@ def _squeezenet(version, pretrained, progress, **kwargs): ...@@ -112,7 +122,7 @@ def _squeezenet(version, pretrained, progress, **kwargs):
return model return model
def squeezenet1_0(pretrained=False, progress=True, **kwargs): def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size" accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper. <https://arxiv.org/abs/1602.07360>`_ paper.
...@@ -124,7 +134,7 @@ def squeezenet1_0(pretrained=False, progress=True, **kwargs): ...@@ -124,7 +134,7 @@ def squeezenet1_0(pretrained=False, progress=True, **kwargs):
return _squeezenet('1_0', pretrained, progress, **kwargs) return _squeezenet('1_0', pretrained, progress, **kwargs)
def squeezenet1_1(pretrained=False, progress=True, **kwargs): def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_. <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment