"tests/vscode:/vscode.git/clone" did not exist on "27f6561a89294532c92ad94ec492cf00a3940c25"
Unverified Commit f51bcf50 authored by ShawnHu's avatar ShawnHu Committed by GitHub
Browse files

Add type hints in mmcv/cnn/resnet.py (#1982)

* Add type hints in resnet.py

* using lint while commit

* using lint while commit

* using lint while commit

* reslove typehints

* add pre-defined type hints

* Add type hints for other methods in mmcv/cnn/resnet.py

* Fix type hints
parent aea2bb28
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional, Sequence, Tuple, Union
import torch.nn as nn
import torch.utils.checkpoint as cp
from torch import Tensor
from .utils import constant_init, kaiming_init
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
def conv3x3(in_planes: int,
out_planes: int,
stride: int = 1,
dilation: int = 1):
"""3x3 convolution with padding."""
return nn.Conv2d(
in_planes,
......@@ -23,13 +28,13 @@ class BasicBlock(nn.Module):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False):
inplanes: int,
planes: int,
stride: int = 1,
dilation: int = 1,
downsample: Optional[nn.Module] = None,
style: str = 'pytorch',
with_cp: bool = False):
super().__init__()
assert style in ['pytorch', 'caffe']
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
......@@ -42,7 +47,7 @@ class BasicBlock(nn.Module):
self.dilation = dilation
assert not with_cp
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
residual = x
out = self.conv1(x)
......@@ -65,13 +70,13 @@ class Bottleneck(nn.Module):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False):
inplanes: int,
planes: int,
stride: int = 1,
dilation: int = 1,
downsample: Optional[nn.Module] = None,
style: str = 'pytorch',
with_cp: bool = False):
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
......@@ -107,7 +112,7 @@ class Bottleneck(nn.Module):
self.dilation = dilation
self.with_cp = with_cp
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
def _inner_forward(x):
residual = x
......@@ -140,14 +145,14 @@ class Bottleneck(nn.Module):
return out
def make_res_layer(block,
inplanes,
planes,
blocks,
stride=1,
dilation=1,
style='pytorch',
with_cp=False):
def make_res_layer(block: nn.Module,
inplanes: int,
planes: int,
blocks: int,
stride: int = 1,
dilation: int = 1,
style: str = 'pytorch',
with_cp: bool = False) -> nn.Module:
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
......@@ -208,22 +213,22 @@ class ResNet(nn.Module):
}
def __init__(self,
depth,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
frozen_stages=-1,
bn_eval=True,
bn_frozen=False,
with_cp=False):
depth: int,
num_stages: int = 4,
strides: Sequence[int] = (1, 2, 2, 2),
dilations: Sequence[int] = (1, 1, 1, 1),
out_indices: Sequence[int] = (0, 1, 2, 3),
style: str = 'pytorch',
frozen_stages: int = -1,
bn_eval: bool = True,
bn_frozen: bool = False,
with_cp: bool = False):
super().__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
stage_blocks = stage_blocks[:num_stages] # type: ignore
assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
......@@ -234,7 +239,7 @@ class ResNet(nn.Module):
self.bn_frozen = bn_frozen
self.with_cp = with_cp
self.inplanes = 64
self.inplanes: int = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
......@@ -255,14 +260,15 @@ class ResNet(nn.Module):
dilation=dilation,
style=self.style,
with_cp=with_cp)
self.inplanes = planes * block.expansion
self.inplanes = planes * block.expansion # type: ignore
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
self.feat_dim = block.expansion * 64 * 2**( # type: ignore
len(stage_blocks) - 1)
def init_weights(self, pretrained=None):
def init_weights(self, pretrained: Optional[str] = None) -> None:
if isinstance(pretrained, str):
logger = logging.getLogger()
from ..runner import load_checkpoint
......@@ -276,7 +282,7 @@ class ResNet(nn.Module):
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
......@@ -292,7 +298,7 @@ class ResNet(nn.Module):
else:
return tuple(outs)
def train(self, mode=True):
def train(self, mode: bool = True) -> None:
super().train(mode)
if self.bn_eval:
for m in self.modules():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment