Commit fdeee889 authored by limm's avatar limm
Browse files

release v1.6.1 of mmcv

parent df465820
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
from torch import nn
......@@ -6,7 +8,7 @@ from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS
def last_zero_init(m):
def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
if isinstance(m, nn.Sequential):
constant_init(m[-1], val=0)
else:
......@@ -34,11 +36,11 @@ class ContextBlock(nn.Module):
_abbr_ = 'context_block'
def __init__(self,
in_channels,
ratio,
pooling_type='att',
fusion_types=('channel_add', )):
super(ContextBlock, self).__init__()
in_channels: int,
ratio: float,
pooling_type: str = 'att',
fusion_types: tuple = ('channel_add', )):
super().__init__()
assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple))
valid_fusion_types = ['channel_add', 'channel_mul']
......@@ -82,7 +84,7 @@ class ContextBlock(nn.Module):
if self.channel_mul_conv is not None:
last_zero_init(self.channel_mul_conv)
def spatial_pool(self, x):
def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
batch, channel, height, width = x.size()
if self.pooling_type == 'att':
input_x = x
......@@ -108,7 +110,7 @@ class ContextBlock(nn.Module):
return context
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# [N, C, 1, 1]
context = self.spatial_pool(x)
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from torch import nn
from .registry import CONV_LAYERS
......@@ -9,7 +11,7 @@ CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
def build_conv_layer(cfg, *args, **kwargs):
def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
"""Build convolution layer.
Args:
......@@ -35,7 +37,7 @@ def build_conv_layer(cfg, *args, **kwargs):
layer_type = cfg_.pop('type')
if layer_type not in CONV_LAYERS:
raise KeyError(f'Unrecognized norm type {layer_type}')
raise KeyError(f'Unrecognized layer type {layer_type}')
else:
conv_layer = CONV_LAYERS.get(layer_type)
......
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
......@@ -31,18 +33,18 @@ class Conv2dAdaptivePadding(nn.Conv2d):
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True):
super().__init__(in_channels, out_channels, kernel_size, stride, 0,
dilation, groups, bias)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
img_h, img_w = x.size()[-2:]
kernel_h, kernel_w = self.weight.size()[-2:]
stride_h, stride_w = self.stride
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.utils import _BatchNorm, _InstanceNorm
......@@ -68,22 +70,22 @@ class ConvModule(nn.Module):
_abbr_ = 'conv_block'
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias='auto',
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
inplace=True,
with_spectral_norm=False,
padding_mode='zeros',
order=('conv', 'norm', 'act')):
super(ConvModule, self).__init__()
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: Union[bool, str] = 'auto',
conv_cfg: Optional[Dict] = None,
norm_cfg: Optional[Dict] = None,
act_cfg: Optional[Dict] = dict(type='ReLU'),
inplace: bool = True,
with_spectral_norm: bool = False,
padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act')):
super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
assert act_cfg is None or isinstance(act_cfg, dict)
......@@ -96,7 +98,7 @@ class ConvModule(nn.Module):
self.with_explicit_padding = padding_mode not in official_padding_mode
self.order = order
assert isinstance(self.order, tuple) and len(self.order) == 3
assert set(order) == set(['conv', 'norm', 'act'])
assert set(order) == {'conv', 'norm', 'act'}
self.with_norm = norm_cfg is not None
self.with_activation = act_cfg is not None
......@@ -143,21 +145,22 @@ class ConvModule(nn.Module):
norm_channels = out_channels
else:
norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
self.norm_name, norm = build_norm_layer(
norm_cfg, norm_channels) # type: ignore
self.add_module(self.norm_name, norm)
if self.with_bias:
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
warnings.warn(
'Unnecessary conv bias before batch/instance norm')
else:
self.norm_name = None
self.norm_name = None # type: ignore
# build activation layer
if self.with_activation:
act_cfg_ = act_cfg.copy()
act_cfg_ = act_cfg.copy() # type: ignore
# nn.Tanh has no 'inplace' argument
if act_cfg_['type'] not in [
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
]:
act_cfg_.setdefault('inplace', inplace)
self.activate = build_activation_layer(act_cfg_)
......@@ -193,7 +196,10 @@ class ConvModule(nn.Module):
if self.with_norm:
constant_init(self.norm, 1, bias=0)
def forward(self, x, activate=True, norm=True):
def forward(self,
x: torch.Tensor,
activate: bool = True,
norm: bool = True) -> torch.Tensor:
for layer in self.order:
if layer == 'conv':
if self.with_explicit_padding:
......
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -6,14 +9,14 @@ import torch.nn.functional as F
from .registry import CONV_LAYERS
def conv_ws_2d(input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
eps=1e-5):
def conv_ws_2d(input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
eps: float = 1e-5) -> torch.Tensor:
c_in = weight.size(0)
weight_flat = weight.view(c_in, -1)
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
......@@ -26,16 +29,16 @@ def conv_ws_2d(input,
class ConvWS2d(nn.Conv2d):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
eps=1e-5):
super(ConvWS2d, self).__init__(
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
eps: float = 1e-5):
super().__init__(
in_channels,
out_channels,
kernel_size,
......@@ -46,7 +49,7 @@ class ConvWS2d(nn.Conv2d):
bias=bias)
self.eps = eps
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.eps)
......@@ -76,14 +79,14 @@ class ConvAWS2d(nn.Conv2d):
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True):
super().__init__(
in_channels,
out_channels,
......@@ -98,7 +101,7 @@ class ConvAWS2d(nn.Conv2d):
self.register_buffer('weight_beta',
torch.zeros(self.out_channels, 1, 1, 1))
def _get_weight(self, weight):
def _get_weight(self, weight: torch.Tensor) -> torch.Tensor:
weight_flat = weight.view(weight.size(0), -1)
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
......@@ -106,13 +109,16 @@ class ConvAWS2d(nn.Conv2d):
weight = self.weight_gamma * weight + self.weight_beta
return weight
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self._get_weight(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
local_metadata: Dict, strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str]) -> None:
"""Override default load function.
AWS overrides the function _load_from_state_dict to recover
......@@ -124,7 +130,7 @@ class ConvAWS2d(nn.Conv2d):
"""
self.weight_gamma.data.fill_(-1)
local_missing_keys = []
local_missing_keys: List = []
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, local_missing_keys,
unexpected_keys, error_msgs)
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from .conv_module import ConvModule
......@@ -46,27 +49,27 @@ class DepthwiseSeparableConvModule(nn.Module):
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
dw_norm_cfg='default',
dw_act_cfg='default',
pw_norm_cfg='default',
pw_act_cfg='default',
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
norm_cfg: Optional[Dict] = None,
act_cfg: Dict = dict(type='ReLU'),
dw_norm_cfg: Union[Dict, str] = 'default',
dw_act_cfg: Union[Dict, str] = 'default',
pw_norm_cfg: Union[Dict, str] = 'default',
pw_act_cfg: Union[Dict, str] = 'default',
**kwargs):
super(DepthwiseSeparableConvModule, self).__init__()
super().__init__()
assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not
# specified, use default config.
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
# depthwise convolution
......@@ -78,19 +81,19 @@ class DepthwiseSeparableConvModule(nn.Module):
padding=padding,
dilation=dilation,
groups=in_channels,
norm_cfg=dw_norm_cfg,
act_cfg=dw_act_cfg,
norm_cfg=dw_norm_cfg, # type: ignore
act_cfg=dw_act_cfg, # type: ignore
**kwargs)
self.pointwise_conv = ConvModule(
in_channels,
out_channels,
1,
norm_cfg=pw_norm_cfg,
act_cfg=pw_act_cfg,
norm_cfg=pw_norm_cfg, # type: ignore
act_cfg=pw_act_cfg, # type: ignore
**kwargs)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
......@@ -6,7 +8,9 @@ from mmcv import build_from_cfg
from .registry import DROPOUT_LAYERS
def drop_path(x, drop_prob=0., training=False):
def drop_path(x: torch.Tensor,
drop_prob: float = 0.,
training: bool = False) -> torch.Tensor:
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
......@@ -36,11 +40,11 @@ class DropPath(nn.Module):
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
"""
def __init__(self, drop_prob=0.1):
super(DropPath, self).__init__()
def __init__(self, drop_prob: float = 0.1):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
......@@ -56,10 +60,10 @@ class Dropout(nn.Dropout):
inplace (bool): Do the operation inplace or not. Default: False.
"""
def __init__(self, drop_prob=0.5, inplace=False):
def __init__(self, drop_prob: float = 0.5, inplace: bool = False):
super().__init__(p=drop_prob, inplace=inplace)
def build_dropout(cfg, default_args=None):
def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any:
"""Builder for drop out layers."""
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
......@@ -45,16 +45,16 @@ class GeneralizedAttention(nn.Module):
_abbr_ = 'gen_attention_block'
def __init__(self,
in_channels,
spatial_range=-1,
num_heads=9,
position_embedding_dim=-1,
position_magnitude=1,
kv_stride=2,
q_stride=1,
attention_type='1111'):
in_channels: int,
spatial_range: int = -1,
num_heads: int = 9,
position_embedding_dim: int = -1,
position_magnitude: int = 1,
kv_stride: int = 2,
q_stride: int = 1,
attention_type: str = '1111'):
super(GeneralizedAttention, self).__init__()
super().__init__()
# hard range means local range for non-local operation
self.position_embedding_dim = (
......@@ -131,7 +131,7 @@ class GeneralizedAttention(nn.Module):
max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
local_constraint_map = np.ones(
(max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
(max_len, max_len, max_len_kv, max_len_kv), dtype=int)
for iy in range(max_len):
for ix in range(max_len):
local_constraint_map[
......@@ -213,7 +213,7 @@ class GeneralizedAttention(nn.Module):
return embedding_x, embedding_y
def forward(self, x_input):
def forward(self, x_input: torch.Tensor) -> torch.Tensor:
num_heads = self.num_heads
# use empirical_attention
......@@ -351,7 +351,7 @@ class GeneralizedAttention(nn.Module):
repeat(n, 1, 1, 1)
position_feat_x_reshape = position_feat_x.\
view(n, num_heads, w*w_kv, self.qk_embed_dim)
view(n, num_heads, w * w_kv, self.qk_embed_dim)
position_feat_y_reshape = position_feat_y.\
view(n, num_heads, h * h_kv, self.qk_embed_dim)
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
from .registry import ACTIVATION_LAYERS
......@@ -8,11 +11,15 @@ from .registry import ACTIVATION_LAYERS
class HSigmoid(nn.Module):
"""Hard Sigmoid Module. Apply the hard sigmoid function:
Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
Default: Hsigmoid(x) = min(max((x + 3) / 6, 0), 1)
Note:
In MMCV v1.4.4, we modified the default value of args to align with
PyTorch official.
Args:
bias (float): Bias of the input feature map. Default: 1.0.
divisor (float): Divisor of the input feature map. Default: 2.0.
bias (float): Bias of the input feature map. Default: 3.0.
divisor (float): Divisor of the input feature map. Default: 6.0.
min_value (float): Lower bound value. Default: 0.0.
max_value (float): Upper bound value. Default: 1.0.
......@@ -20,15 +27,25 @@ class HSigmoid(nn.Module):
Tensor: The output tensor.
"""
def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
super(HSigmoid, self).__init__()
def __init__(self,
bias: float = 3.0,
divisor: float = 6.0,
min_value: float = 0.0,
max_value: float = 1.0):
super().__init__()
warnings.warn(
'In MMCV v1.4.4, we modified the default value of args to align '
'with PyTorch official. Previous Implementation: '
'Hsigmoid(x) = min(max((x + 1) / 2, 0), 1). '
'Current Implementation: '
'Hsigmoid(x) = min(max((x + 3) / 6, 0), 1).')
self.bias = bias
self.divisor = divisor
assert self.divisor != 0
self.min_value = min_value
self.max_value = max_value
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (x + self.bias) / self.divisor
return x.clamp_(self.min_value, self.max_value)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.utils import TORCH_VERSION, digit_version
from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
class HSwish(nn.Module):
"""Hard Swish Module.
......@@ -21,9 +22,18 @@ class HSwish(nn.Module):
Tensor: The output tensor.
"""
def __init__(self, inplace=False):
super(HSwish, self).__init__()
def __init__(self, inplace: bool = False):
super().__init__()
self.act = nn.ReLU6(inplace)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.act(x + 3) / 6
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.7')):
# Hardswish is not supported when PyTorch version < 1.6.
# And Hardswish in PyTorch 1.6 does not support inplace.
ACTIVATION_LAYERS.register_module(module=HSwish)
else:
ACTIVATION_LAYERS.register_module(module=nn.Hardswish, name='HSwish')
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta
from typing import Dict, Optional
import torch
import torch.nn as nn
......@@ -33,14 +34,14 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
"""
def __init__(self,
in_channels,
reduction=2,
use_scale=True,
conv_cfg=None,
norm_cfg=None,
mode='embedded_gaussian',
in_channels: int,
reduction: int = 2,
use_scale: bool = True,
conv_cfg: Optional[Dict] = None,
norm_cfg: Optional[Dict] = None,
mode: str = 'embedded_gaussian',
**kwargs):
super(_NonLocalNd, self).__init__()
super().__init__()
self.in_channels = in_channels
self.reduction = reduction
self.use_scale = use_scale
......@@ -61,7 +62,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
act_cfg=None) # type: ignore
self.conv_out = ConvModule(
self.inter_channels,
self.in_channels,
......@@ -96,7 +97,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.init_weights(**kwargs)
def init_weights(self, std=0.01, zeros_init=True):
def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None:
if self.mode != 'gaussian':
for m in [self.g, self.theta, self.phi]:
normal_init(m.conv, std=std)
......@@ -113,7 +114,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
else:
normal_init(self.conv_out.norm, std=std)
def gaussian(self, theta_x, phi_x):
def gaussian(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
......@@ -121,7 +123,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def embedded_gaussian(self, theta_x, phi_x):
def embedded_gaussian(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
......@@ -132,7 +135,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def dot_product(self, theta_x, phi_x):
def dot_product(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
......@@ -140,7 +144,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight
def concatenation(self, theta_x, phi_x):
def concatenation(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
......@@ -157,7 +162,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
return pairwise_weight
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Assume `reduction = 1`, then `inter_channels = C`
# or `inter_channels = C` when `mode="gaussian"`
......@@ -224,12 +229,11 @@ class NonLocal1d(_NonLocalNd):
"""
def __init__(self,
in_channels,
sub_sample=False,
conv_cfg=dict(type='Conv1d'),
in_channels: int,
sub_sample: bool = False,
conv_cfg: Dict = dict(type='Conv1d'),
**kwargs):
super(NonLocal1d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
......@@ -258,12 +262,11 @@ class NonLocal2d(_NonLocalNd):
_abbr_ = 'nonlocal_block'
def __init__(self,
in_channels,
sub_sample=False,
conv_cfg=dict(type='Conv2d'),
in_channels: int,
sub_sample: bool = False,
conv_cfg: Dict = dict(type='Conv2d'),
**kwargs):
super(NonLocal2d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
......@@ -289,12 +292,11 @@ class NonLocal3d(_NonLocalNd):
"""
def __init__(self,
in_channels,
sub_sample=False,
conv_cfg=dict(type='Conv3d'),
in_channels: int,
sub_sample: bool = False,
conv_cfg: Dict = dict(type='Conv3d'),
**kwargs):
super(NonLocal3d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
if sub_sample:
......
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict, Tuple, Union
import torch.nn as nn
......@@ -69,7 +70,9 @@ def infer_abbr(class_type):
return 'norm_layer'
def build_norm_layer(cfg, num_features, postfix=''):
def build_norm_layer(cfg: Dict,
num_features: int,
postfix: Union[int, str] = '') -> Tuple[str, nn.Module]:
"""Build normalization layer.
Args:
......@@ -83,9 +86,9 @@ def build_norm_layer(cfg, num_features, postfix=''):
to create named layer.
Returns:
(str, nn.Module): The first element is the layer name consisting of
abbreviation and postfix, e.g., bn1, gn. The second element is the
created norm layer.
tuple[str, nn.Module]: The first element is the layer name consisting
of abbreviation and postfix, e.g., bn1, gn. The second element is the
created norm layer.
"""
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
......@@ -119,7 +122,8 @@ def build_norm_layer(cfg, num_features, postfix=''):
return name, layer
def is_norm(layer, exclude=None):
def is_norm(layer: nn.Module,
exclude: Union[type, tuple, None] = None) -> bool:
"""Check if a layer is a normalization layer.
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch.nn as nn
from .registry import PADDING_LAYERS
......@@ -8,11 +10,11 @@ PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
def build_padding_layer(cfg, *args, **kwargs):
def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
"""Build padding layer.
Args:
cfg (None or dict): The padding layer config, which should contain:
cfg (dict): The padding layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate a padding layer.
......
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import platform
from typing import Dict, Tuple, Union
import torch.nn as nn
from .registry import PLUGIN_LAYERS
if platform.system() == 'Windows':
import regex as re
import regex as re # type: ignore
else:
import re
import re # type: ignore
def infer_abbr(class_type):
def infer_abbr(class_type: type) -> str:
"""Infer abbreviation from the class name.
This method will infer the abbreviation to map class types to
......@@ -47,25 +51,27 @@ def infer_abbr(class_type):
raise TypeError(
f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, '_abbr_'):
return class_type._abbr_
return class_type._abbr_ # type: ignore
else:
return camel2snack(class_type.__name__)
def build_plugin_layer(cfg, postfix='', **kwargs):
def build_plugin_layer(cfg: Dict,
postfix: Union[int, str] = '',
**kwargs) -> Tuple[str, nn.Module]:
"""Build plugin layer.
Args:
cfg (None or dict): cfg should contain:
type (str): identify plugin layer type.
layer args: args needed to instantiate a plugin layer.
cfg (dict): cfg should contain:
- type (str): identify plugin layer type.
- layer args: args needed to instantiate a plugin layer.
postfix (int, str): appended into norm abbreviation to
create named layer. Default: ''.
Returns:
tuple[str, nn.Module]:
name (str): abbreviation + postfix
layer (nn.Module): created plugin layer
tuple[str, nn.Module]: The first one is the concatenation of
abbreviation and postfix. The second is the created plugin layer.
"""
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
......
......@@ -13,9 +13,9 @@ class Scale(nn.Module):
scale (float): Initial value of scale factor. Default: 1.0
"""
def __init__(self, scale=1.0):
super(Scale, self).__init__()
def __init__(self, scale: float = 1.0):
super().__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale
......@@ -19,7 +19,7 @@ class Swish(nn.Module):
"""
def __init__(self):
super(Swish, self).__init__()
super().__init__()
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import warnings
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv import ConfigDict, deprecated_api_warning
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import build_from_cfg
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
to_2tuple)
from .drop import build_dropout
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
from mmcv.ops.multi_scale_deform_attn import \
MultiScaleDeformableAttention # noqa F401
warnings.warn(
ImportWarning(
'``MultiScaleDeformableAttention`` has been moved to '
......@@ -55,6 +60,349 @@ def build_transformer_layer_sequence(cfg, default_args=None):
return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
class AdaptivePadding(nn.Module):
"""Applies padding adaptively to the input.
This module can make input get fully covered by filter
you specified. It support two modes "same" and "corner". The
"same" mode is same with "SAME" padding mode in TensorFlow, pad
zero around input. The "corner" mode would pad zero
to bottom right.
Args:
kernel_size (int | tuple): Size of the kernel. Default: 1.
stride (int | tuple): Stride of the filter. Default: 1.
dilation (int | tuple): Spacing between kernel elements.
Default: 1.
padding (str): Support "same" and "corner", "corner" mode
would pad zero to bottom right, and "same" mode would
pad zero around input. Default: "corner".
Example:
>>> kernel_size = 16
>>> stride = 16
>>> dilation = 1
>>> input = torch.rand(1, 1, 15, 17)
>>> adap_pad = AdaptivePadding(
>>> kernel_size=kernel_size,
>>> stride=stride,
>>> dilation=dilation,
>>> padding="corner")
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
>>> input = torch.rand(1, 1, 16, 17)
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super().__init__()
assert padding in ('same', 'corner')
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
self.padding = padding
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
def get_pad_shape(self, input_shape):
"""Calculate the padding size of input.
Args:
input_shape (:obj:`torch.Size`): arrange as (H, W).
Returns:
Tuple[int]: The padding size along the
original H and W directions
"""
input_h, input_w = input_shape
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.stride
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max((output_h - 1) * stride_h +
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
pad_w = max((output_w - 1) * stride_w +
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
return pad_h, pad_w
def forward(self, x):
"""Add padding to `x`
Args:
x (Tensor): Input tensor has shape (B, C, H, W).
Returns:
Tensor: The tensor with adaptive padding
"""
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
if pad_h > 0 or pad_w > 0:
if self.padding == 'corner':
x = F.pad(x, [0, pad_w, 0, pad_h])
elif self.padding == 'same':
x = F.pad(x, [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2
])
return x
class PatchEmbed(BaseModule):
"""Image to Patch Embedding.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (str): The type of convolution
to generate patch embedding. Default: "Conv2d".
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: 16.
padding (int | tuple | string): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int): The dilation rate of embedding conv. Default: 1.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
input_size (int | tuple | None): The size of input, which will be
used to calculate the out size. Only works when `dynamic_size`
is False. Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type='Conv2d',
kernel_size=16,
stride=16,
padding='corner',
dilation=1,
bias=True,
norm_cfg=None,
input_size=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
if stride is None:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of conv
padding = 0
else:
self.adaptive_padding = None
padding = to_2tuple(padding)
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
if input_size:
input_size = to_2tuple(input_size)
# `init_out_size` would be used outside to
# calculate the num_patches
# e.g. when `use_abs_pos_embed` outside
self.init_input_size = input_size
if self.adaptive_padding:
pad_h, pad_w = self.adaptive_padding.get_pad_shape(input_size)
input_h, input_w = input_size
input_h = input_h + pad_h
input_w = input_w + pad_w
input_size = (input_h, input_w)
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
self.init_out_size = (h_out, w_out)
else:
self.init_input_size = None
self.init_out_size = None
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adaptive_padding:
x = self.adaptive_padding(x)
x = self.projection(x)
out_size = (x.shape[2], x.shape[3])
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x, out_size
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map ((used in Swin Transformer)).
Our implementation uses `nn.Unfold` to
merge patches, which is about 25% faster than the original
implementation. However, we need to modify pretrained
models for compatibility.
Args:
in_channels (int): The num of input channels.
to gets fully covered by filter and stride you specified.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Default: None. (Would be set as `kernel_size`)
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Default: 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding='corner',
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of unfold
padding = 0
else:
self.adaptive_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f'Expect ' \
f'input_size is ' \
f'`Sequence` ' \
f'but get {input_size}'
H, W = input_size
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
if self.adaptive_padding:
x = self.adaptive_padding(x)
H, W = x.shape[-2:]
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
x = self.sampler(x)
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
(self.sampler.kernel_size[0] - 1) -
1) // self.sampler.stride[0] + 1
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
(self.sampler.kernel_size[1] - 1) -
1) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size
@ATTENTION.register_module()
class MultiheadAttention(BaseModule):
"""A wrapper for ``torch.nn.MultiheadAttention``.
......@@ -87,12 +435,13 @@ class MultiheadAttention(BaseModule):
init_cfg=None,
batch_first=False,
**kwargs):
super(MultiheadAttention, self).__init__(init_cfg)
super().__init__(init_cfg)
if 'dropout' in kwargs:
warnings.warn('The arguments `dropout` in MultiheadAttention '
'has been deprecated, now you can separately '
'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) ')
warnings.warn(
'The arguments `dropout` in MultiheadAttention '
'has been deprecated, now you can separately '
'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) ', DeprecationWarning)
attn_drop = kwargs['dropout']
dropout_layer['drop_prob'] = kwargs.pop('dropout')
......@@ -154,9 +503,9 @@ class MultiheadAttention(BaseModule):
Returns:
Tensor: forwarded results with shape
[num_queries, bs, embed_dims]
if self.batch_first is False, else
[bs, num_queries embed_dims].
[num_queries, bs, embed_dims]
if self.batch_first is False, else
[bs, num_queries embed_dims].
"""
if key is None:
......@@ -241,7 +590,7 @@ class FFN(BaseModule):
add_identity=True,
init_cfg=None,
**kwargs):
super(FFN, self).__init__(init_cfg)
super().__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims
......@@ -342,15 +691,15 @@ class BaseTransformerLayer(BaseModule):
f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ')
f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
ffn_cfgs[new_name] = kwargs[ori_name]
super(BaseTransformerLayer, self).__init__(init_cfg)
super().__init__(init_cfg)
self.batch_first = batch_first
assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
assert set(operation_order) & {
'self_attn', 'norm', 'ffn', 'cross_attn'} == \
set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \
......@@ -397,7 +746,7 @@ class BaseTransformerLayer(BaseModule):
assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs['embed_dims'] = self.embed_dims
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(
......@@ -531,7 +880,7 @@ class TransformerLayerSequence(BaseModule):
"""
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__(init_cfg)
super().__init__(init_cfg)
if isinstance(transformerlayers, dict):
transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers)
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -24,9 +27,9 @@ class PixelShufflePack(nn.Module):
channels.
"""
def __init__(self, in_channels, out_channels, scale_factor,
upsample_kernel):
super(PixelShufflePack, self).__init__()
def __init__(self, in_channels: int, out_channels: int, scale_factor: int,
upsample_kernel: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.scale_factor = scale_factor
......@@ -41,13 +44,13 @@ class PixelShufflePack(nn.Module):
def init_weights(self):
xavier_init(self.upsample_conv, distribution='uniform')
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample_conv(x)
x = F.pixel_shuffle(x, self.scale_factor)
return x
def build_upsample_layer(cfg, *args, **kwargs):
def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
"""Build upsample layer.
Args:
......@@ -55,7 +58,7 @@ def build_upsample_layer(cfg, *args, **kwargs):
- type (str): Layer type.
- scale_factor (int): Upsample ratio, which is not applicable to
deconv.
deconv.
- layer args: Args needed to instantiate a upsample layer.
args (argument list): Arguments passed to the ``__init__``
method of the corresponding conv layer.
......
......@@ -21,19 +21,19 @@ else:
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
def obsolete_torch_version(torch_version, version_threshold):
def obsolete_torch_version(torch_version, version_threshold) -> bool:
return torch_version == 'parrots' or torch_version <= version_threshold
class NewEmptyTensorOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, new_shape):
def forward(ctx, x: torch.Tensor, new_shape: tuple) -> torch.Tensor:
ctx.shape = x.shape
return x.new_empty(new_shape)
@staticmethod
def backward(ctx, grad):
def backward(ctx, grad: torch.Tensor) -> tuple:
shape = ctx.shape
return NewEmptyTensorOp.apply(grad, shape), None
......@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
@CONV_LAYERS.register_module('Conv', force=True)
class Conv2d(nn.Conv2d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
......@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d):
@CONV_LAYERS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
......@@ -85,7 +85,7 @@ class Conv3d(nn.Conv3d):
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
......@@ -108,7 +108,7 @@ class ConvTranspose2d(nn.ConvTranspose2d):
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
......@@ -128,7 +128,7 @@ class ConvTranspose3d(nn.ConvTranspose3d):
class MaxPool2d(nn.MaxPool2d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2])
......@@ -146,7 +146,7 @@ class MaxPool2d(nn.MaxPool2d):
class MaxPool3d(nn.MaxPool3d):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2])
......@@ -165,7 +165,7 @@ class MaxPool3d(nn.MaxPool3d):
class Linear(torch.nn.Linear):
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
out_shape = [x.shape[0], self.out_features]
......
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional, Sequence, Tuple, Union
import torch.nn as nn
import torch.utils.checkpoint as cp
from torch import Tensor
from .utils import constant_init, kaiming_init
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
def conv3x3(in_planes: int,
out_planes: int,
stride: int = 1,
dilation: int = 1):
"""3x3 convolution with padding."""
return nn.Conv2d(
in_planes,
......@@ -23,14 +28,14 @@ class BasicBlock(nn.Module):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False):
super(BasicBlock, self).__init__()
inplanes: int,
planes: int,
stride: int = 1,
dilation: int = 1,
downsample: Optional[nn.Module] = None,
style: str = 'pytorch',
with_cp: bool = False):
super().__init__()
assert style in ['pytorch', 'caffe']
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes)
......@@ -42,7 +47,7 @@ class BasicBlock(nn.Module):
self.dilation = dilation
assert not with_cp
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
residual = x
out = self.conv1(x)
......@@ -65,19 +70,19 @@ class Bottleneck(nn.Module):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False):
inplanes: int,
planes: int,
stride: int = 1,
dilation: int = 1,
downsample: Optional[nn.Module] = None,
style: str = 'pytorch',
with_cp: bool = False):
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
super().__init__()
assert style in ['pytorch', 'caffe']
if style == 'pytorch':
conv1_stride = 1
......@@ -107,7 +112,7 @@ class Bottleneck(nn.Module):
self.dilation = dilation
self.with_cp = with_cp
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
def _inner_forward(x):
residual = x
......@@ -140,14 +145,14 @@ class Bottleneck(nn.Module):
return out
def make_res_layer(block,
inplanes,
planes,
blocks,
stride=1,
dilation=1,
style='pytorch',
with_cp=False):
def make_res_layer(block: nn.Module,
inplanes: int,
planes: int,
blocks: int,
stride: int = 1,
dilation: int = 1,
style: str = 'pytorch',
with_cp: bool = False) -> nn.Module:
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
......@@ -208,22 +213,22 @@ class ResNet(nn.Module):
}
def __init__(self,
depth,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
frozen_stages=-1,
bn_eval=True,
bn_frozen=False,
with_cp=False):
super(ResNet, self).__init__()
depth: int,
num_stages: int = 4,
strides: Sequence[int] = (1, 2, 2, 2),
dilations: Sequence[int] = (1, 1, 1, 1),
out_indices: Sequence[int] = (0, 1, 2, 3),
style: str = 'pytorch',
frozen_stages: int = -1,
bn_eval: bool = True,
bn_frozen: bool = False,
with_cp: bool = False):
super().__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
stage_blocks = stage_blocks[:num_stages] # type: ignore
assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
......@@ -234,7 +239,7 @@ class ResNet(nn.Module):
self.bn_frozen = bn_frozen
self.with_cp = with_cp
self.inplanes = 64
self.inplanes: int = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
......@@ -255,14 +260,15 @@ class ResNet(nn.Module):
dilation=dilation,
style=self.style,
with_cp=with_cp)
self.inplanes = planes * block.expansion
self.inplanes = planes * block.expansion # type: ignore
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
self.feat_dim = block.expansion * 64 * 2**( # type: ignore
len(stage_blocks) - 1)
def init_weights(self, pretrained=None):
def init_weights(self, pretrained: Optional[str] = None) -> None:
if isinstance(pretrained, str):
logger = logging.getLogger()
from ..runner import load_checkpoint
......@@ -276,7 +282,7 @@ class ResNet(nn.Module):
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
......@@ -292,8 +298,8 @@ class ResNet(nn.Module):
else:
return tuple(outs)
def train(self, mode=True):
super(ResNet, self).train(mode)
def train(self, mode: bool = True) -> None:
super().train(mode)
if self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment