"gallery/transforms/plot_transforms_illustrations.py" did not exist on "849112914f57ec272c52f3d5083aeab9dc966bf9"
Unverified Commit f9e31a6d authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to vgg (#2861)

* style: Added annotation typing for vgg

* fix: Fixed annotation typing

* refactor: Removed un-necessary import

* fix: Added missing annotation for kwargs

* fix: Fixed constructor typing

* refactor: Refactored typing to minize changes

* refactor: Refactored typing cast

* fix: Fixed module list typing
parent d559ad87
import torch import torch
import torch.nn as nn import torch.nn as nn
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
from typing import Union, List, Dict, Any, cast
__all__ = [ __all__ = [
...@@ -23,7 +24,12 @@ model_urls = { ...@@ -23,7 +24,12 @@ model_urls = {
class VGG(nn.Module): class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True): def __init__(
self,
features: nn.Module,
num_classes: int = 1000,
init_weights: bool = True
) -> None:
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
...@@ -39,14 +45,14 @@ class VGG(nn.Module): ...@@ -39,14 +45,14 @@ class VGG(nn.Module):
if init_weights: if init_weights:
self._initialize_weights() self._initialize_weights()
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x) x = self.features(x)
x = self.avgpool(x) x = self.avgpool(x)
x = torch.flatten(x, 1) x = torch.flatten(x, 1)
x = self.classifier(x) x = self.classifier(x)
return x return x
def _initialize_weights(self): def _initialize_weights(self) -> None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
...@@ -60,13 +66,14 @@ class VGG(nn.Module): ...@@ -60,13 +66,14 @@ class VGG(nn.Module):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def make_layers(cfg, batch_norm=False): def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
layers = [] layers: List[nn.Module] = []
in_channels = 3 in_channels = 3
for v in cfg: for v in cfg:
if v == 'M': if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)] layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else: else:
v = cast(int, v)
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm: if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
...@@ -76,7 +83,7 @@ def make_layers(cfg, batch_norm=False): ...@@ -76,7 +83,7 @@ def make_layers(cfg, batch_norm=False):
return nn.Sequential(*layers) return nn.Sequential(*layers)
cfgs = { cfgs: Dict[str, List[Union[str, int]]] = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
...@@ -84,7 +91,7 @@ cfgs = { ...@@ -84,7 +91,7 @@ cfgs = {
} }
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs['init_weights'] = False
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
...@@ -95,7 +102,7 @@ def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): ...@@ -95,7 +102,7 @@ def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
return model return model
def vgg11(pretrained=False, progress=True, **kwargs): def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 11-layer model (configuration "A") from r"""VGG 11-layer model (configuration "A") from
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
...@@ -106,7 +113,7 @@ def vgg11(pretrained=False, progress=True, **kwargs): ...@@ -106,7 +113,7 @@ def vgg11(pretrained=False, progress=True, **kwargs):
return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
def vgg11_bn(pretrained=False, progress=True, **kwargs): def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 11-layer model (configuration "A") with batch normalization r"""VGG 11-layer model (configuration "A") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
...@@ -117,7 +124,7 @@ def vgg11_bn(pretrained=False, progress=True, **kwargs): ...@@ -117,7 +124,7 @@ def vgg11_bn(pretrained=False, progress=True, **kwargs):
return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
def vgg13(pretrained=False, progress=True, **kwargs): def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 13-layer model (configuration "B") r"""VGG 13-layer model (configuration "B")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
...@@ -128,7 +135,7 @@ def vgg13(pretrained=False, progress=True, **kwargs): ...@@ -128,7 +135,7 @@ def vgg13(pretrained=False, progress=True, **kwargs):
return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
def vgg13_bn(pretrained=False, progress=True, **kwargs): def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 13-layer model (configuration "B") with batch normalization r"""VGG 13-layer model (configuration "B") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
...@@ -139,7 +146,7 @@ def vgg13_bn(pretrained=False, progress=True, **kwargs): ...@@ -139,7 +146,7 @@ def vgg13_bn(pretrained=False, progress=True, **kwargs):
return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
def vgg16(pretrained=False, progress=True, **kwargs): def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 16-layer model (configuration "D") r"""VGG 16-layer model (configuration "D")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
...@@ -150,7 +157,7 @@ def vgg16(pretrained=False, progress=True, **kwargs): ...@@ -150,7 +157,7 @@ def vgg16(pretrained=False, progress=True, **kwargs):
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
def vgg16_bn(pretrained=False, progress=True, **kwargs): def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 16-layer model (configuration "D") with batch normalization r"""VGG 16-layer model (configuration "D") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
...@@ -161,7 +168,7 @@ def vgg16_bn(pretrained=False, progress=True, **kwargs): ...@@ -161,7 +168,7 @@ def vgg16_bn(pretrained=False, progress=True, **kwargs):
return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
def vgg19(pretrained=False, progress=True, **kwargs): def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 19-layer model (configuration "E") r"""VGG 19-layer model (configuration "E")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
...@@ -172,7 +179,7 @@ def vgg19(pretrained=False, progress=True, **kwargs): ...@@ -172,7 +179,7 @@ def vgg19(pretrained=False, progress=True, **kwargs):
return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
def vgg19_bn(pretrained=False, progress=True, **kwargs): def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 19-layer model (configuration 'E') with batch normalization r"""VGG 19-layer model (configuration 'E') with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment