Commit 6f3c5f1c authored by limm's avatar limm
Browse files

support v1.4.0

parent 6f674c7e
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.registry import MODELS
from .registry import CONV_LAYERS
def conv_ws_2d(input: torch.Tensor, def conv_ws_2d(input,
weight: torch.Tensor, weight,
bias: Optional[torch.Tensor] = None, bias=None,
stride: Union[int, Tuple[int, int]] = 1, stride=1,
padding: Union[int, Tuple[int, int]] = 0, padding=0,
dilation: Union[int, Tuple[int, int]] = 1, dilation=1,
groups: int = 1, groups=1,
eps: float = 1e-5) -> torch.Tensor: eps=1e-5):
c_in = weight.size(0) c_in = weight.size(0)
weight_flat = weight.view(c_in, -1) weight_flat = weight.view(c_in, -1)
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
...@@ -24,20 +22,20 @@ def conv_ws_2d(input: torch.Tensor, ...@@ -24,20 +22,20 @@ def conv_ws_2d(input: torch.Tensor,
return F.conv2d(input, weight, bias, stride, padding, dilation, groups) return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
@MODELS.register_module('ConvWS') @CONV_LAYERS.register_module('ConvWS')
class ConvWS2d(nn.Conv2d): class ConvWS2d(nn.Conv2d):
def __init__(self, def __init__(self,
in_channels: int, in_channels,
out_channels: int, out_channels,
kernel_size: Union[int, Tuple[int, int]], kernel_size,
stride: Union[int, Tuple[int, int]] = 1, stride=1,
padding: Union[int, Tuple[int, int]] = 0, padding=0,
dilation: Union[int, Tuple[int, int]] = 1, dilation=1,
groups: int = 1, groups=1,
bias: bool = True, bias=True,
eps: float = 1e-5): eps=1e-5):
super().__init__( super(ConvWS2d, self).__init__(
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -48,12 +46,12 @@ class ConvWS2d(nn.Conv2d): ...@@ -48,12 +46,12 @@ class ConvWS2d(nn.Conv2d):
bias=bias) bias=bias)
self.eps = eps self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.eps) self.dilation, self.groups, self.eps)
@MODELS.register_module(name='ConvAWS') @CONV_LAYERS.register_module(name='ConvAWS')
class ConvAWS2d(nn.Conv2d): class ConvAWS2d(nn.Conv2d):
"""AWS (Adaptive Weight Standardization) """AWS (Adaptive Weight Standardization)
...@@ -78,14 +76,14 @@ class ConvAWS2d(nn.Conv2d): ...@@ -78,14 +76,14 @@ class ConvAWS2d(nn.Conv2d):
""" """
def __init__(self, def __init__(self,
in_channels: int, in_channels,
out_channels: int, out_channels,
kernel_size: Union[int, Tuple[int, int]], kernel_size,
stride: Union[int, Tuple[int, int]] = 1, stride=1,
padding: Union[int, Tuple[int, int]] = 0, padding=0,
dilation: Union[int, Tuple[int, int]] = 1, dilation=1,
groups: int = 1, groups=1,
bias: bool = True): bias=True):
super().__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
...@@ -100,7 +98,7 @@ class ConvAWS2d(nn.Conv2d): ...@@ -100,7 +98,7 @@ class ConvAWS2d(nn.Conv2d):
self.register_buffer('weight_beta', self.register_buffer('weight_beta',
torch.zeros(self.out_channels, 1, 1, 1)) torch.zeros(self.out_channels, 1, 1, 1))
def _get_weight(self, weight: torch.Tensor) -> torch.Tensor: def _get_weight(self, weight):
weight_flat = weight.view(weight.size(0), -1) weight_flat = weight.view(weight.size(0), -1)
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
...@@ -108,16 +106,13 @@ class ConvAWS2d(nn.Conv2d): ...@@ -108,16 +106,13 @@ class ConvAWS2d(nn.Conv2d):
weight = self.weight_gamma * weight + self.weight_beta weight = self.weight_gamma * weight + self.weight_beta
return weight return weight
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
weight = self._get_weight(self.weight) weight = self._get_weight(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, return F.conv2d(x, weight, self.bias, self.stride, self.padding,
self.dilation, self.groups) self.dilation, self.groups)
def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str, def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
local_metadata: Dict, strict: bool, missing_keys, unexpected_keys, error_msgs):
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str]) -> None:
"""Override default load function. """Override default load function.
AWS overrides the function _load_from_state_dict to recover AWS overrides the function _load_from_state_dict to recover
...@@ -129,7 +124,7 @@ class ConvAWS2d(nn.Conv2d): ...@@ -129,7 +124,7 @@ class ConvAWS2d(nn.Conv2d):
""" """
self.weight_gamma.data.fill_(-1) self.weight_gamma.data.fill_(-1)
local_missing_keys: List = [] local_missing_keys = []
super()._load_from_state_dict(state_dict, prefix, local_metadata, super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, local_missing_keys, strict, local_missing_keys,
unexpected_keys, error_msgs) unexpected_keys, error_msgs)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from .conv_module import ConvModule from .conv_module import ConvModule
...@@ -49,27 +46,27 @@ class DepthwiseSeparableConvModule(nn.Module): ...@@ -49,27 +46,27 @@ class DepthwiseSeparableConvModule(nn.Module):
""" """
def __init__(self, def __init__(self,
in_channels: int, in_channels,
out_channels: int, out_channels,
kernel_size: Union[int, Tuple[int, int]], kernel_size,
stride: Union[int, Tuple[int, int]] = 1, stride=1,
padding: Union[int, Tuple[int, int]] = 0, padding=0,
dilation: Union[int, Tuple[int, int]] = 1, dilation=1,
norm_cfg: Optional[Dict] = None, norm_cfg=None,
act_cfg: Dict = dict(type='ReLU'), act_cfg=dict(type='ReLU'),
dw_norm_cfg: Union[Dict, str] = 'default', dw_norm_cfg='default',
dw_act_cfg: Union[Dict, str] = 'default', dw_act_cfg='default',
pw_norm_cfg: Union[Dict, str] = 'default', pw_norm_cfg='default',
pw_act_cfg: Union[Dict, str] = 'default', pw_act_cfg='default',
**kwargs): **kwargs):
super().__init__() super(DepthwiseSeparableConvModule, self).__init__()
assert 'groups' not in kwargs, 'groups should not be specified' assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not # if norm/activation config of depthwise/pointwise ConvModule is not
# specified, use default config. # specified, use default config.
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501 dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501 pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
# depthwise convolution # depthwise convolution
...@@ -81,19 +78,19 @@ class DepthwiseSeparableConvModule(nn.Module): ...@@ -81,19 +78,19 @@ class DepthwiseSeparableConvModule(nn.Module):
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
groups=in_channels, groups=in_channels,
norm_cfg=dw_norm_cfg, # type: ignore norm_cfg=dw_norm_cfg,
act_cfg=dw_act_cfg, # type: ignore act_cfg=dw_act_cfg,
**kwargs) **kwargs)
self.pointwise_conv = ConvModule( self.pointwise_conv = ConvModule(
in_channels, in_channels,
out_channels, out_channels,
1, 1,
norm_cfg=pw_norm_cfg, # type: ignore norm_cfg=pw_norm_cfg,
act_cfg=pw_act_cfg, # type: ignore act_cfg=pw_act_cfg,
**kwargs) **kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.pointwise_conv(x) x = self.pointwise_conv(x)
return x return x
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
from mmcv import build_from_cfg
from .registry import DROPOUT_LAYERS
def drop_path(x: torch.Tensor, def drop_path(x, drop_prob=0., training=False):
drop_prob: float = 0.,
training: bool = False) -> torch.Tensor:
"""Drop paths (Stochastic Depth) per sample (when applied in main path of """Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks). residual blocks).
...@@ -26,7 +24,7 @@ def drop_path(x: torch.Tensor, ...@@ -26,7 +24,7 @@ def drop_path(x: torch.Tensor,
return output return output
@MODELS.register_module() @DROPOUT_LAYERS.register_module()
class DropPath(nn.Module): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of """Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks). residual blocks).
...@@ -38,15 +36,15 @@ class DropPath(nn.Module): ...@@ -38,15 +36,15 @@ class DropPath(nn.Module):
drop_prob (float): Probability of the path to be zeroed. Default: 0.1 drop_prob (float): Probability of the path to be zeroed. Default: 0.1
""" """
def __init__(self, drop_prob: float = 0.1): def __init__(self, drop_prob=0.1):
super().__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
@MODELS.register_module() @DROPOUT_LAYERS.register_module()
class Dropout(nn.Dropout): class Dropout(nn.Dropout):
"""A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
...@@ -58,10 +56,10 @@ class Dropout(nn.Dropout): ...@@ -58,10 +56,10 @@ class Dropout(nn.Dropout):
inplace (bool): Do the operation inplace or not. Default: False. inplace (bool): Do the operation inplace or not. Default: False.
""" """
def __init__(self, drop_prob: float = 0.5, inplace: bool = False): def __init__(self, drop_prob=0.5, inplace=False):
super().__init__(p=drop_prob, inplace=inplace) super().__init__(p=drop_prob, inplace=inplace)
def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any: def build_dropout(cfg, default_args=None):
"""Builder for drop out layers.""" """Builder for drop out layers."""
return MODELS.build(cfg, default_args=default_args) return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
...@@ -5,16 +5,17 @@ import numpy as np ...@@ -5,16 +5,17 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.model import kaiming_init
from mmengine.registry import MODELS
from ..utils import kaiming_init
from .registry import PLUGIN_LAYERS
@MODELS.register_module()
@PLUGIN_LAYERS.register_module()
class GeneralizedAttention(nn.Module): class GeneralizedAttention(nn.Module):
"""GeneralizedAttention module. """GeneralizedAttention module.
See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks' See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
(https://arxiv.org/abs/1904.05873) for details. (https://arxiv.org/abs/1711.07971) for details.
Args: Args:
in_channels (int): Channels of the input feature map. in_channels (int): Channels of the input feature map.
...@@ -44,16 +45,16 @@ class GeneralizedAttention(nn.Module): ...@@ -44,16 +45,16 @@ class GeneralizedAttention(nn.Module):
_abbr_ = 'gen_attention_block' _abbr_ = 'gen_attention_block'
def __init__(self, def __init__(self,
in_channels: int, in_channels,
spatial_range: int = -1, spatial_range=-1,
num_heads: int = 9, num_heads=9,
position_embedding_dim: int = -1, position_embedding_dim=-1,
position_magnitude: int = 1, position_magnitude=1,
kv_stride: int = 2, kv_stride=2,
q_stride: int = 1, q_stride=1,
attention_type: str = '1111'): attention_type='1111'):
super().__init__() super(GeneralizedAttention, self).__init__()
# hard range means local range for non-local operation # hard range means local range for non-local operation
self.position_embedding_dim = ( self.position_embedding_dim = (
...@@ -130,7 +131,7 @@ class GeneralizedAttention(nn.Module): ...@@ -130,7 +131,7 @@ class GeneralizedAttention(nn.Module):
max_len_kv = int((max_len - 1.0) / self.kv_stride + 1) max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
local_constraint_map = np.ones( local_constraint_map = np.ones(
(max_len, max_len, max_len_kv, max_len_kv), dtype=int) (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
for iy in range(max_len): for iy in range(max_len):
for ix in range(max_len): for ix in range(max_len):
local_constraint_map[ local_constraint_map[
...@@ -212,7 +213,7 @@ class GeneralizedAttention(nn.Module): ...@@ -212,7 +213,7 @@ class GeneralizedAttention(nn.Module):
return embedding_x, embedding_y return embedding_x, embedding_y
def forward(self, x_input: torch.Tensor) -> torch.Tensor: def forward(self, x_input):
num_heads = self.num_heads num_heads = self.num_heads
# use empirical_attention # use empirical_attention
...@@ -350,7 +351,7 @@ class GeneralizedAttention(nn.Module): ...@@ -350,7 +351,7 @@ class GeneralizedAttention(nn.Module):
repeat(n, 1, 1, 1) repeat(n, 1, 1, 1)
position_feat_x_reshape = position_feat_x.\ position_feat_x_reshape = position_feat_x.\
view(n, num_heads, w * w_kv, self.qk_embed_dim) view(n, num_heads, w*w_kv, self.qk_embed_dim)
position_feat_y_reshape = position_feat_y.\ position_feat_y_reshape = position_feat_y.\
view(n, num_heads, h * h_kv, self.qk_embed_dim) view(n, num_heads, h * h_kv, self.qk_embed_dim)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
from .registry import ACTIVATION_LAYERS
@MODELS.register_module() @ACTIVATION_LAYERS.register_module()
class HSigmoid(nn.Module): class HSigmoid(nn.Module):
"""Hard Sigmoid Module. Apply the hard sigmoid function: """Hard Sigmoid Module. Apply the hard sigmoid function:
Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value) Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
Default: Hsigmoid(x) = min(max((x + 3) / 6, 0), 1) Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
Note:
In MMCV v1.4.4, we modified the default value of args to align with
PyTorch official.
Args: Args:
bias (float): Bias of the input feature map. Default: 3.0. bias (float): Bias of the input feature map. Default: 1.0.
divisor (float): Divisor of the input feature map. Default: 6.0. divisor (float): Divisor of the input feature map. Default: 2.0.
min_value (float): Lower bound value. Default: 0.0. min_value (float): Lower bound value. Default: 0.0.
max_value (float): Upper bound value. Default: 1.0. max_value (float): Upper bound value. Default: 1.0.
...@@ -26,25 +20,15 @@ class HSigmoid(nn.Module): ...@@ -26,25 +20,15 @@ class HSigmoid(nn.Module):
Tensor: The output tensor. Tensor: The output tensor.
""" """
def __init__(self, def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
bias: float = 3.0, super(HSigmoid, self).__init__()
divisor: float = 6.0,
min_value: float = 0.0,
max_value: float = 1.0):
super().__init__()
warnings.warn(
'In MMCV v1.4.4, we modified the default value of args to align '
'with PyTorch official. Previous Implementation: '
'Hsigmoid(x) = min(max((x + 1) / 2, 0), 1). '
'Current Implementation: '
'Hsigmoid(x) = min(max((x + 3) / 6, 0), 1).')
self.bias = bias self.bias = bias
self.divisor = divisor self.divisor = divisor
assert self.divisor != 0 assert self.divisor != 0
self.min_value = min_value self.min_value = min_value
self.max_value = max_value self.max_value = max_value
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
x = (x + self.bias) / self.divisor x = (x + self.bias) / self.divisor
return x.clamp_(self.min_value, self.max_value) return x.clamp_(self.min_value, self.max_value)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
class HSwish(nn.Module): class HSwish(nn.Module):
"""Hard Swish Module. """Hard Swish Module.
...@@ -22,18 +21,9 @@ class HSwish(nn.Module): ...@@ -22,18 +21,9 @@ class HSwish(nn.Module):
Tensor: The output tensor. Tensor: The output tensor.
""" """
def __init__(self, inplace: bool = False): def __init__(self, inplace=False):
super().__init__() super(HSwish, self).__init__()
self.act = nn.ReLU6(inplace) self.act = nn.ReLU6(inplace)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
return x * self.act(x + 3) / 6 return x * self.act(x + 3) / 6
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.7')):
# Hardswish is not supported when PyTorch version < 1.6.
# And Hardswish in PyTorch 1.6 does not support inplace.
MODELS.register_module(module=HSwish)
else:
MODELS.register_module(module=nn.Hardswish, name='HSwish')
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta from abc import ABCMeta
from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.model import constant_init, normal_init
from mmengine.registry import MODELS
from ..utils import constant_init, normal_init
from .conv_module import ConvModule from .conv_module import ConvModule
from .registry import PLUGIN_LAYERS
class _NonLocalNd(nn.Module, metaclass=ABCMeta): class _NonLocalNd(nn.Module, metaclass=ABCMeta):
...@@ -34,14 +33,14 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -34,14 +33,14 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
""" """
def __init__(self, def __init__(self,
in_channels: int, in_channels,
reduction: int = 2, reduction=2,
use_scale: bool = True, use_scale=True,
conv_cfg: Optional[Dict] = None, conv_cfg=None,
norm_cfg: Optional[Dict] = None, norm_cfg=None,
mode: str = 'embedded_gaussian', mode='embedded_gaussian',
**kwargs): **kwargs):
super().__init__() super(_NonLocalNd, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.reduction = reduction self.reduction = reduction
self.use_scale = use_scale self.use_scale = use_scale
...@@ -62,7 +61,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -62,7 +61,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.inter_channels, self.inter_channels,
kernel_size=1, kernel_size=1,
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
act_cfg=None) # type: ignore act_cfg=None)
self.conv_out = ConvModule( self.conv_out = ConvModule(
self.inter_channels, self.inter_channels,
self.in_channels, self.in_channels,
...@@ -97,7 +96,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -97,7 +96,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.init_weights(**kwargs) self.init_weights(**kwargs)
def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None: def init_weights(self, std=0.01, zeros_init=True):
if self.mode != 'gaussian': if self.mode != 'gaussian':
for m in [self.g, self.theta, self.phi]: for m in [self.g, self.theta, self.phi]:
normal_init(m.conv, std=std) normal_init(m.conv, std=std)
...@@ -114,8 +113,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -114,8 +113,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
else: else:
normal_init(self.conv_out.norm, std=std) normal_init(self.conv_out.norm, std=std)
def gaussian(self, theta_x: torch.Tensor, def gaussian(self, theta_x, phi_x):
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -123,8 +121,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -123,8 +121,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight = pairwise_weight.softmax(dim=-1) pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight return pairwise_weight
def embedded_gaussian(self, theta_x: torch.Tensor, def embedded_gaussian(self, theta_x, phi_x):
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -135,8 +132,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -135,8 +132,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight = pairwise_weight.softmax(dim=-1) pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight return pairwise_weight
def dot_product(self, theta_x: torch.Tensor, def dot_product(self, theta_x, phi_x):
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -144,8 +140,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -144,8 +140,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight /= pairwise_weight.shape[-1] pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight return pairwise_weight
def concatenation(self, theta_x: torch.Tensor, def concatenation(self, theta_x, phi_x):
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -162,7 +157,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -162,7 +157,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
return pairwise_weight return pairwise_weight
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
# Assume `reduction = 1`, then `inter_channels = C` # Assume `reduction = 1`, then `inter_channels = C`
# or `inter_channels = C` when `mode="gaussian"` # or `inter_channels = C` when `mode="gaussian"`
...@@ -229,11 +224,12 @@ class NonLocal1d(_NonLocalNd): ...@@ -229,11 +224,12 @@ class NonLocal1d(_NonLocalNd):
""" """
def __init__(self, def __init__(self,
in_channels: int, in_channels,
sub_sample: bool = False, sub_sample=False,
conv_cfg: Dict = dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
**kwargs): **kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) super(NonLocal1d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample self.sub_sample = sub_sample
...@@ -246,7 +242,7 @@ class NonLocal1d(_NonLocalNd): ...@@ -246,7 +242,7 @@ class NonLocal1d(_NonLocalNd):
self.phi = max_pool_layer self.phi = max_pool_layer
@MODELS.register_module() @PLUGIN_LAYERS.register_module()
class NonLocal2d(_NonLocalNd): class NonLocal2d(_NonLocalNd):
"""2D Non-local module. """2D Non-local module.
...@@ -262,11 +258,12 @@ class NonLocal2d(_NonLocalNd): ...@@ -262,11 +258,12 @@ class NonLocal2d(_NonLocalNd):
_abbr_ = 'nonlocal_block' _abbr_ = 'nonlocal_block'
def __init__(self, def __init__(self,
in_channels: int, in_channels,
sub_sample: bool = False, sub_sample=False,
conv_cfg: Dict = dict(type='Conv2d'), conv_cfg=dict(type='Conv2d'),
**kwargs): **kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) super(NonLocal2d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample self.sub_sample = sub_sample
...@@ -292,11 +289,12 @@ class NonLocal3d(_NonLocalNd): ...@@ -292,11 +289,12 @@ class NonLocal3d(_NonLocalNd):
""" """
def __init__(self, def __init__(self,
in_channels: int, in_channels,
sub_sample: bool = False, sub_sample=False,
conv_cfg: Dict = dict(type='Conv3d'), conv_cfg=dict(type='Conv3d'),
**kwargs): **kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) super(NonLocal3d, self).__init__(
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample self.sub_sample = sub_sample
if sub_sample: if sub_sample:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect import inspect
from typing import Dict, Tuple, Union
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
from mmengine.utils import is_tuple_of from mmcv.utils import is_tuple_of
from mmengine.utils.dl_utils.parrots_wrapper import (SyncBatchNorm, _BatchNorm, from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
_InstanceNorm) from .registry import NORM_LAYERS
MODELS.register_module('BN', module=nn.BatchNorm2d) NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
MODELS.register_module('BN1d', module=nn.BatchNorm1d) NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
MODELS.register_module('BN2d', module=nn.BatchNorm2d) NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
MODELS.register_module('BN3d', module=nn.BatchNorm3d) NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
MODELS.register_module('SyncBN', module=SyncBatchNorm) NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
MODELS.register_module('GN', module=nn.GroupNorm) NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
MODELS.register_module('LN', module=nn.LayerNorm) NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
MODELS.register_module('IN', module=nn.InstanceNorm2d) NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
MODELS.register_module('IN1d', module=nn.InstanceNorm1d) NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
MODELS.register_module('IN2d', module=nn.InstanceNorm2d) NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
MODELS.register_module('IN3d', module=nn.InstanceNorm3d) NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
def infer_abbr(class_type): def infer_abbr(class_type):
...@@ -70,9 +69,7 @@ def infer_abbr(class_type): ...@@ -70,9 +69,7 @@ def infer_abbr(class_type):
return 'norm_layer' return 'norm_layer'
def build_norm_layer(cfg: Dict, def build_norm_layer(cfg, num_features, postfix=''):
num_features: int,
postfix: Union[int, str] = '') -> Tuple[str, nn.Module]:
"""Build normalization layer. """Build normalization layer.
Args: Args:
...@@ -86,9 +83,9 @@ def build_norm_layer(cfg: Dict, ...@@ -86,9 +83,9 @@ def build_norm_layer(cfg: Dict,
to create named layer. to create named layer.
Returns: Returns:
tuple[str, nn.Module]: The first element is the layer name consisting (str, nn.Module): The first element is the layer name consisting of
of abbreviation and postfix, e.g., bn1, gn. The second element is the abbreviation and postfix, e.g., bn1, gn. The second element is the
created norm layer. created norm layer.
""" """
if not isinstance(cfg, dict): if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict') raise TypeError('cfg must be a dict')
...@@ -97,15 +94,10 @@ def build_norm_layer(cfg: Dict, ...@@ -97,15 +94,10 @@ def build_norm_layer(cfg: Dict,
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if layer_type not in NORM_LAYERS:
raise KeyError(f'Unrecognized norm type {layer_type}')
# Switch registry to the target scope. If `norm_layer` cannot be found norm_layer = NORM_LAYERS.get(layer_type)
# in the registry, fallback to search `norm_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
norm_layer = registry.get(layer_type)
if norm_layer is None:
raise KeyError(f'Cannot find {norm_layer} in registry under scope '
f'name {registry.scope}')
abbr = infer_abbr(norm_layer) abbr = infer_abbr(norm_layer)
assert isinstance(postfix, (int, str)) assert isinstance(postfix, (int, str))
...@@ -127,8 +119,7 @@ def build_norm_layer(cfg: Dict, ...@@ -127,8 +119,7 @@ def build_norm_layer(cfg: Dict,
return name, layer return name, layer
def is_norm(layer: nn.Module, def is_norm(layer, exclude=None):
exclude: Union[type, tuple, None] = None) -> bool:
"""Check if a layer is a normalization layer. """Check if a layer is a normalization layer.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
MODELS.register_module('zero', module=nn.ZeroPad2d) from .registry import PADDING_LAYERS
MODELS.register_module('reflect', module=nn.ReflectionPad2d)
MODELS.register_module('replicate', module=nn.ReplicationPad2d) PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module: def build_padding_layer(cfg, *args, **kwargs):
"""Build padding layer. """Build padding layer.
Args: Args:
cfg (dict): The padding layer config, which should contain: cfg (None or dict): The padding layer config, which should contain:
- type (str): Layer type. - type (str): Layer type.
- layer args: Args needed to instantiate a padding layer. - layer args: Args needed to instantiate a padding layer.
...@@ -27,15 +26,11 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module: ...@@ -27,15 +26,11 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy() cfg_ = cfg.copy()
padding_type = cfg_.pop('type') padding_type = cfg_.pop('type')
if padding_type not in PADDING_LAYERS:
raise KeyError(f'Unrecognized padding type {padding_type}.')
else:
padding_layer = PADDING_LAYERS.get(padding_type)
# Switch registry to the target scope. If `padding_layer` cannot be found
# in the registry, fallback to search `padding_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
padding_layer = registry.get(padding_type)
if padding_layer is None:
raise KeyError(f'Cannot find {padding_layer} in registry under scope '
f'name {registry.scope}')
layer = padding_layer(*args, **kwargs, **cfg_) layer = padding_layer(*args, **kwargs, **cfg_)
return layer return layer
# Copyright (c) OpenMMLab. All rights reserved.
import inspect import inspect
import platform import platform
from typing import Dict, Tuple, Union
import torch.nn as nn from .registry import PLUGIN_LAYERS
from mmengine.registry import MODELS
if platform.system() == 'Windows': if platform.system() == 'Windows':
import regex as re # type: ignore import regex as re
else: else:
import re # type: ignore import re
def infer_abbr(class_type: type) -> str: def infer_abbr(class_type):
"""Infer abbreviation from the class name. """Infer abbreviation from the class name.
This method will infer the abbreviation to map class types to This method will infer the abbreviation to map class types to
...@@ -50,27 +47,25 @@ def infer_abbr(class_type: type) -> str: ...@@ -50,27 +47,25 @@ def infer_abbr(class_type: type) -> str:
raise TypeError( raise TypeError(
f'class_type must be a type, but got {type(class_type)}') f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, '_abbr_'): if hasattr(class_type, '_abbr_'):
return class_type._abbr_ # type: ignore return class_type._abbr_
else: else:
return camel2snack(class_type.__name__) return camel2snack(class_type.__name__)
def build_plugin_layer(cfg: Dict, def build_plugin_layer(cfg, postfix='', **kwargs):
postfix: Union[int, str] = '',
**kwargs) -> Tuple[str, nn.Module]:
"""Build plugin layer. """Build plugin layer.
Args: Args:
cfg (dict): cfg should contain: cfg (None or dict): cfg should contain:
type (str): identify plugin layer type.
- type (str): identify plugin layer type. layer args: args needed to instantiate a plugin layer.
- layer args: args needed to instantiate a plugin layer.
postfix (int, str): appended into norm abbreviation to postfix (int, str): appended into norm abbreviation to
create named layer. Default: ''. create named layer. Default: ''.
Returns: Returns:
tuple[str, nn.Module]: The first one is the concatenation of tuple[str, nn.Module]:
abbreviation and postfix. The second is the created plugin layer. name (str): abbreviation + postfix
layer (nn.Module): created plugin layer
""" """
if not isinstance(cfg, dict): if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict') raise TypeError('cfg must be a dict')
...@@ -79,15 +74,10 @@ def build_plugin_layer(cfg: Dict, ...@@ -79,15 +74,10 @@ def build_plugin_layer(cfg: Dict,
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if layer_type not in PLUGIN_LAYERS:
raise KeyError(f'Unrecognized plugin type {layer_type}')
# Switch registry to the target scope. If `plugin_layer` cannot be found plugin_layer = PLUGIN_LAYERS.get(layer_type)
# in the registry, fallback to search `plugin_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
plugin_layer = registry.get(layer_type)
if plugin_layer is None:
raise KeyError(f'Cannot find {plugin_layer} in registry under scope '
f'name {registry.scope}')
abbr = infer_abbr(plugin_layer) abbr = infer_abbr(plugin_layer)
assert isinstance(postfix, (int, str)) assert isinstance(postfix, (int, str))
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry
CONV_LAYERS = Registry('conv layer')
NORM_LAYERS = Registry('norm layer')
ACTIVATION_LAYERS = Registry('activation layer')
PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer')
DROPOUT_LAYERS = Registry('drop out layers')
POSITIONAL_ENCODING = Registry('position encoding')
ATTENTION = Registry('attention')
FEEDFORWARD_NETWORK = Registry('feed-forward Network')
TRANSFORMER_LAYER = Registry('transformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
...@@ -13,45 +13,9 @@ class Scale(nn.Module): ...@@ -13,45 +13,9 @@ class Scale(nn.Module):
scale (float): Initial value of scale factor. Default: 1.0 scale (float): Initial value of scale factor. Default: 1.0
""" """
def __init__(self, scale: float = 1.0): def __init__(self, scale=1.0):
super().__init__() super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
return x * self.scale return x * self.scale
class LayerScale(nn.Module):
"""LayerScale layer.
Args:
dim (int): Dimension of input features.
inplace (bool): Whether performs operation in-place.
Default: `False`.
data_format (str): The input data format, could be 'channels_last'
or 'channels_first', representing (B, C, H, W) and
(B, N, C) format data respectively. Default: 'channels_last'.
scale (float): Initial value of scale factor. Default: 1.0
"""
def __init__(self,
dim: int,
inplace: bool = False,
data_format: str = 'channels_last',
scale: float = 1e-5):
super().__init__()
assert data_format in ('channels_last', 'channels_first'), \
"'data_format' could only be channels_last or channels_first."
self.inplace = inplace
self.data_format = data_format
self.weight = nn.Parameter(torch.ones(dim) * scale)
def forward(self, x) -> torch.Tensor:
if self.data_format == 'channels_first':
shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
else:
shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
if self.inplace:
return x.mul_(self.weight.view(*shape))
else:
return x * self.weight.view(*shape)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
from .registry import ACTIVATION_LAYERS
@MODELS.register_module()
@ACTIVATION_LAYERS.register_module()
class Swish(nn.Module): class Swish(nn.Module):
"""Swish Module. """Swish Module.
...@@ -18,7 +19,7 @@ class Swish(nn.Module): ...@@ -18,7 +19,7 @@ class Swish(nn.Module):
""" """
def __init__(self): def __init__(self):
super().__init__() super(Swish, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import math
import warnings import warnings
from typing import Sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from mmengine.config import ConfigDict from mmcv import ConfigDict, deprecated_api_warning
from mmengine.model import BaseModule, ModuleList, Sequential from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmengine.registry import MODELS from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmengine.utils import deprecated_api_warning, to_2tuple from mmcv.utils import build_from_cfg
from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
build_norm_layer)
from .drop import build_dropout from .drop import build_dropout
from .scale import LayerScale from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try: try:
from mmcv.ops.multi_scale_deform_attn import \ from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
MultiScaleDeformableAttention # noqa F401
warnings.warn( warnings.warn(
ImportWarning( ImportWarning(
'``MultiScaleDeformableAttention`` has been moved to ' '``MultiScaleDeformableAttention`` has been moved to '
...@@ -32,379 +27,35 @@ try: ...@@ -32,379 +27,35 @@ try:
except ImportError: except ImportError:
warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from ' warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
'``mmcv.ops.multi_scale_deform_attn``, ' '``mmcv.ops.multi_scale_deform_attn``, '
'You should install ``mmcv`` rather than ``mmcv-lite`` ' 'You should install ``mmcv-full`` if you need this module. ')
'if you need this module. ')
def build_positional_encoding(cfg, default_args=None): def build_positional_encoding(cfg, default_args=None):
"""Builder for Position Encoding.""" """Builder for Position Encoding."""
return MODELS.build(cfg, default_args=default_args) return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
def build_attention(cfg, default_args=None): def build_attention(cfg, default_args=None):
"""Builder for attention.""" """Builder for attention."""
return MODELS.build(cfg, default_args=default_args) return build_from_cfg(cfg, ATTENTION, default_args)
def build_feedforward_network(cfg, default_args=None): def build_feedforward_network(cfg, default_args=None):
"""Builder for feed-forward network (FFN).""" """Builder for feed-forward network (FFN)."""
return MODELS.build(cfg, default_args=default_args) return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
def build_transformer_layer(cfg, default_args=None): def build_transformer_layer(cfg, default_args=None):
"""Builder for transformer layer.""" """Builder for transformer layer."""
return MODELS.build(cfg, default_args=default_args) return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
def build_transformer_layer_sequence(cfg, default_args=None): def build_transformer_layer_sequence(cfg, default_args=None):
"""Builder for transformer encoder and transformer decoder.""" """Builder for transformer encoder and transformer decoder."""
return MODELS.build(cfg, default_args=default_args) return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
class AdaptivePadding(nn.Module):
"""Applies padding adaptively to the input.
This module can make input get fully covered by filter
you specified. It support two modes "same" and "corner". The
"same" mode is same with "SAME" padding mode in TensorFlow, pad
zero around input. The "corner" mode would pad zero
to bottom right.
Args:
kernel_size (int | tuple): Size of the kernel. Default: 1.
stride (int | tuple): Stride of the filter. Default: 1.
dilation (int | tuple): Spacing between kernel elements.
Default: 1.
padding (str): Support "same" and "corner", "corner" mode
would pad zero to bottom right, and "same" mode would
pad zero around input. Default: "corner".
Example:
>>> kernel_size = 16
>>> stride = 16
>>> dilation = 1
>>> input = torch.rand(1, 1, 15, 17)
>>> adap_pad = AdaptivePadding(
>>> kernel_size=kernel_size,
>>> stride=stride,
>>> dilation=dilation,
>>> padding="corner")
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
>>> input = torch.rand(1, 1, 16, 17)
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super().__init__()
assert padding in ('same', 'corner')
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
self.padding = padding
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
def get_pad_shape(self, input_shape):
"""Calculate the padding size of input.
Args:
input_shape (:obj:`torch.Size`): arrange as (H, W).
Returns:
Tuple[int]: The padding size along the
original H and W directions
"""
input_h, input_w = input_shape
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.stride
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max((output_h - 1) * stride_h +
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
pad_w = max((output_w - 1) * stride_w +
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
return pad_h, pad_w
def forward(self, x):
"""Add padding to `x`
Args:
x (Tensor): Input tensor has shape (B, C, H, W).
Returns:
Tensor: The tensor with adaptive padding
"""
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
if pad_h > 0 or pad_w > 0:
if self.padding == 'corner':
x = F.pad(x, [0, pad_w, 0, pad_h])
elif self.padding == 'same':
x = F.pad(x, [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2
])
return x
class PatchEmbed(BaseModule): @ATTENTION.register_module()
"""Image to Patch Embedding.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (str): The type of convolution
to generate patch embedding. Default: "Conv2d".
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: 16.
padding (int | tuple | string): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int): The dilation rate of embedding conv. Default: 1.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
input_size (int | tuple | None): The size of input, which will be
used to calculate the out size. Only works when `dynamic_size`
is False. Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type='Conv2d',
kernel_size=16,
stride=16,
padding='corner',
dilation=1,
bias=True,
norm_cfg=None,
input_size=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
if stride is None:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of conv
padding = 0
else:
self.adaptive_padding = None
padding = to_2tuple(padding)
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
if input_size:
input_size = to_2tuple(input_size)
# `init_out_size` would be used outside to
# calculate the num_patches
# e.g. when `use_abs_pos_embed` outside
self.init_input_size = input_size
if self.adaptive_padding:
pad_h, pad_w = self.adaptive_padding.get_pad_shape(input_size)
input_h, input_w = input_size
input_h = input_h + pad_h
input_w = input_w + pad_w
input_size = (input_h, input_w)
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
self.init_out_size = (h_out, w_out)
else:
self.init_input_size = None
self.init_out_size = None
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adaptive_padding:
x = self.adaptive_padding(x)
x = self.projection(x)
out_size = (x.shape[2], x.shape[3])
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x, out_size
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map ((used in Swin Transformer)).
Our implementation uses `nn.Unfold` to
merge patches, which is about 25% faster than the original
implementation. However, we need to modify pretrained
models for compatibility.
Args:
in_channels (int): The num of input channels.
to gets fully covered by filter and stride you specified.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Default: None. (Would be set as `kernel_size`)
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Default: 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding='corner',
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of unfold
padding = 0
else:
self.adaptive_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f'Expect ' \
f'input_size is ' \
f'`Sequence` ' \
f'but get {input_size}'
H, W = input_size
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
if self.adaptive_padding:
x = self.adaptive_padding(x)
H, W = x.shape[-2:]
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
x = self.sampler(x)
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
(self.sampler.kernel_size[0] - 1) -
1) // self.sampler.stride[0] + 1
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
(self.sampler.kernel_size[1] - 1) -
1) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size
@MODELS.register_module()
class MultiheadAttention(BaseModule): class MultiheadAttention(BaseModule):
"""A wrapper for ``torch.nn.MultiheadAttention``. """A wrapper for ``torch.nn.MultiheadAttention``.
...@@ -436,13 +87,12 @@ class MultiheadAttention(BaseModule): ...@@ -436,13 +87,12 @@ class MultiheadAttention(BaseModule):
init_cfg=None, init_cfg=None,
batch_first=False, batch_first=False,
**kwargs): **kwargs):
super().__init__(init_cfg) super(MultiheadAttention, self).__init__(init_cfg)
if 'dropout' in kwargs: if 'dropout' in kwargs:
warnings.warn( warnings.warn('The arguments `dropout` in MultiheadAttention '
'The arguments `dropout` in MultiheadAttention ' 'has been deprecated, now you can separately '
'has been deprecated, now you can separately ' 'set `attn_drop`(float), proj_drop(float), '
'set `attn_drop`(float), proj_drop(float), ' 'and `dropout_layer`(dict) ')
'and `dropout_layer`(dict) ', DeprecationWarning)
attn_drop = kwargs['dropout'] attn_drop = kwargs['dropout']
dropout_layer['drop_prob'] = kwargs.pop('dropout') dropout_layer['drop_prob'] = kwargs.pop('dropout')
...@@ -504,9 +154,9 @@ class MultiheadAttention(BaseModule): ...@@ -504,9 +154,9 @@ class MultiheadAttention(BaseModule):
Returns: Returns:
Tensor: forwarded results with shape Tensor: forwarded results with shape
[num_queries, bs, embed_dims] [num_queries, bs, embed_dims]
if self.batch_first is False, else if self.batch_first is False, else
[bs, num_queries embed_dims]. [bs, num_queries embed_dims].
""" """
if key is None: if key is None:
...@@ -552,7 +202,7 @@ class MultiheadAttention(BaseModule): ...@@ -552,7 +202,7 @@ class MultiheadAttention(BaseModule):
return identity + self.dropout_layer(self.proj_drop(out)) return identity + self.dropout_layer(self.proj_drop(out))
@MODELS.register_module() @FEEDFORWARD_NETWORK.register_module()
class FFN(BaseModule): class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with identity connection. """Implements feed-forward networks (FFNs) with identity connection.
...@@ -573,8 +223,6 @@ class FFN(BaseModule): ...@@ -573,8 +223,6 @@ class FFN(BaseModule):
when adding the shortcut. when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None. Default: None.
layer_scale_init_value (float): Initial value of scale factor in
LayerScale. Default: 1.0
""" """
@deprecated_api_warning( @deprecated_api_warning(
...@@ -592,21 +240,23 @@ class FFN(BaseModule): ...@@ -592,21 +240,23 @@ class FFN(BaseModule):
dropout_layer=None, dropout_layer=None,
add_identity=True, add_identity=True,
init_cfg=None, init_cfg=None,
layer_scale_init_value=0.): **kwargs):
super().__init__(init_cfg) super(FFN, self).__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \ assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.' f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
layers = [] layers = []
in_channels = embed_dims in_channels = embed_dims
for _ in range(num_fcs - 1): for _ in range(num_fcs - 1):
layers.append( layers.append(
Sequential( Sequential(
Linear(in_channels, feedforward_channels), Linear(in_channels, feedforward_channels), self.activate,
build_activation_layer(act_cfg), nn.Dropout(ffn_drop))) nn.Dropout(ffn_drop)))
in_channels = feedforward_channels in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims)) layers.append(Linear(feedforward_channels, embed_dims))
layers.append(nn.Dropout(ffn_drop)) layers.append(nn.Dropout(ffn_drop))
...@@ -615,11 +265,6 @@ class FFN(BaseModule): ...@@ -615,11 +265,6 @@ class FFN(BaseModule):
dropout_layer) if dropout_layer else torch.nn.Identity() dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity self.add_identity = add_identity
if layer_scale_init_value > 0:
self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value)
else:
self.gamma2 = nn.Identity()
@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
def forward(self, x, identity=None): def forward(self, x, identity=None):
"""Forward function for `FFN`. """Forward function for `FFN`.
...@@ -627,7 +272,6 @@ class FFN(BaseModule): ...@@ -627,7 +272,6 @@ class FFN(BaseModule):
The function would add x to the output tensor if residue is None. The function would add x to the output tensor if residue is None.
""" """
out = self.layers(x) out = self.layers(x)
out = self.gamma2(out)
if not self.add_identity: if not self.add_identity:
return self.dropout_layer(out) return self.dropout_layer(out)
if identity is None: if identity is None:
...@@ -635,7 +279,7 @@ class FFN(BaseModule): ...@@ -635,7 +279,7 @@ class FFN(BaseModule):
return identity + self.dropout_layer(out) return identity + self.dropout_layer(out)
@MODELS.register_module() @TRANSFORMER_LAYER.register_module()
class BaseTransformerLayer(BaseModule): class BaseTransformerLayer(BaseModule):
"""Base `TransformerLayer` for vision transformer. """Base `TransformerLayer` for vision transformer.
...@@ -698,15 +342,15 @@ class BaseTransformerLayer(BaseModule): ...@@ -698,15 +342,15 @@ class BaseTransformerLayer(BaseModule):
f'The arguments `{ori_name}` in BaseTransformerLayer ' f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` ' f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments ' f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ', DeprecationWarning) f'to a dict named `ffn_cfgs`. ')
ffn_cfgs[new_name] = kwargs[ori_name] ffn_cfgs[new_name] = kwargs[ori_name]
super().__init__(init_cfg) super(BaseTransformerLayer, self).__init__(init_cfg)
self.batch_first = batch_first self.batch_first = batch_first
assert set(operation_order) & { assert set(operation_order) & set(
'self_attn', 'norm', 'ffn', 'cross_attn'} == \ ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
set(operation_order), f'The operation_order of' \ set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \ f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \ f'contains all four operation type ' \
...@@ -753,7 +397,7 @@ class BaseTransformerLayer(BaseModule): ...@@ -753,7 +397,7 @@ class BaseTransformerLayer(BaseModule):
assert len(ffn_cfgs) == num_ffns assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns): for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]: if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims ffn_cfgs['embed_dims'] = self.embed_dims
else: else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append( self.ffns.append(
...@@ -866,7 +510,7 @@ class BaseTransformerLayer(BaseModule): ...@@ -866,7 +510,7 @@ class BaseTransformerLayer(BaseModule):
return query return query
@MODELS.register_module() @TRANSFORMER_LAYER_SEQUENCE.register_module()
class TransformerLayerSequence(BaseModule): class TransformerLayerSequence(BaseModule):
"""Base class for TransformerEncoder and TransformerDecoder in vision """Base class for TransformerEncoder and TransformerDecoder in vision
transformer. transformer.
...@@ -887,7 +531,7 @@ class TransformerLayerSequence(BaseModule): ...@@ -887,7 +531,7 @@ class TransformerLayerSequence(BaseModule):
""" """
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super().__init__(init_cfg) super(TransformerLayerSequence, self).__init__(init_cfg)
if isinstance(transformerlayers, dict): if isinstance(transformerlayers, dict):
transformerlayers = [ transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers) copy.deepcopy(transformerlayers) for _ in range(num_layers)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.model import xavier_init
from mmengine.registry import MODELS
MODELS.register_module('nearest', module=nn.Upsample) from ..utils import xavier_init
MODELS.register_module('bilinear', module=nn.Upsample) from .registry import UPSAMPLE_LAYERS
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
@MODELS.register_module(name='pixel_shuffle') @UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
class PixelShufflePack(nn.Module): class PixelShufflePack(nn.Module):
"""Pixel Shuffle upsample layer. """Pixel Shuffle upsample layer.
...@@ -26,9 +24,9 @@ class PixelShufflePack(nn.Module): ...@@ -26,9 +24,9 @@ class PixelShufflePack(nn.Module):
channels. channels.
""" """
def __init__(self, in_channels: int, out_channels: int, scale_factor: int, def __init__(self, in_channels, out_channels, scale_factor,
upsample_kernel: int): upsample_kernel):
super().__init__() super(PixelShufflePack, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.scale_factor = scale_factor self.scale_factor = scale_factor
...@@ -43,13 +41,13 @@ class PixelShufflePack(nn.Module): ...@@ -43,13 +41,13 @@ class PixelShufflePack(nn.Module):
def init_weights(self): def init_weights(self):
xavier_init(self.upsample_conv, distribution='uniform') xavier_init(self.upsample_conv, distribution='uniform')
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
x = self.upsample_conv(x) x = self.upsample_conv(x)
x = F.pixel_shuffle(x, self.scale_factor) x = F.pixel_shuffle(x, self.scale_factor)
return x return x
def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: def build_upsample_layer(cfg, *args, **kwargs):
"""Build upsample layer. """Build upsample layer.
Args: Args:
...@@ -57,7 +55,7 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: ...@@ -57,7 +55,7 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
- type (str): Layer type. - type (str): Layer type.
- scale_factor (int): Upsample ratio, which is not applicable to - scale_factor (int): Upsample ratio, which is not applicable to
deconv. deconv.
- layer args: Args needed to instantiate a upsample layer. - layer args: Args needed to instantiate a upsample layer.
args (argument list): Arguments passed to the ``__init__`` args (argument list): Arguments passed to the ``__init__``
method of the corresponding conv layer. method of the corresponding conv layer.
...@@ -75,15 +73,11 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: ...@@ -75,15 +73,11 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if layer_type not in UPSAMPLE_LAYERS:
raise KeyError(f'Unrecognized upsample type {layer_type}')
else:
upsample = UPSAMPLE_LAYERS.get(layer_type)
# Switch registry to the target scope. If `upsample` cannot be found
# in the registry, fallback to search `upsample` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
upsample = registry.get(layer_type)
if upsample is None:
raise KeyError(f'Cannot find {upsample} in registry under scope '
f'name {registry.scope}')
if upsample is nn.Upsample: if upsample is nn.Upsample:
cfg_['mode'] = layer_type cfg_['mode'] = layer_type
layer = upsample(*args, **kwargs, **cfg_) layer = upsample(*args, **kwargs, **cfg_)
......
...@@ -9,9 +9,10 @@ import math ...@@ -9,9 +9,10 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
from torch.nn.modules.utils import _pair, _triple from torch.nn.modules.utils import _pair, _triple
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
if torch.__version__ == 'parrots': if torch.__version__ == 'parrots':
TORCH_VERSION = torch.__version__ TORCH_VERSION = torch.__version__
else: else:
...@@ -20,27 +21,27 @@ else: ...@@ -20,27 +21,27 @@ else:
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
def obsolete_torch_version(torch_version, version_threshold) -> bool: def obsolete_torch_version(torch_version, version_threshold):
return torch_version == 'parrots' or torch_version <= version_threshold return torch_version == 'parrots' or torch_version <= version_threshold
class NewEmptyTensorOp(torch.autograd.Function): class NewEmptyTensorOp(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: torch.Tensor, new_shape: tuple) -> torch.Tensor: def forward(ctx, x, new_shape):
ctx.shape = x.shape ctx.shape = x.shape
return x.new_empty(new_shape) return x.new_empty(new_shape)
@staticmethod @staticmethod
def backward(ctx, grad: torch.Tensor) -> tuple: def backward(ctx, grad):
shape = ctx.shape shape = ctx.shape
return NewEmptyTensorOp.apply(grad, shape), None return NewEmptyTensorOp.apply(grad, shape), None
@MODELS.register_module('Conv', force=True) @CONV_LAYERS.register_module('Conv', force=True)
class Conv2d(nn.Conv2d): class Conv2d(nn.Conv2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
...@@ -58,10 +59,10 @@ class Conv2d(nn.Conv2d): ...@@ -58,10 +59,10 @@ class Conv2d(nn.Conv2d):
return super().forward(x) return super().forward(x)
@MODELS.register_module('Conv3d', force=True) @CONV_LAYERS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d): class Conv3d(nn.Conv3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
...@@ -79,11 +80,12 @@ class Conv3d(nn.Conv3d): ...@@ -79,11 +80,12 @@ class Conv3d(nn.Conv3d):
return super().forward(x) return super().forward(x)
@MODELS.register_module() @CONV_LAYERS.register_module()
@MODELS.register_module('deconv') @CONV_LAYERS.register_module('deconv')
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
class ConvTranspose2d(nn.ConvTranspose2d): class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
...@@ -101,11 +103,12 @@ class ConvTranspose2d(nn.ConvTranspose2d): ...@@ -101,11 +103,12 @@ class ConvTranspose2d(nn.ConvTranspose2d):
return super().forward(x) return super().forward(x)
@MODELS.register_module() @CONV_LAYERS.register_module()
@MODELS.register_module('deconv3d') @CONV_LAYERS.register_module('deconv3d')
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d): class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
...@@ -125,7 +128,7 @@ class ConvTranspose3d(nn.ConvTranspose3d): ...@@ -125,7 +128,7 @@ class ConvTranspose3d(nn.ConvTranspose3d):
class MaxPool2d(nn.MaxPool2d): class MaxPool2d(nn.MaxPool2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
# PyTorch 1.9 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
...@@ -143,7 +146,7 @@ class MaxPool2d(nn.MaxPool2d): ...@@ -143,7 +146,7 @@ class MaxPool2d(nn.MaxPool2d):
class MaxPool3d(nn.MaxPool3d): class MaxPool3d(nn.MaxPool3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
# PyTorch 1.9 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
...@@ -162,7 +165,7 @@ class MaxPool3d(nn.MaxPool3d): ...@@ -162,7 +165,7 @@ class MaxPool3d(nn.MaxPool3d):
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x):
# empty tensor forward of Linear layer is supported in Pytorch 1.6 # empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
out_shape = [x.shape[0], self.out_features] out_shape = [x.shape[0], self.out_features]
......
# Copyright (c) OpenMMLab. All rights reserved.
from ..runner import Sequential
from ..utils import Registry, build_from_cfg
def build_model_from_cfg(cfg, registry, default_args=None):
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
Args:
cfg (dict, list[dict]): The config of modules, is is either a config
dict or a list of config dicts. If cfg is a list, a
the built modules will be wrapped with ``nn.Sequential``.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
MODELS = Registry('model', build_func=build_model_from_cfg)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging import logging
from typing import Optional, Sequence, Tuple, Union
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmengine.model import constant_init, kaiming_init
from mmengine.runner import load_checkpoint
from torch import Tensor
from .utils import constant_init, kaiming_init
def conv3x3(in_planes: int,
out_planes: int, def conv3x3(in_planes, out_planes, stride=1, dilation=1):
stride: int = 1,
dilation: int = 1):
"""3x3 convolution with padding.""" """3x3 convolution with padding."""
return nn.Conv2d( return nn.Conv2d(
in_planes, in_planes,
...@@ -28,14 +23,14 @@ class BasicBlock(nn.Module): ...@@ -28,14 +23,14 @@ class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, def __init__(self,
inplanes: int, inplanes,
planes: int, planes,
stride: int = 1, stride=1,
dilation: int = 1, dilation=1,
downsample: Optional[nn.Module] = None, downsample=None,
style: str = 'pytorch', style='pytorch',
with_cp: bool = False): with_cp=False):
super().__init__() super(BasicBlock, self).__init__()
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
...@@ -47,7 +42,7 @@ class BasicBlock(nn.Module): ...@@ -47,7 +42,7 @@ class BasicBlock(nn.Module):
self.dilation = dilation self.dilation = dilation
assert not with_cp assert not with_cp
def forward(self, x: Tensor) -> Tensor: def forward(self, x):
residual = x residual = x
out = self.conv1(x) out = self.conv1(x)
...@@ -70,19 +65,19 @@ class Bottleneck(nn.Module): ...@@ -70,19 +65,19 @@ class Bottleneck(nn.Module):
expansion = 4 expansion = 4
def __init__(self, def __init__(self,
inplanes: int, inplanes,
planes: int, planes,
stride: int = 1, stride=1,
dilation: int = 1, dilation=1,
downsample: Optional[nn.Module] = None, downsample=None,
style: str = 'pytorch', style='pytorch',
with_cp: bool = False): with_cp=False):
"""Bottleneck block. """Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer. it is "caffe", the stride-two layer is the first 1x1 conv layer.
""" """
super().__init__() super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
if style == 'pytorch': if style == 'pytorch':
conv1_stride = 1 conv1_stride = 1
...@@ -112,7 +107,7 @@ class Bottleneck(nn.Module): ...@@ -112,7 +107,7 @@ class Bottleneck(nn.Module):
self.dilation = dilation self.dilation = dilation
self.with_cp = with_cp self.with_cp = with_cp
def forward(self, x: Tensor) -> Tensor: def forward(self, x):
def _inner_forward(x): def _inner_forward(x):
residual = x residual = x
...@@ -145,14 +140,14 @@ class Bottleneck(nn.Module): ...@@ -145,14 +140,14 @@ class Bottleneck(nn.Module):
return out return out
def make_res_layer(block: nn.Module, def make_res_layer(block,
inplanes: int, inplanes,
planes: int, planes,
blocks: int, blocks,
stride: int = 1, stride=1,
dilation: int = 1, dilation=1,
style: str = 'pytorch', style='pytorch',
with_cp: bool = False) -> nn.Module: with_cp=False):
downsample = None downsample = None
if stride != 1 or inplanes != planes * block.expansion: if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
...@@ -213,22 +208,22 @@ class ResNet(nn.Module): ...@@ -213,22 +208,22 @@ class ResNet(nn.Module):
} }
def __init__(self, def __init__(self,
depth: int, depth,
num_stages: int = 4, num_stages=4,
strides: Sequence[int] = (1, 2, 2, 2), strides=(1, 2, 2, 2),
dilations: Sequence[int] = (1, 1, 1, 1), dilations=(1, 1, 1, 1),
out_indices: Sequence[int] = (0, 1, 2, 3), out_indices=(0, 1, 2, 3),
style: str = 'pytorch', style='pytorch',
frozen_stages: int = -1, frozen_stages=-1,
bn_eval: bool = True, bn_eval=True,
bn_frozen: bool = False, bn_frozen=False,
with_cp: bool = False): with_cp=False):
super().__init__() super(ResNet, self).__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet') raise KeyError(f'invalid depth {depth} for resnet')
assert num_stages >= 1 and num_stages <= 4 assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth] block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages] # type: ignore stage_blocks = stage_blocks[:num_stages]
assert len(strides) == len(dilations) == num_stages assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages assert max(out_indices) < num_stages
...@@ -239,7 +234,7 @@ class ResNet(nn.Module): ...@@ -239,7 +234,7 @@ class ResNet(nn.Module):
self.bn_frozen = bn_frozen self.bn_frozen = bn_frozen
self.with_cp = with_cp self.with_cp = with_cp
self.inplanes: int = 64 self.inplanes = 64
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False) 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
...@@ -260,17 +255,17 @@ class ResNet(nn.Module): ...@@ -260,17 +255,17 @@ class ResNet(nn.Module):
dilation=dilation, dilation=dilation,
style=self.style, style=self.style,
with_cp=with_cp) with_cp=with_cp)
self.inplanes = planes * block.expansion # type: ignore self.inplanes = planes * block.expansion
layer_name = f'layer{i + 1}' layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name) self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**( # type: ignore self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
len(stage_blocks) - 1)
def init_weights(self, pretrained: Optional[str] = None) -> None: def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
from ..runner import load_checkpoint
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
for m in self.modules(): for m in self.modules():
...@@ -281,7 +276,7 @@ class ResNet(nn.Module): ...@@ -281,7 +276,7 @@ class ResNet(nn.Module):
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]: def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
x = self.relu(x) x = self.relu(x)
...@@ -297,8 +292,8 @@ class ResNet(nn.Module): ...@@ -297,8 +292,8 @@ class ResNet(nn.Module):
else: else:
return tuple(outs) return tuple(outs)
def train(self, mode: bool = True) -> None: def train(self, mode=True):
super().train(mode) super(ResNet, self).train(mode)
if self.bn_eval: if self.bn_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
......
# Copyright (c) OpenMMLab. All rights reserved.
from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp
from .search import RFSearchHook
__all__ = ['BaseConvRFSearchOp', 'Conv2dRFSearchOp', 'RFSearchHook']
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import numpy as np
import torch
import torch.nn as nn
from mmengine.logging import print_log
from mmengine.model import BaseModule
from torch import Tensor
from .utils import expand_rates, get_single_padding
class BaseConvRFSearchOp(BaseModule):
"""Based class of ConvRFSearchOp.
Args:
op_layer (nn.Module): pytorch module, e,g, Conv2d
global_config (dict): config dict.
"""
def __init__(self, op_layer: nn.Module, global_config: dict):
super().__init__()
self.op_layer = op_layer
self.global_config = global_config
def normlize(self, weights: nn.Parameter) -> nn.Parameter:
"""Normalize weights.
Args:
weights (nn.Parameter): Weights to be normalized.
Returns:
nn.Parameters: Normalized weights.
"""
abs_weights = torch.abs(weights)
normalized_weights = abs_weights / torch.sum(abs_weights)
return normalized_weights
class Conv2dRFSearchOp(BaseConvRFSearchOp):
"""Enable Conv2d with receptive field searching ability.
Args:
op_layer (nn.Module): pytorch module, e,g, Conv2d
global_config (dict): config dict. Defaults to None.
By default this must include:
- "init_alphas": The value for initializing weights of each branch.
- "num_branches": The controller of the size of
search space (the number of branches).
- "exp_rate": The controller of the sparsity of search space.
- "mmin": The minimum dilation rate.
- "mmax": The maximum dilation rate.
Extra keys may exist, but are used by RFSearchHook, e.g., "step",
"max_step", "search_interval", and "skip_layer".
verbose (bool): Determines whether to print rf-next
related logging messages.
Defaults to True.
"""
def __init__(self,
op_layer: nn.Module,
global_config: dict,
verbose: bool = True):
super().__init__(op_layer, global_config)
assert global_config is not None, 'global_config is None'
self.num_branches = global_config['num_branches']
assert self.num_branches in [2, 3]
self.verbose = verbose
init_dilation = op_layer.dilation
self.dilation_rates = expand_rates(init_dilation, global_config)
if self.op_layer.kernel_size[
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0:
self.dilation_rates = [(op_layer.dilation[0], r[1])
for r in self.dilation_rates]
if self.op_layer.kernel_size[
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0:
self.dilation_rates = [(r[0], op_layer.dilation[1])
for r in self.dilation_rates]
self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches))
if self.verbose:
print_log(f'Expand as {self.dilation_rates}', 'current')
nn.init.constant_(self.branch_weights, global_config['init_alphas'])
def forward(self, input: Tensor) -> Tensor:
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
if len(self.dilation_rates) == 1:
outputs = [
nn.functional.conv2d(
input,
weight=self.op_layer.weight,
bias=self.op_layer.bias,
stride=self.op_layer.stride,
padding=self.get_padding(self.dilation_rates[0]),
dilation=self.dilation_rates[0],
groups=self.op_layer.groups,
)
]
else:
outputs = [
nn.functional.conv2d(
input,
weight=self.op_layer.weight,
bias=self.op_layer.bias,
stride=self.op_layer.stride,
padding=self.get_padding(r),
dilation=r,
groups=self.op_layer.groups,
) * norm_w[i] for i, r in enumerate(self.dilation_rates)
]
output = outputs[0]
for i in range(1, len(self.dilation_rates)):
output += outputs[i]
return output
def estimate_rates(self) -> None:
"""Estimate new dilation rate based on trained branch_weights."""
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
if self.verbose:
print_log(
'Estimate dilation {} with weight {}.'.format(
self.dilation_rates,
norm_w.detach().cpu().numpy().tolist()), 'current')
sum0, sum1, w_sum = 0, 0, 0
for i in range(len(self.dilation_rates)):
sum0 += norm_w[i].item() * self.dilation_rates[i][0]
sum1 += norm_w[i].item() * self.dilation_rates[i][1]
w_sum += norm_w[i].item()
estimated = [
np.clip(
int(round(sum0 / w_sum)), self.global_config['mmin'],
self.global_config['mmax']).item(),
np.clip(
int(round(sum1 / w_sum)), self.global_config['mmin'],
self.global_config['mmax']).item()
]
self.op_layer.dilation = tuple(estimated)
self.op_layer.padding = self.get_padding(self.op_layer.dilation)
self.dilation_rates = [tuple(estimated)]
if self.verbose:
print_log(f'Estimate as {tuple(estimated)}', 'current')
def expand_rates(self) -> None:
"""Expand dilation rate."""
dilation = self.op_layer.dilation
dilation_rates = expand_rates(dilation, self.global_config)
if self.op_layer.kernel_size[
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0:
dilation_rates = [(dilation[0], r[1]) for r in dilation_rates]
if self.op_layer.kernel_size[
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0:
dilation_rates = [(r[0], dilation[1]) for r in dilation_rates]
self.dilation_rates = copy.deepcopy(dilation_rates)
if self.verbose:
print_log(f'Expand as {self.dilation_rates}', 'current')
nn.init.constant_(self.branch_weights,
self.global_config['init_alphas'])
def get_padding(self, dilation) -> tuple:
padding = (get_single_padding(self.op_layer.kernel_size[0],
self.op_layer.stride[0], dilation[0]),
get_single_padding(self.op_layer.kernel_size[1],
self.op_layer.stride[1], dilation[1]))
return padding
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment