"vscode:/vscode.git/clone" did not exist on "5fa8ae041cef2b5f5587d4eb076dbaeb5bf992f6"
Unverified Commit a8d84961 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to densenet (#2860)

* style: Added annotation typing for densenet

* fix: Fixed import

* refactor: Removed un-necessary import

* fix: Fixed constructor typing

* chore: Updated mypy.ini

* fix: Fixed tuple typing

* style: Ignored some mypy errors

* style: Fixed typing

* fix: Added missing constructor typing
parent f655e6a7
......@@ -16,10 +16,6 @@ ignore_errors = True
ignore_errors = True
[mypy-torchvision.models.densenet.*]
ignore_errors = True
[mypy-torchvision.models.quantization.*]
ignore_errors = True
......
......@@ -6,7 +6,7 @@ import torch.utils.checkpoint as cp
from collections import OrderedDict
from .utils import load_state_dict_from_url
from torch import Tensor
from torch.jit.annotations import List
from typing import Any, List, Tuple
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
......@@ -20,56 +20,64 @@ model_urls = {
class _DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
def __init__(
self,
num_input_features: int,
growth_rate: int,
bn_size: int,
drop_rate: float,
memory_efficient: bool = False
) -> None:
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.norm1: nn.BatchNorm2d
self.add_module('norm1', nn.BatchNorm2d(num_input_features))
self.relu1: nn.ReLU
self.add_module('relu1', nn.ReLU(inplace=True))
self.conv1: nn.Conv2d
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1,
bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
bias=False))
self.norm2: nn.BatchNorm2d
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate))
self.relu2: nn.ReLU
self.add_module('relu2', nn.ReLU(inplace=True))
self.conv2: nn.Conv2d
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1,
bias=False)),
bias=False))
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
def bn_function(self, inputs):
# type: (List[Tensor]) -> Tensor
def bn_function(self, inputs: List[Tensor]) -> Tensor:
concated_features = torch.cat(inputs, 1)
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
return bottleneck_output
# todo: rewrite when torchscript supports any
def any_requires_grad(self, input):
# type: (List[Tensor]) -> bool
def any_requires_grad(self, input: List[Tensor]) -> bool:
for tensor in input:
if tensor.requires_grad:
return True
return False
@torch.jit.unused # noqa: T484
def call_checkpoint_bottleneck(self, input):
# type: (List[Tensor]) -> Tensor
def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
def closure(*inputs):
return self.bn_function(inputs)
return cp.checkpoint(closure, *input)
@torch.jit._overload_method # noqa: F811
def forward(self, input):
# type: (List[Tensor]) -> (Tensor)
def forward(self, input: List[Tensor]) -> Tensor:
pass
@torch.jit._overload_method # noqa: F811
def forward(self, input):
# type: (Tensor) -> (Tensor)
@torch.jit._overload_method # type: ignore[no-redef] # noqa: F811
def forward(self, input: Tensor) -> Tensor:
pass
# torchscript does not yet support *args, so we overload method
# allowing it to take either a List[Tensor] or single Tensor
def forward(self, input): # noqa: F811
def forward(self, input: Tensor) -> Tensor: # type: ignore[no-redef] # noqa: F811
if isinstance(input, Tensor):
prev_features = [input]
else:
......@@ -93,7 +101,15 @@ class _DenseLayer(nn.Module):
class _DenseBlock(nn.ModuleDict):
_version = 2
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
def __init__(
self,
num_layers: int,
num_input_features: int,
bn_size: int,
growth_rate: int,
drop_rate: float,
memory_efficient: bool = False
) -> None:
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
......@@ -105,7 +121,7 @@ class _DenseBlock(nn.ModuleDict):
)
self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, init_features):
def forward(self, init_features: Tensor) -> Tensor: # type: ignore[override]
features = [init_features]
for name, layer in self.items():
new_features = layer(features)
......@@ -114,7 +130,7 @@ class _DenseBlock(nn.ModuleDict):
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
def __init__(self, num_input_features: int, num_output_features: int) -> None:
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
......@@ -139,8 +155,16 @@ class DenseNet(nn.Module):
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):
def __init__(
self,
growth_rate: int = 32,
block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
num_init_features: int = 64,
bn_size: int = 4,
drop_rate: float = 0,
num_classes: int = 1000,
memory_efficient: bool = False
) -> None:
super(DenseNet, self).__init__()
......@@ -188,7 +212,7 @@ class DenseNet(nn.Module):
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
......@@ -197,7 +221,7 @@ class DenseNet(nn.Module):
return out
def _load_state_dict(model, model_url, progress):
def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
# '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
......@@ -215,15 +239,22 @@ def _load_state_dict(model, model_url, progress):
model.load_state_dict(state_dict)
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
**kwargs):
def _densenet(
arch: str,
growth_rate: int,
block_config: Tuple[int, int, int, int],
num_init_features: int,
pretrained: bool,
progress: bool,
**kwargs: Any
) -> DenseNet:
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
if pretrained:
_load_state_dict(model, model_urls[arch], progress)
return model
def densenet121(pretrained=False, progress=True, **kwargs):
def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
......@@ -237,7 +268,7 @@ def densenet121(pretrained=False, progress=True, **kwargs):
**kwargs)
def densenet161(pretrained=False, progress=True, **kwargs):
def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
......@@ -251,7 +282,7 @@ def densenet161(pretrained=False, progress=True, **kwargs):
**kwargs)
def densenet169(pretrained=False, progress=True, **kwargs):
def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
......@@ -265,7 +296,7 @@ def densenet169(pretrained=False, progress=True, **kwargs):
**kwargs)
def densenet201(pretrained=False, progress=True, **kwargs):
def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment