Commit c04f261a authored by dongchy920's avatar dongchy920
Browse files

InstruceBLIP

parents
Pipeline #1594 canceled with stages
import os.path as osp
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class PascalVOCDataset(CustomDataset):
"""Pascal VOC dataset.
Args:
split (str): Split txt file for Pascal VOC.
"""
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
'train', 'tvmonitor')
PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
def __init__(self, split, **kwargs):
super(PascalVOCDataset, self).__init__(
img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
assert osp.exists(self.img_dir) and self.split is not None
from .backbones import * # noqa: F401,F403
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
build_head, build_loss, build_segmentor)
from .decode_heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403
__all__ = [
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
'build_head', 'build_loss', 'build_segmentor'
]
from .cgnet import CGNet
# from .fast_scnn import FastSCNN
from .hrnet import HRNet
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt
from .unet import UNet
from .vit import VisionTransformer
from .uniformer import UniFormer
__all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'UniFormer'
]
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from annotator.uniformer.mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
constant_init, kaiming_init)
from annotator.uniformer.mmcv.runner import load_checkpoint
from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
from annotator.uniformer.mmseg.utils import get_root_logger
from ..builder import BACKBONES
class GlobalContextExtractor(nn.Module):
"""Global Context Extractor for CGNet.
This class is employed to refine the joint feature of both local feature
and surrounding context.
Args:
channel (int): Number of input feature channels.
reduction (int): Reductions for global context extractor. Default: 16.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self, channel, reduction=16, with_cp=False):
super(GlobalContextExtractor, self).__init__()
self.channel = channel
self.reduction = reduction
assert reduction >= 1 and channel >= reduction
self.with_cp = with_cp
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel), nn.Sigmoid())
def forward(self, x):
def _inner_forward(x):
num_batch, num_channel = x.size()[:2]
y = self.avg_pool(x).view(num_batch, num_channel)
y = self.fc(y).view(num_batch, num_channel, 1, 1)
return x * y
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class ContextGuidedBlock(nn.Module):
"""Context Guided Block for CGNet.
This class consists of four components: local feature extractor,
surrounding feature extractor, joint feature extractor and global
context extractor.
Args:
in_channels (int): Number of input feature channels.
out_channels (int): Number of output feature channels.
dilation (int): Dilation rate for surrounding context extractor.
Default: 2.
reduction (int): Reduction for global context extractor. Default: 16.
skip_connect (bool): Add input to output or not. Default: True.
downsample (bool): Downsample the input to 1/2 or not. Default: False.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels,
out_channels,
dilation=2,
reduction=16,
skip_connect=True,
downsample=False,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='PReLU'),
with_cp=False):
super(ContextGuidedBlock, self).__init__()
self.with_cp = with_cp
self.downsample = downsample
channels = out_channels if downsample else out_channels // 2
if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
act_cfg['num_parameters'] = channels
kernel_size = 3 if downsample else 1
stride = 2 if downsample else 1
padding = (kernel_size - 1) // 2
self.conv1x1 = ConvModule(
in_channels,
channels,
kernel_size,
stride,
padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.f_loc = build_conv_layer(
conv_cfg,
channels,
channels,
kernel_size=3,
padding=1,
groups=channels,
bias=False)
self.f_sur = build_conv_layer(
conv_cfg,
channels,
channels,
kernel_size=3,
padding=dilation,
groups=channels,
dilation=dilation,
bias=False)
self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
self.activate = nn.PReLU(2 * channels)
if downsample:
self.bottleneck = build_conv_layer(
conv_cfg,
2 * channels,
out_channels,
kernel_size=1,
bias=False)
self.skip_connect = skip_connect and not downsample
self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
def forward(self, x):
def _inner_forward(x):
out = self.conv1x1(x)
loc = self.f_loc(out)
sur = self.f_sur(out)
joi_feat = torch.cat([loc, sur], 1) # the joint feature
joi_feat = self.bn(joi_feat)
joi_feat = self.activate(joi_feat)
if self.downsample:
joi_feat = self.bottleneck(joi_feat) # channel = out_channels
# f_glo is employed to refine the joint feature
out = self.f_glo(joi_feat)
if self.skip_connect:
return x + out
else:
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class InputInjection(nn.Module):
"""Downsampling module for CGNet."""
def __init__(self, num_downsampling):
super(InputInjection, self).__init__()
self.pool = nn.ModuleList()
for i in range(num_downsampling):
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
def forward(self, x):
for pool in self.pool:
x = pool(x)
return x
@BACKBONES.register_module()
class CGNet(nn.Module):
"""CGNet backbone.
A Light-weight Context Guided Network for Semantic Segmentation
arXiv: https://arxiv.org/abs/1811.08201
Args:
in_channels (int): Number of input image channels. Normally 3.
num_channels (tuple[int]): Numbers of feature channels at each stages.
Default: (32, 64, 128).
num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
Default: (3, 21).
dilations (tuple[int]): Dilation rate for surrounding context
extractors at stage 1 and stage 2. Default: (2, 4).
reductions (tuple[int]): Reductions for global context extractors at
stage 1 and stage 2. Default: (8, 16).
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
in_channels=3,
num_channels=(32, 64, 128),
num_blocks=(3, 21),
dilations=(2, 4),
reductions=(8, 16),
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='PReLU'),
norm_eval=False,
with_cp=False):
super(CGNet, self).__init__()
self.in_channels = in_channels
self.num_channels = num_channels
assert isinstance(self.num_channels, tuple) and len(
self.num_channels) == 3
self.num_blocks = num_blocks
assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
self.dilations = dilations
assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
self.reductions = reductions
assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
self.act_cfg['num_parameters'] = num_channels[0]
self.norm_eval = norm_eval
self.with_cp = with_cp
cur_channels = in_channels
self.stem = nn.ModuleList()
for i in range(3):
self.stem.append(
ConvModule(
cur_channels,
num_channels[0],
3,
2 if i == 0 else 1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
cur_channels = num_channels[0]
self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
cur_channels += in_channels
self.norm_prelu_0 = nn.Sequential(
build_norm_layer(norm_cfg, cur_channels)[1],
nn.PReLU(cur_channels))
# stage 1
self.level1 = nn.ModuleList()
for i in range(num_blocks[0]):
self.level1.append(
ContextGuidedBlock(
cur_channels if i == 0 else num_channels[1],
num_channels[1],
dilations[0],
reductions[0],
downsample=(i == 0),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
with_cp=with_cp)) # CG block
cur_channels = 2 * num_channels[1] + in_channels
self.norm_prelu_1 = nn.Sequential(
build_norm_layer(norm_cfg, cur_channels)[1],
nn.PReLU(cur_channels))
# stage 2
self.level2 = nn.ModuleList()
for i in range(num_blocks[1]):
self.level2.append(
ContextGuidedBlock(
cur_channels if i == 0 else num_channels[2],
num_channels[2],
dilations[1],
reductions[1],
downsample=(i == 0),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
with_cp=with_cp)) # CG block
cur_channels = 2 * num_channels[2]
self.norm_prelu_2 = nn.Sequential(
build_norm_layer(norm_cfg, cur_channels)[1],
nn.PReLU(cur_channels))
def forward(self, x):
output = []
# stage 0
inp_2x = self.inject_2x(x)
inp_4x = self.inject_4x(x)
for layer in self.stem:
x = layer(x)
x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
output.append(x)
# stage 1
for i, layer in enumerate(self.level1):
x = layer(x)
if i == 0:
down1 = x
x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
output.append(x)
# stage 2
for i, layer in enumerate(self.level2):
x = layer(x)
if i == 0:
down2 = x
x = self.norm_prelu_2(torch.cat([down2, x], 1))
output.append(x)
return output
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
elif isinstance(m, nn.PReLU):
constant_init(m, 0)
else:
raise TypeError('pretrained must be a str or None')
def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization
layer freezed."""
super(CGNet, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
kaiming_init)
from torch.nn.modules.batchnorm import _BatchNorm
from annotator.uniformer.mmseg.models.decode_heads.psp_head import PPM
from annotator.uniformer.mmseg.ops import resize
from ..builder import BACKBONES
from ..utils.inverted_residual import InvertedResidual
class LearningToDownsample(nn.Module):
"""Learning to downsample module.
Args:
in_channels (int): Number of input channels.
dw_channels (tuple[int]): Number of output channels of the first and
the second depthwise conv (dwconv) layers.
out_channels (int): Number of output channels of the whole
'learning to downsample' module.
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
"""
def __init__(self,
in_channels,
dw_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU')):
super(LearningToDownsample, self).__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
dw_channels1 = dw_channels[0]
dw_channels2 = dw_channels[1]
self.conv = ConvModule(
in_channels,
dw_channels1,
3,
stride=2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.dsconv1 = DepthwiseSeparableConvModule(
dw_channels1,
dw_channels2,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg)
self.dsconv2 = DepthwiseSeparableConvModule(
dw_channels2,
out_channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg)
def forward(self, x):
x = self.conv(x)
x = self.dsconv1(x)
x = self.dsconv2(x)
return x
class GlobalFeatureExtractor(nn.Module):
"""Global feature extractor module.
Args:
in_channels (int): Number of input channels of the GFE module.
Default: 64
block_channels (tuple[int]): Tuple of ints. Each int specifies the
number of output channels of each Inverted Residual module.
Default: (64, 96, 128)
out_channels(int): Number of output channels of the GFE module.
Default: 128
expand_ratio (int): Adjusts number of channels of the hidden layer
in InvertedResidual by this amount.
Default: 6
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
number of times each Inverted Residual module is repeated.
The repeated Inverted Residual modules are called a 'group'.
Default: (3, 3, 3)
strides (tuple[int]): Tuple of ints. Each int specifies
the downsampling factor of each 'group'.
Default: (2, 2, 1)
pool_scales (tuple[int]): Tuple of ints. Each int specifies
the parameter required in 'global average pooling' within PPM.
Default: (1, 2, 3, 6)
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
"""
def __init__(self,
in_channels=64,
block_channels=(64, 96, 128),
out_channels=128,
expand_ratio=6,
num_blocks=(3, 3, 3),
strides=(2, 2, 1),
pool_scales=(1, 2, 3, 6),
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
super(GlobalFeatureExtractor, self).__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
assert len(block_channels) == len(num_blocks) == 3
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
num_blocks[0], strides[0],
expand_ratio)
self.bottleneck2 = self._make_layer(block_channels[0],
block_channels[1], num_blocks[1],
strides[1], expand_ratio)
self.bottleneck3 = self._make_layer(block_channels[1],
block_channels[2], num_blocks[2],
strides[2], expand_ratio)
self.ppm = PPM(
pool_scales,
block_channels[2],
block_channels[2] // 4,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=align_corners)
self.out = ConvModule(
block_channels[2] * 2,
out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def _make_layer(self,
in_channels,
out_channels,
blocks,
stride=1,
expand_ratio=6):
layers = [
InvertedResidual(
in_channels,
out_channels,
stride,
expand_ratio,
norm_cfg=self.norm_cfg)
]
for i in range(1, blocks):
layers.append(
InvertedResidual(
out_channels,
out_channels,
1,
expand_ratio,
norm_cfg=self.norm_cfg))
return nn.Sequential(*layers)
def forward(self, x):
x = self.bottleneck1(x)
x = self.bottleneck2(x)
x = self.bottleneck3(x)
x = torch.cat([x, *self.ppm(x)], dim=1)
x = self.out(x)
return x
class FeatureFusionModule(nn.Module):
"""Feature fusion module.
Args:
higher_in_channels (int): Number of input channels of the
higher-resolution branch.
lower_in_channels (int): Number of input channels of the
lower-resolution branch.
out_channels (int): Number of output channels.
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
"""
def __init__(self,
higher_in_channels,
lower_in_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
super(FeatureFusionModule, self).__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
self.dwconv = ConvModule(
lower_in_channels,
out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.conv_lower_res = ConvModule(
out_channels,
out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.conv_higher_res = ConvModule(
higher_in_channels,
out_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.relu = nn.ReLU(True)
def forward(self, higher_res_feature, lower_res_feature):
lower_res_feature = resize(
lower_res_feature,
size=higher_res_feature.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
lower_res_feature = self.dwconv(lower_res_feature)
lower_res_feature = self.conv_lower_res(lower_res_feature)
higher_res_feature = self.conv_higher_res(higher_res_feature)
out = higher_res_feature + lower_res_feature
return self.relu(out)
@BACKBONES.register_module()
class FastSCNN(nn.Module):
"""Fast-SCNN Backbone.
Args:
in_channels (int): Number of input image channels. Default: 3.
downsample_dw_channels (tuple[int]): Number of output channels after
the first conv layer & the second conv layer in
Learning-To-Downsample (LTD) module.
Default: (32, 48).
global_in_channels (int): Number of input channels of
Global Feature Extractor(GFE).
Equal to number of output channels of LTD.
Default: 64.
global_block_channels (tuple[int]): Tuple of integers that describe
the output channels for each of the MobileNet-v2 bottleneck
residual blocks in GFE.
Default: (64, 96, 128).
global_block_strides (tuple[int]): Tuple of integers
that describe the strides (downsampling factors) for each of the
MobileNet-v2 bottleneck residual blocks in GFE.
Default: (2, 2, 1).
global_out_channels (int): Number of output channels of GFE.
Default: 128.
higher_in_channels (int): Number of input channels of the higher
resolution branch in FFM.
Equal to global_in_channels.
Default: 64.
lower_in_channels (int): Number of input channels of the lower
resolution branch in FFM.
Equal to global_out_channels.
Default: 128.
fusion_out_channels (int): Number of output channels of FFM.
Default: 128.
out_indices (tuple): Tuple of indices of list
[higher_res_features, lower_res_features, fusion_output].
Often set to (0,1,2) to enable aux. heads.
Default: (0, 1, 2).
conv_cfg (dict | None): Config of conv layers. Default: None
norm_cfg (dict | None): Config of norm layers. Default:
dict(type='BN')
act_cfg (dict): Config of activation layers. Default:
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
"""
def __init__(self,
in_channels=3,
downsample_dw_channels=(32, 48),
global_in_channels=64,
global_block_channels=(64, 96, 128),
global_block_strides=(2, 2, 1),
global_out_channels=128,
higher_in_channels=64,
lower_in_channels=128,
fusion_out_channels=128,
out_indices=(0, 1, 2),
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
super(FastSCNN, self).__init__()
if global_in_channels != higher_in_channels:
raise AssertionError('Global Input Channels must be the same \
with Higher Input Channels!')
elif global_out_channels != lower_in_channels:
raise AssertionError('Global Output Channels must be the same \
with Lower Input Channels!')
self.in_channels = in_channels
self.downsample_dw_channels1 = downsample_dw_channels[0]
self.downsample_dw_channels2 = downsample_dw_channels[1]
self.global_in_channels = global_in_channels
self.global_block_channels = global_block_channels
self.global_block_strides = global_block_strides
self.global_out_channels = global_out_channels
self.higher_in_channels = higher_in_channels
self.lower_in_channels = lower_in_channels
self.fusion_out_channels = fusion_out_channels
self.out_indices = out_indices
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
self.learning_to_downsample = LearningToDownsample(
in_channels,
downsample_dw_channels,
global_in_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.global_feature_extractor = GlobalFeatureExtractor(
global_in_channels,
global_block_channels,
global_out_channels,
strides=self.global_block_strides,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.feature_fusion = FeatureFusionModule(
higher_in_channels,
lower_in_channels,
fusion_out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
def init_weights(self, pretrained=None):
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
def forward(self, x):
higher_res_features = self.learning_to_downsample(x)
lower_res_features = self.global_feature_extractor(higher_res_features)
fusion_output = self.feature_fusion(higher_res_features,
lower_res_features)
outs = [higher_res_features, lower_res_features, fusion_output]
outs = [outs[i] for i in self.out_indices]
return tuple(outs)
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
kaiming_init)
from annotator.uniformer.mmcv.runner import load_checkpoint
from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
from annotator.uniformer.mmseg.ops import Upsample, resize
from annotator.uniformer.mmseg.utils import get_root_logger
from ..builder import BACKBONES
from .resnet import BasicBlock, Bottleneck
class HRModule(nn.Module):
"""High-Resolution Module for HRNet.
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
is in this module.
"""
def __init__(self,
num_branches,
blocks,
num_blocks,
in_channels,
num_channels,
multiscale_output=True,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True)):
super(HRModule, self).__init__()
self._check_branches(num_branches, num_blocks, in_channels,
num_channels)
self.in_channels = in_channels
self.num_branches = num_branches
self.multiscale_output = multiscale_output
self.norm_cfg = norm_cfg
self.conv_cfg = conv_cfg
self.with_cp = with_cp
self.branches = self._make_branches(num_branches, blocks, num_blocks,
num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=False)
def _check_branches(self, num_branches, num_blocks, in_channels,
num_channels):
"""Check branches configuration."""
if num_branches != len(num_blocks):
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
f'{len(num_blocks)})'
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
f'{len(num_channels)})'
raise ValueError(error_msg)
if num_branches != len(in_channels):
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
f'{len(in_channels)})'
raise ValueError(error_msg)
def _make_one_branch(self,
branch_index,
block,
num_blocks,
num_channels,
stride=1):
"""Build one branch."""
downsample = None
if stride != 1 or \
self.in_channels[branch_index] != \
num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
self.in_channels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
block.expansion)[1])
layers = []
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
self.in_channels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
"""Build multiple branch."""
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
"""Build fuse layer."""
if self.num_branches == 1:
return None
num_branches = self.num_branches
in_channels = self.in_channels
fuse_layers = []
num_out_branches = num_branches if self.multiscale_output else 1
for i in range(num_out_branches):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg, in_channels[i])[1],
# we set align_corners=False for HRNet
Upsample(
scale_factor=2**(j - i),
mode='bilinear',
align_corners=False)))
elif j == i:
fuse_layer.append(None)
else:
conv_downsamples = []
for k in range(i - j):
if k == i - j - 1:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[i])[1]))
else:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[j])[1],
nn.ReLU(inplace=False)))
fuse_layer.append(nn.Sequential(*conv_downsamples))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def forward(self, x):
"""Forward function."""
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = 0
for j in range(self.num_branches):
if i == j:
y += x[j]
elif j > i:
y = y + resize(
self.fuse_layers[i][j](x[j]),
size=x[i].shape[2:],
mode='bilinear',
align_corners=False)
else:
y += self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
@BACKBONES.register_module()
class HRNet(nn.Module):
"""HRNet backbone.
High-Resolution Representations for Labeling Pixels and Regions
arXiv: https://arxiv.org/abs/1904.04514
Args:
extra (dict): detailed configuration for each stage of HRNet.
in_channels (int): Number of input image channels. Normally 3.
conv_cfg (dict): dictionary to construct and config conv layer.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from annotator.uniformer.mmseg.models import HRNet
>>> import torch
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,
>>> num_branches=1,
>>> block='BOTTLENECK',
>>> num_blocks=(4, ),
>>> num_channels=(64, )),
>>> stage2=dict(
>>> num_modules=1,
>>> num_branches=2,
>>> block='BASIC',
>>> num_blocks=(4, 4),
>>> num_channels=(32, 64)),
>>> stage3=dict(
>>> num_modules=4,
>>> num_branches=3,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4),
>>> num_channels=(32, 64, 128)),
>>> stage4=dict(
>>> num_modules=3,
>>> num_branches=4,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4, 4),
>>> num_channels=(32, 64, 128, 256)))
>>> self = HRNet(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 32, 8, 8)
(1, 64, 4, 4)
(1, 128, 2, 2)
(1, 256, 1, 1)
"""
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
def __init__(self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=False):
super(HRNet, self).__init__()
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
64,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
# stage 1
self.stage1_cfg = self.extra['stage1']
num_channels = self.stage1_cfg['num_channels'][0]
block_type = self.stage1_cfg['block']
num_blocks = self.stage1_cfg['num_blocks'][0]
block = self.blocks_dict[block_type]
stage1_out_channels = num_channels * block.expansion
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
# stage 2
self.stage2_cfg = self.extra['stage2']
num_channels = self.stage2_cfg['num_channels']
block_type = self.stage2_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition1 = self._make_transition_layer([stage1_out_channels],
num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
# stage 3
self.stage3_cfg = self.extra['stage3']
num_channels = self.stage3_cfg['num_channels']
block_type = self.stage3_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition2 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
# stage 4
self.stage4_cfg = self.extra['stage4']
num_channels = self.stage4_cfg['num_channels']
block_type = self.stage4_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition3 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: the normalization layer named "norm2" """
return getattr(self, self.norm2_name)
def _make_transition_layer(self, num_channels_pre_layer,
num_channels_cur_layer):
"""Make transition layer."""
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
num_channels_cur_layer[i])[1],
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv_downsamples = []
for j in range(i + 1 - num_branches_pre):
in_channels = num_channels_pre_layer[-1]
out_channels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else in_channels
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
"""Make each layer."""
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = []
layers.append(
block(
inplanes,
planes,
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
inplanes,
planes,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*layers)
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
"""Make each stage."""
num_modules = layer_config['num_modules']
num_branches = layer_config['num_branches']
num_blocks = layer_config['num_blocks']
num_channels = layer_config['num_channels']
block = self.blocks_dict[layer_config['block']]
hr_modules = []
for i in range(num_modules):
# multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1:
reset_multiscale_output = False
else:
reset_multiscale_output = True
hr_modules.append(
HRModule(
num_branches,
block,
num_blocks,
in_channels,
num_channels,
reset_multiscale_output,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*hr_modules), in_channels
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Forward function."""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['num_branches']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['num_branches']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['num_branches']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage4(x_list)
return y_list
def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization
layer freezed."""
super(HRNet, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
import logging
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule, constant_init, kaiming_init
from annotator.uniformer.mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from ..utils import InvertedResidual, make_divisible
@BACKBONES.register_module()
class MobileNetV2(nn.Module):
"""MobileNetV2 backbone.
Args:
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Default: 1.0.
strides (Sequence[int], optional): Strides of the first block of each
layer. If not specified, default config in ``arch_setting`` will
be used.
dilations (Sequence[int]): Dilation of each layer.
out_indices (None or Sequence[int]): Output from which stages.
Default: (7, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU6').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
# Parameters to build layers. 3 parameters are needed to construct a
# layer, from left to right: expand_ratio, channel, num_blocks.
arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
[6, 96, 3], [6, 160, 3], [6, 320, 1]]
def __init__(self,
widen_factor=1.,
strides=(1, 2, 2, 2, 1, 2, 1),
dilations=(1, 1, 1, 1, 1, 1, 1),
out_indices=(1, 2, 4, 6),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
norm_eval=False,
with_cp=False):
super(MobileNetV2, self).__init__()
self.widen_factor = widen_factor
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == len(self.arch_settings)
self.out_indices = out_indices
for index in out_indices:
if index not in range(0, 7):
raise ValueError('the item in out_indices must in '
f'range(0, 8). But received {index}')
if frozen_stages not in range(-1, 7):
raise ValueError('frozen_stages must be in range(-1, 7). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.in_channels = make_divisible(32 * widen_factor, 8)
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.layers = []
for i, layer_cfg in enumerate(self.arch_settings):
expand_ratio, channel, num_blocks = layer_cfg
stride = self.strides[i]
dilation = self.dilations[i]
out_channels = make_divisible(channel * widen_factor, 8)
inverted_res_layer = self.make_layer(
out_channels=out_channels,
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
expand_ratio=expand_ratio)
layer_name = f'layer{i + 1}'
self.add_module(layer_name, inverted_res_layer)
self.layers.append(layer_name)
def make_layer(self, out_channels, num_blocks, stride, dilation,
expand_ratio):
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
Args:
out_channels (int): out_channels of block.
num_blocks (int): Number of blocks.
stride (int): Stride of the first block.
dilation (int): Dilation of the first block.
expand_ratio (int): Expand the number of channels of the
hidden layer in InvertedResidual by this ratio.
"""
layers = []
for i in range(num_blocks):
layers.append(
InvertedResidual(
self.in_channels,
out_channels,
stride if i == 0 else 1,
expand_ratio=expand_ratio,
dilation=dilation if i == 0 else 1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.in_channels = out_channels
return nn.Sequential(*layers)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(MobileNetV2, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
import logging
import annotator.uniformer.mmcv as mmcv
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule, constant_init, kaiming_init
from annotator.uniformer.mmcv.cnn.bricks import Conv2dAdaptivePadding
from annotator.uniformer.mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from ..utils import InvertedResidualV3 as InvertedResidual
@BACKBONES.register_module()
class MobileNetV3(nn.Module):
"""MobileNetV3 backbone.
This backbone is the improved implementation of `Searching for MobileNetV3
<https://ieeexplore.ieee.org/document/9008835>`_.
Args:
arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
Default: 'small'.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
out_indices (tuple[int]): Output from which layer.
Default: (0, 1, 12).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed.
Default: False.
"""
# Parameters to build each block:
# [kernel size, mid channels, out channels, with_se, act type, stride]
arch_settings = {
'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
[3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
[3, 88, 24, False, 'ReLU', 1],
[5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
[5, 240, 40, True, 'HSwish', 1],
[5, 240, 40, True, 'HSwish', 1],
[5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
[5, 144, 48, True, 'HSwish', 1],
[5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
[5, 576, 96, True, 'HSwish', 1],
[5, 576, 96, True, 'HSwish', 1]],
'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
[3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
[3, 72, 24, False, 'ReLU', 1],
[5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
[5, 120, 40, True, 'ReLU', 1],
[5, 120, 40, True, 'ReLU', 1],
[3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
[3, 200, 80, False, 'HSwish', 1],
[3, 184, 80, False, 'HSwish', 1],
[3, 184, 80, False, 'HSwish', 1],
[3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
[3, 672, 112, True, 'HSwish', 1],
[5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
[5, 960, 160, True, 'HSwish', 1],
[5, 960, 160, True, 'HSwish', 1]]
} # yapf: disable
def __init__(self,
arch='small',
conv_cfg=None,
norm_cfg=dict(type='BN'),
out_indices=(0, 1, 12),
frozen_stages=-1,
reduction_factor=1,
norm_eval=False,
with_cp=False):
super(MobileNetV3, self).__init__()
assert arch in self.arch_settings
assert isinstance(reduction_factor, int) and reduction_factor > 0
assert mmcv.is_tuple_of(out_indices, int)
for index in out_indices:
if index not in range(0, len(self.arch_settings[arch]) + 2):
raise ValueError(
'the item in out_indices must in '
f'range(0, {len(self.arch_settings[arch])+2}). '
f'But received {index}')
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
raise ValueError('frozen_stages must be in range(-1, '
f'{len(self.arch_settings[arch])+2}). '
f'But received {frozen_stages}')
self.arch = arch
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.reduction_factor = reduction_factor
self.norm_eval = norm_eval
self.with_cp = with_cp
self.layers = self._make_layer()
def _make_layer(self):
layers = []
# build the first layer (layer0)
in_channels = 16
layer = ConvModule(
in_channels=3,
out_channels=in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=dict(type='Conv2dAdaptivePadding'),
norm_cfg=self.norm_cfg,
act_cfg=dict(type='HSwish'))
self.add_module('layer0', layer)
layers.append('layer0')
layer_setting = self.arch_settings[self.arch]
for i, params in enumerate(layer_setting):
(kernel_size, mid_channels, out_channels, with_se, act,
stride) = params
if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
i >= 8:
mid_channels = mid_channels // self.reduction_factor
out_channels = out_channels // self.reduction_factor
if with_se:
se_cfg = dict(
channels=mid_channels,
ratio=4,
act_cfg=(dict(type='ReLU'),
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
else:
se_cfg = None
layer = InvertedResidual(
in_channels=in_channels,
out_channels=out_channels,
mid_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
se_cfg=se_cfg,
with_expand_conv=(in_channels != mid_channels),
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type=act),
with_cp=self.with_cp)
in_channels = out_channels
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, layer)
layers.append(layer_name)
# build the last layer
# block5 layer12 os=32 for small model
# block6 layer16 os=32 for large model
layer = ConvModule(
in_channels=in_channels,
out_channels=576 if self.arch == 'small' else 960,
kernel_size=1,
stride=1,
dilation=4,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='HSwish'))
layer_name = 'layer{}'.format(len(layer_setting) + 1)
self.add_module(layer_name, layer)
layers.append(layer_name)
# next, convert backbone MobileNetV3 to a semantic segmentation version
if self.arch == 'small':
self.layer4.depthwise_conv.conv.stride = (1, 1)
self.layer9.depthwise_conv.conv.stride = (1, 1)
for i in range(4, len(layers)):
layer = getattr(self, layers[i])
if isinstance(layer, InvertedResidual):
modified_module = layer.depthwise_conv.conv
else:
modified_module = layer.conv
if i < 9:
modified_module.dilation = (2, 2)
pad = 2
else:
modified_module.dilation = (4, 4)
pad = 4
if not isinstance(modified_module, Conv2dAdaptivePadding):
# Adjust padding
pad *= (modified_module.kernel_size[0] - 1) // 2
modified_module.padding = (pad, pad)
else:
self.layer7.depthwise_conv.conv.stride = (1, 1)
self.layer13.depthwise_conv.conv.stride = (1, 1)
for i in range(7, len(layers)):
layer = getattr(self, layers[i])
if isinstance(layer, InvertedResidual):
modified_module = layer.depthwise_conv.conv
else:
modified_module = layer.conv
if i < 13:
modified_module.dilation = (2, 2)
pad = 2
else:
modified_module.dilation = (4, 4)
pad = 4
if not isinstance(modified_module, Conv2dAdaptivePadding):
# Adjust padding
pad *= (modified_module.kernel_size[0] - 1) // 2
modified_module.padding = (pad, pad)
return layers
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
return outs
def _freeze_stages(self):
for i in range(self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(MobileNetV3, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNetV1d
class RSoftmax(nn.Module):
"""Radix Softmax module in ``SplitAttentionConv2d``.
Args:
radix (int): Radix of input.
groups (int): Groups of input.
"""
def __init__(self, radix, groups):
super().__init__()
self.radix = radix
self.groups = groups
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class SplitAttentionConv2d(nn.Module):
"""Split-Attention Conv2d in ResNeSt.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int | tuple[int]): Same as nn.Conv2d.
stride (int | tuple[int]): Same as nn.Conv2d.
padding (int | tuple[int]): Same as nn.Conv2d.
dilation (int | tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
dcn (dict): Config dict for DCN. Default: None.
"""
def __init__(self,
in_channels,
channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
radix=2,
reduction_factor=4,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None):
super(SplitAttentionConv2d, self).__init__()
inter_channels = max(in_channels * radix // reduction_factor, 32)
self.radix = radix
self.groups = groups
self.channels = channels
self.with_dcn = dcn is not None
self.dcn = dcn
fallback_on_stride = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if self.with_dcn and not fallback_on_stride:
assert conv_cfg is None, 'conv_cfg must be None for DCN'
conv_cfg = dcn
self.conv = build_conv_layer(
conv_cfg,
in_channels,
channels * radix,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups * radix,
bias=False)
self.norm0_name, norm0 = build_norm_layer(
norm_cfg, channels * radix, postfix=0)
self.add_module(self.norm0_name, norm0)
self.relu = nn.ReLU(inplace=True)
self.fc1 = build_conv_layer(
None, channels, inter_channels, 1, groups=self.groups)
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, inter_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.fc2 = build_conv_layer(
None, inter_channels, channels * radix, 1, groups=self.groups)
self.rsoftmax = RSoftmax(radix, groups)
@property
def norm0(self):
"""nn.Module: the normalization layer named "norm0" """
return getattr(self, self.norm0_name)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
def forward(self, x):
x = self.conv(x)
x = self.norm0(x)
x = self.relu(x)
batch, rchannel = x.shape[:2]
batch = x.size(0)
if self.radix > 1:
splits = x.view(batch, self.radix, -1, *x.shape[2:])
gap = splits.sum(dim=1)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
gap = self.fc1(gap)
gap = self.norm1(gap)
gap = self.relu(gap)
atten = self.fc2(gap)
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
if self.radix > 1:
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
out = torch.sum(attens * splits, dim=1)
else:
out = atten * x
return out.contiguous()
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeSt.
Args:
inplane (int): Input planes of this block.
planes (int): Middle planes of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels in
SplitAttentionConv2d. Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
kwargs (dict): Key word arguments for base class.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
"""Bottleneck block for ResNeSt."""
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.with_modulated_dcn = False
self.conv2 = SplitAttentionConv2d(
width,
width,
kernel_size=3,
stride=1 if self.avg_down_stride else self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
radix=radix,
reduction_factor=reduction_factor,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
dcn=self.dcn)
delattr(self, self.norm2_name)
if self.avg_down_stride:
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
if self.avg_down_stride:
out = self.avd_layer(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@BACKBONES.register_module()
class ResNeSt(ResNetV1d):
"""ResNeSt backbone.
Args:
groups (int): Number of groups of Bottleneck. Default: 1
base_width (int): Base width of Bottleneck. Default: 4
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels in
SplitAttentionConv2d. Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
kwargs (dict): Keyword arguments for ResNet.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3)),
200: (Bottleneck, (3, 24, 36, 3))
}
def __init__(self,
groups=1,
base_width=4,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
self.groups = groups
self.base_width = base_width
self.radix = radix
self.reduction_factor = reduction_factor
self.avg_down_stride = avg_down_stride
super(ResNeSt, self).__init__(**kwargs)
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``."""
return ResLayer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
radix=self.radix,
reduction_factor=self.reduction_factor,
avg_down_stride=self.avg_down_stride,
**kwargs)
import torch.nn as nn
import torch.utils.checkpoint as cp
from annotator.uniformer.mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
constant_init, kaiming_init)
from annotator.uniformer.mmcv.runner import load_checkpoint
from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
from annotator.uniformer.mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import ResLayer
class BasicBlock(nn.Module):
"""Basic block for ResNet."""
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None):
super(BasicBlock, self).__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg, planes, planes, 3, padding=1, bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
@property
def norm1(self):
"""nn.Module: normalization layer after the first convolution layer"""
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: normalization layer after the second convolution layer"""
return getattr(self, self.norm2_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
class Bottleneck(nn.Module):
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
"caffe", the stride-two layer is the first 1x1 conv layer.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None):
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
assert dcn is None or isinstance(dcn, dict)
assert plugins is None or isinstance(plugins, list)
if plugins is not None:
allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
assert all(p['position'] in allowed_position for p in plugins)
self.inplanes = inplanes
self.planes = planes
self.stride = stride
self.dilation = dilation
self.style = style
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.dcn = dcn
self.with_dcn = dcn is not None
self.plugins = plugins
self.with_plugins = plugins is not None
if self.with_plugins:
# collect plugins for conv1/conv2/conv3
self.after_conv1_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv1'
]
self.after_conv2_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv2'
]
self.after_conv3_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv3'
]
if self.style == 'pytorch':
self.conv1_stride = 1
self.conv2_stride = stride
else:
self.conv1_stride = stride
self.conv2_stride = 1
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
norm_cfg, planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
if self.with_dcn:
fallback_on_stride = dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
conv_cfg,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
self.conv2 = build_conv_layer(
dcn,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
conv_cfg,
planes,
planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
if self.with_plugins:
self.after_conv1_plugin_names = self.make_block_plugins(
planes, self.after_conv1_plugins)
self.after_conv2_plugin_names = self.make_block_plugins(
planes, self.after_conv2_plugins)
self.after_conv3_plugin_names = self.make_block_plugins(
planes * self.expansion, self.after_conv3_plugins)
def make_block_plugins(self, in_channels, plugins):
"""make plugins for block.
Args:
in_channels (int): Input channels of plugin.
plugins (list[dict]): List of plugins cfg to build.
Returns:
list[str]: List of the names of plugin.
"""
assert isinstance(plugins, list)
plugin_names = []
for plugin in plugins:
plugin = plugin.copy()
name, layer = build_plugin_layer(
plugin,
in_channels=in_channels,
postfix=plugin.pop('postfix', ''))
assert not hasattr(self, name), f'duplicate plugin {name}'
self.add_module(name, layer)
plugin_names.append(name)
return plugin_names
def forward_plugin(self, x, plugin_names):
"""Forward function for plugins."""
out = x
for name in plugin_names:
out = getattr(self, name)(x)
return out
@property
def norm1(self):
"""nn.Module: normalization layer after the first convolution layer"""
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: normalization layer after the second convolution layer"""
return getattr(self, self.norm2_name)
@property
def norm3(self):
"""nn.Module: normalization layer after the third convolution layer"""
return getattr(self, self.norm3_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@BACKBONES.register_module()
class ResNet(nn.Module):
"""ResNet backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default" 3.
stem_channels (int): Number of stem channels. Default: 64.
base_channels (int): Number of base channels of res layer. Default: 64.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): Dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- position (str, required): Position inside block to insert plugin,
options: 'after_conv1', 'after_conv2', 'after_conv3'.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'
multi_grid (Sequence[int]|None): Multi grid dilation rates of last
stage. Default: None
contract_dilation (bool): Whether contract first dilation of each layer
Default: False
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from annotator.uniformer.mmseg.models import ResNet
>>> import torch
>>> self = ResNet(depth=18)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 8, 8)
(1, 128, 4, 4)
(1, 256, 2, 2)
(1, 512, 1, 1)
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth,
in_channels=3,
stem_channels=64,
base_channels=64,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
deep_stem=False,
avg_down=False,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
dcn=None,
stage_with_dcn=(False, False, False, False),
plugins=None,
multi_grid=None,
contract_dilation=False,
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.depth = depth
self.stem_channels = stem_channels
self.base_channels = base_channels
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.deep_stem = deep_stem
self.avg_down = avg_down
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.dcn = dcn
self.stage_with_dcn = stage_with_dcn
if dcn is not None:
assert len(stage_with_dcn) == num_stages
self.plugins = plugins
self.multi_grid = multi_grid
self.contract_dilation = contract_dilation
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = stem_channels
self._make_stem_layer(in_channels, stem_channels)
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
dcn = self.dcn if self.stage_with_dcn[i] else None
if plugins is not None:
stage_plugins = self.make_stage_plugins(plugins, i)
else:
stage_plugins = None
# multi grid is applied to last layer only
stage_multi_grid = multi_grid if i == len(
self.stage_blocks) - 1 else None
planes = base_channels * 2**i
res_layer = self.make_res_layer(
block=self.block,
inplanes=self.inplanes,
planes=planes,
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
plugins=stage_plugins,
multi_grid=stage_multi_grid,
contract_dilation=contract_dilation)
self.inplanes = planes * self.block.expansion
layer_name = f'layer{i+1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = self.block.expansion * base_channels * 2**(
len(self.stage_blocks) - 1)
def make_stage_plugins(self, plugins, stage_idx):
"""make plugins for ResNet 'stage_idx'th stage .
Currently we support to insert 'context_block',
'empirical_attention_block', 'nonlocal_block' into the backbone like
ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
Bottleneck.
An example of plugins format could be :
>>> plugins=[
... dict(cfg=dict(type='xxx', arg1='xxx'),
... stages=(False, True, True, True),
... position='after_conv2'),
... dict(cfg=dict(type='yyy'),
... stages=(True, True, True, True),
... position='after_conv3'),
... dict(cfg=dict(type='zzz', postfix='1'),
... stages=(True, True, True, True),
... position='after_conv3'),
... dict(cfg=dict(type='zzz', postfix='2'),
... stages=(True, True, True, True),
... position='after_conv3')
... ]
>>> self = ResNet(depth=18)
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
>>> assert len(stage_plugins) == 3
Suppose 'stage_idx=0', the structure of blocks in the stage would be:
conv1-> conv2->conv3->yyy->zzz1->zzz2
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
If stages is missing, the plugin would be applied to all stages.
Args:
plugins (list[dict]): List of plugins cfg to build. The postfix is
required if multiple same type plugins are inserted.
stage_idx (int): Index of stage to build
Returns:
list[dict]: Plugins for current stage
"""
stage_plugins = []
for plugin in plugins:
plugin = plugin.copy()
stages = plugin.pop('stages', None)
assert stages is None or len(stages) == self.num_stages
# whether to insert plugin into current stage
if stages is None or stages[stage_idx]:
stage_plugins.append(plugin)
return stage_plugins
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``."""
return ResLayer(**kwargs)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
def _make_stem_layer(self, in_channels, stem_channels):
"""Make stem layer for ResNet."""
if self.deep_stem:
self.stem = nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
stem_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels)[1],
nn.ReLU(inplace=True))
else:
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, stem_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def _freeze_stages(self):
"""Freeze stages param and norm stats."""
if self.frozen_stages >= 0:
if self.deep_stem:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
else:
self.norm1.eval()
for m in [self.conv1, self.norm1]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'layer{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) and hasattr(
m, 'conv2_offset'):
constant_init(m.conv2_offset, 0)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Forward function."""
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(ResNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
@BACKBONES.register_module()
class ResNetV1c(ResNet):
"""ResNetV1c variant described in [1]_.
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
in the input stem with three 3x3 convs.
References:
.. [1] https://arxiv.org/pdf/1812.01187.pdf
"""
def __init__(self, **kwargs):
super(ResNetV1c, self).__init__(
deep_stem=True, avg_down=False, **kwargs)
@BACKBONES.register_module()
class ResNetV1d(ResNet):
"""ResNetV1d variant described in [1]_.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
"""
def __init__(self, **kwargs):
super(ResNetV1d, self).__init__(
deep_stem=True, avg_down=True, **kwargs)
import math
from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeXt.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
"caffe", the stride-two layer is the first 1x1 conv layer.
"""
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
**kwargs):
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, width, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
self.conv_cfg,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
self.conv2 = build_conv_layer(
self.dcn,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@BACKBONES.register_module()
class ResNeXt(ResNet):
"""ResNeXt backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4.
groups (int): Group of resnext.
base_width (int): Base width of resnext.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from annotator.uniformer.mmseg.models import ResNeXt
>>> import torch
>>> self = ResNeXt(depth=50)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 8, 8)
(1, 512, 4, 4)
(1, 1024, 2, 2)
(1, 2048, 1, 1)
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, groups=1, base_width=4, **kwargs):
self.groups = groups
self.base_width = base_width
super(ResNeXt, self).__init__(**kwargs)
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``"""
return ResLayer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
**kwargs)
import torch.nn as nn
import torch.utils.checkpoint as cp
from annotator.uniformer.mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
build_norm_layer, constant_init, kaiming_init)
from annotator.uniformer.mmcv.runner import load_checkpoint
from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
from annotator.uniformer.mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import UpConvBlock
class BasicConvBlock(nn.Module):
"""Basic convolutional block for UNet.
This module consists of several plain convolutional layers.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
num_convs (int): Number of convolutional layers. Default: 2.
stride (int): Whether use stride convolution to downsample
the input feature map. If stride=2, it only uses stride convolution
in the first convolutional layer to downsample the input feature
map. Options are 1 or 2. Default: 1.
dilation (int): Whether use dilated convolution to expand the
receptive field. Set dilation rate of each convolutional layer and
the dilation rate of the first convolutional layer is always 1.
Default: 1.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
"""
def __init__(self,
in_channels,
out_channels,
num_convs=2,
stride=1,
dilation=1,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
dcn=None,
plugins=None):
super(BasicConvBlock, self).__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.with_cp = with_cp
convs = []
for i in range(num_convs):
convs.append(
ConvModule(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride if i == 0 else 1,
dilation=1 if i == 0 else dilation,
padding=1 if i == 0 else dilation,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.convs = nn.Sequential(*convs)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.convs, x)
else:
out = self.convs(x)
return out
@UPSAMPLE_LAYERS.register_module()
class DeconvModule(nn.Module):
"""Deconvolution upsample module in decoder for UNet (2X upsample).
This module uses deconvolution to upsample feature map in the decoder
of UNet.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
"""
def __init__(self,
in_channels,
out_channels,
with_cp=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
*,
kernel_size=4,
scale_factor=2):
super(DeconvModule, self).__init__()
assert (kernel_size - scale_factor >= 0) and\
(kernel_size - scale_factor) % 2 == 0,\
f'kernel_size should be greater than or equal to scale_factor '\
f'and (kernel_size - scale_factor) should be even numbers, '\
f'while the kernel size is {kernel_size} and scale_factor is '\
f'{scale_factor}.'
stride = scale_factor
padding = (kernel_size - scale_factor) // 2
self.with_cp = with_cp
deconv = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
activate = build_activation_layer(act_cfg)
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.deconv_upsamping, x)
else:
out = self.deconv_upsamping(x)
return out
@UPSAMPLE_LAYERS.register_module()
class InterpConv(nn.Module):
"""Interpolation upsample module in decoder for UNet.
This module uses interpolation to upsample feature map in the decoder
of UNet. It consists of one interpolation upsample layer and one
convolutional layer. It can be one interpolation upsample layer followed
by one convolutional layer (conv_first=False) or one convolutional layer
followed by one interpolation upsample layer (conv_first=True).
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
conv_first (bool): Whether convolutional layer or interpolation
upsample layer first. Default: False. It means interpolation
upsample layer followed by one convolutional layer.
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
stride (int): Stride of the convolutional layer. Default: 1.
padding (int): Padding of the convolutional layer. Default: 1.
upsample_cfg (dict): Interpolation config of the upsample layer.
Default: dict(
scale_factor=2, mode='bilinear', align_corners=False).
"""
def __init__(self,
in_channels,
out_channels,
with_cp=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
*,
conv_cfg=None,
conv_first=False,
kernel_size=1,
stride=1,
padding=0,
upsample_cfg=dict(
scale_factor=2, mode='bilinear', align_corners=False)):
super(InterpConv, self).__init__()
self.with_cp = with_cp
conv = ConvModule(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
upsample = nn.Upsample(**upsample_cfg)
if conv_first:
self.interp_upsample = nn.Sequential(conv, upsample)
else:
self.interp_upsample = nn.Sequential(upsample, conv)
def forward(self, x):
"""Forward function."""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.interp_upsample, x)
else:
out = self.interp_upsample(x)
return out
@BACKBONES.register_module()
class UNet(nn.Module):
"""UNet backbone.
U-Net: Convolutional Networks for Biomedical Image Segmentation.
https://arxiv.org/pdf/1505.04597.pdf
Args:
in_channels (int): Number of input image channels. Default" 3.
base_channels (int): Number of base channels of each stage.
The output channels of the first stage. Default: 64.
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
len(strides) is equal to num_stages. Normally the stride of the
first stage in encoder is 1. If strides[i]=2, it uses stride
convolution to downsample in the correspondence encoder stage.
Default: (1, 1, 1, 1, 1).
enc_num_convs (Sequence[int]): Number of convolutional layers in the
convolution block of the correspondence encoder stage.
Default: (2, 2, 2, 2, 2).
dec_num_convs (Sequence[int]): Number of convolutional layers in the
convolution block of the correspondence decoder stage.
Default: (2, 2, 2, 2).
downsamples (Sequence[int]): Whether use MaxPool to downsample the
feature map after the first stage of encoder
(stages: [1, num_stages)). If the correspondence encoder stage use
stride convolution (strides[i]=2), it will never use MaxPool to
downsample, even downsamples[i-1]=True.
Default: (True, True, True, True).
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
Default: (1, 1, 1, 1, 1).
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
Default: (1, 1, 1, 1).
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict | None): Config dict for convolution layer.
Default: None.
norm_cfg (dict | None): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict | None): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU').
upsample_cfg (dict): The upsample config of the upsample module in
decoder. Default: dict(type='InterpConv').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None.
plugins (dict): plugins for convolutional layers. Default: None.
Notice:
The input image size should be divisible by the whole downsample rate
of the encoder. More detail of the whole downsample rate can be found
in UNet._check_input_divisible.
"""
def __init__(self,
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
norm_eval=False,
dcn=None,
plugins=None):
super(UNet, self).__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
assert len(strides) == num_stages, \
'The length of strides should be equal to num_stages, '\
f'while the strides is {strides}, the length of '\
f'strides is {len(strides)}, and the num_stages is '\
f'{num_stages}.'
assert len(enc_num_convs) == num_stages, \
'The length of enc_num_convs should be equal to num_stages, '\
f'while the enc_num_convs is {enc_num_convs}, the length of '\
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
f'{num_stages}.'
assert len(dec_num_convs) == (num_stages-1), \
'The length of dec_num_convs should be equal to (num_stages-1), '\
f'while the dec_num_convs is {dec_num_convs}, the length of '\
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
f'{num_stages}.'
assert len(downsamples) == (num_stages-1), \
'The length of downsamples should be equal to (num_stages-1), '\
f'while the downsamples is {downsamples}, the length of '\
f'downsamples is {len(downsamples)}, and the num_stages is '\
f'{num_stages}.'
assert len(enc_dilations) == num_stages, \
'The length of enc_dilations should be equal to num_stages, '\
f'while the enc_dilations is {enc_dilations}, the length of '\
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
f'{num_stages}.'
assert len(dec_dilations) == (num_stages-1), \
'The length of dec_dilations should be equal to (num_stages-1), '\
f'while the dec_dilations is {dec_dilations}, the length of '\
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
f'{num_stages}.'
self.num_stages = num_stages
self.strides = strides
self.downsamples = downsamples
self.norm_eval = norm_eval
self.base_channels = base_channels
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for i in range(num_stages):
enc_conv_block = []
if i != 0:
if strides[i] == 1 and downsamples[i - 1]:
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
upsample = (strides[i] != 1 or downsamples[i - 1])
self.decoder.append(
UpConvBlock(
conv_block=BasicConvBlock,
in_channels=base_channels * 2**i,
skip_channels=base_channels * 2**(i - 1),
out_channels=base_channels * 2**(i - 1),
num_convs=dec_num_convs[i - 1],
stride=1,
dilation=dec_dilations[i - 1],
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
upsample_cfg=upsample_cfg if upsample else None,
dcn=None,
plugins=None))
enc_conv_block.append(
BasicConvBlock(
in_channels=in_channels,
out_channels=base_channels * 2**i,
num_convs=enc_num_convs[i],
stride=strides[i],
dilation=enc_dilations[i],
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dcn=None,
plugins=None))
self.encoder.append((nn.Sequential(*enc_conv_block)))
in_channels = base_channels * 2**i
def forward(self, x):
self._check_input_divisible(x)
enc_outs = []
for enc in self.encoder:
x = enc(x)
enc_outs.append(x)
dec_outs = [x]
for i in reversed(range(len(self.decoder))):
x = self.decoder[i](enc_outs[i], x)
dec_outs.append(x)
return dec_outs
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(UNet, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _check_input_divisible(self, x):
h, w = x.shape[-2:]
whole_downsample_rate = 1
for i in range(1, self.num_stages):
if self.strides[i] == 2 or self.downsamples[i - 1]:
whole_downsample_rate *= 2
assert (h % whole_downsample_rate == 0) \
and (w % whole_downsample_rate == 0),\
f'The input image size {(h, w)} should be divisible by the whole '\
f'downsample rate {whole_downsample_rate}, when num_stages is '\
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
f'is {self.downsamples}.'
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
# --------------------------------------------------------
# UniFormer
# Copyright (c) 2022 SenseTime X-Lab
# Licensed under The MIT License [see LICENSE for details]
# Written by Kunchang Li
# --------------------------------------------------------
from collections import OrderedDict
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from annotator.uniformer.mmcv_custom import load_checkpoint
from annotator.uniformer.mmseg.utils import get_root_logger
from ..builder import BACKBONES
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class CMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class CBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
self.norm1 = nn.BatchNorm2d(dim)
self.conv1 = nn.Conv2d(dim, dim, 1)
self.conv2 = nn.Conv2d(dim, dim, 1)
self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.BatchNorm2d(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.pos_embed(x)
x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SABlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.pos_embed(x)
B, N, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).reshape(B, N, H, W)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class SABlock_Windows(nn.Module):
def __init__(self, dim, num_heads, window_size=14, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.window_size=window_size
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.pos_embed(x)
x = x.permute(0, 2, 3, 1)
B, H, W, C = x.shape
shortcut = x
x = self.norm1(x)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.permute(0, 3, 1, 2).reshape(B, C, H, W)
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.norm = nn.LayerNorm(embed_dim)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, _, H, W = x.shape
x = self.proj(x)
B, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
return x
@BACKBONES.register_module()
class UniFormer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
"""
def __init__(self, layers=[3, 4, 8, 3], img_size=224, in_chans=3, num_classes=80, embed_dim=[64, 128, 320, 512],
head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
pretrained_path=None, use_checkpoint=False, checkpoint_num=[0, 0, 0, 0],
windows=False, hybrid=False, window_size=14):
"""
Args:
layer (list): number of block in each layer
img_size (int, tuple): input image size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
head_dim (int): dimension of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer (nn.Module): normalization layer
pretrained_path (str): path of pretrained model
use_checkpoint (bool): whether use checkpoint
checkpoint_num (list): index for using checkpoint in every stage
windows (bool): whether use window MHRA
hybrid (bool): whether use hybrid MHRA
window_size (int): size of window (>14)
"""
super().__init__()
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.checkpoint_num = checkpoint_num
self.windows = windows
print(f'Use Checkpoint: {self.use_checkpoint}')
print(f'Checkpoint Number: {self.checkpoint_num}')
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])
self.patch_embed2 = PatchEmbed(
img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1])
self.patch_embed3 = PatchEmbed(
img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2])
self.patch_embed4 = PatchEmbed(
img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3])
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))] # stochastic depth decay rule
num_heads = [dim // head_dim for dim in embed_dim]
self.blocks1 = nn.ModuleList([
CBlock(
dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(layers[0])])
self.norm1=norm_layer(embed_dim[0])
self.blocks2 = nn.ModuleList([
CBlock(
dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]], norm_layer=norm_layer)
for i in range(layers[1])])
self.norm2 = norm_layer(embed_dim[1])
if self.windows:
print('Use local window for all blocks in stage3')
self.blocks3 = nn.ModuleList([
SABlock_Windows(
dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)
for i in range(layers[2])])
elif hybrid:
print('Use hybrid window for blocks in stage3')
block3 = []
for i in range(layers[2]):
if (i + 1) % 4 == 0:
block3.append(SABlock(
dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer))
else:
block3.append(SABlock_Windows(
dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer))
self.blocks3 = nn.ModuleList(block3)
else:
print('Use global window for all blocks in stage3')
self.blocks3 = nn.ModuleList([
SABlock(
dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)
for i in range(layers[2])])
self.norm3 = norm_layer(embed_dim[2])
self.blocks4 = nn.ModuleList([
SABlock(
dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]+layers[2]], norm_layer=norm_layer)
for i in range(layers[3])])
self.norm4 = norm_layer(embed_dim[3])
# Representation layer
if representation_size:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
self.apply(self._init_weights)
self.init_weights(pretrained=pretrained_path)
def init_weights(self, pretrained):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
print(f'Load pretrained model from {pretrained}')
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
out = []
x = self.patch_embed1(x)
x = self.pos_drop(x)
for i, blk in enumerate(self.blocks1):
if self.use_checkpoint and i < self.checkpoint_num[0]:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x_out = self.norm1(x.permute(0, 2, 3, 1))
out.append(x_out.permute(0, 3, 1, 2).contiguous())
x = self.patch_embed2(x)
for i, blk in enumerate(self.blocks2):
if self.use_checkpoint and i < self.checkpoint_num[1]:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x_out = self.norm2(x.permute(0, 2, 3, 1))
out.append(x_out.permute(0, 3, 1, 2).contiguous())
x = self.patch_embed3(x)
for i, blk in enumerate(self.blocks3):
if self.use_checkpoint and i < self.checkpoint_num[2]:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x_out = self.norm3(x.permute(0, 2, 3, 1))
out.append(x_out.permute(0, 3, 1, 2).contiguous())
x = self.patch_embed4(x)
for i, blk in enumerate(self.blocks4):
if self.use_checkpoint and i < self.checkpoint_num[3]:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x_out = self.norm4(x.permute(0, 2, 3, 1))
out.append(x_out.permute(0, 3, 1, 2).contiguous())
return tuple(out)
def forward(self, x):
x = self.forward_features(x)
return x
"""Modified from https://github.com/rwightman/pytorch-image-
models/blob/master/timm/models/vision_transformer.py."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from annotator.uniformer.mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
constant_init, kaiming_init, normal_init)
from annotator.uniformer.mmcv.runner import _load_checkpoint
from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
from annotator.uniformer.mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import DropPath, trunc_normal_
class Mlp(nn.Module):
"""MLP layer for Encoder block.
Args:
in_features(int): Input dimension for the first fully
connected layer.
hidden_features(int): Output dimension for the first fully
connected layer.
out_features(int): Output dementsion for the second fully
connected layer.
act_cfg(dict): Config dict for activation layer.
Default: dict(type='GELU').
drop(float): Drop rate for the dropout layer. Dropout rate has
to be between 0 and 1. Default: 0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type='GELU'),
drop=0.):
super(Mlp, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = Linear(in_features, hidden_features)
self.act = build_activation_layer(act_cfg)
self.fc2 = Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
"""Attention layer for Encoder block.
Args:
dim (int): Dimension for the input vector.
num_heads (int): Number of parallel attention heads.
qkv_bias (bool): Enable bias for qkv if True. Default: False.
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
attn_drop (float): Drop rate for attention output weights.
Default: 0.
proj_drop (float): Drop rate for output weights. Default: 0.
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
b, n, c = x.shape
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
c // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
"""Implements encoder block with residual connection.
Args:
dim (int): The feature dimension.
num_heads (int): Number of parallel attention heads.
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop (float): Drop rate for mlp output weights. Default: 0.
attn_drop (float): Drop rate for attention output weights.
Default: 0.
proj_drop (float): Drop rate for attn layer output weights.
Default: 0.
drop_path (float): Drop rate for paths of model.
Default: 0.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN', requires_grad=True).
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
dim,
num_heads,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
proj_drop=0.,
drop_path=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN', eps=1e-6),
with_cp=False):
super(Block, self).__init__()
self.with_cp = with_cp
_, self.norm1 = build_norm_layer(norm_cfg, dim)
self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
proj_drop)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
_, self.norm2 = build_norm_layer(norm_cfg, dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop)
def forward(self, x):
def _inner_forward(x):
out = x + self.drop_path(self.attn(self.norm1(x)))
out = out + self.drop_path(self.mlp(self.norm2(out)))
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
class PatchEmbed(nn.Module):
"""Image to Patch Embedding.
Args:
img_size (int | tuple): Input image size.
default: 224.
patch_size (int): Width and height for a patch.
default: 16.
in_channels (int): Input channels for images. Default: 3.
embed_dim (int): The embedding dimension. Default: 768.
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dim=768):
super(PatchEmbed, self).__init__()
if isinstance(img_size, int):
self.img_size = (img_size, img_size)
elif isinstance(img_size, tuple):
self.img_size = img_size
else:
raise TypeError('img_size must be type of int or tuple')
h, w = self.img_size
self.patch_size = (patch_size, patch_size)
self.num_patches = (h // patch_size) * (w // patch_size)
self.proj = Conv2d(
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
return self.proj(x).flatten(2).transpose(1, 2)
@BACKBONES.register_module()
class VisionTransformer(nn.Module):
"""Vision transformer backbone.
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
Args:
img_size (tuple): input image size. Default: (224, 224).
patch_size (int, tuple): patch size. Default: 16.
in_channels (int): number of input channels. Default: 3.
embed_dim (int): embedding dimension. Default: 768.
depth (int): depth of transformer. Default: 12.
num_heads (int): number of attention heads. Default: 12.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
out_indices (list | tuple | int): Output from which stages.
Default: -1.
qkv_bias (bool): enable bias for qkv if True. Default: True.
qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): dropout rate. Default: 0.
attn_drop_rate (float): attention dropout rate. Default: 0.
drop_path_rate (float): Rate of DropPath. Default: 0.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN', eps=1e-6, requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='GELU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Default: bicubic.
with_cls_token (bool): If concatenating class token into image tokens
as transformer input. Default: True.
with_cp (bool): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
"""
def __init__(self,
img_size=(224, 224),
patch_size=16,
in_channels=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
out_indices=11,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True),
act_cfg=dict(type='GELU'),
norm_eval=False,
final_norm=False,
with_cls_token=True,
interpolate_mode='bicubic',
with_cp=False):
super(VisionTransformer, self).__init__()
self.img_size = img_size
self.patch_size = patch_size
self.features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim)
self.with_cls_token = with_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
self.out_indices = [out_indices]
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
self.out_indices = out_indices
else:
raise TypeError('out_indices must be type of int, list or tuple')
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=dpr[i],
attn_drop=attn_drop_rate,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp) for i in range(depth)
])
self.interpolate_mode = interpolate_mode
self.final_norm = final_norm
if final_norm:
_, self.norm = build_norm_layer(norm_cfg, embed_dim)
self.norm_eval = norm_eval
self.with_cp = with_cp
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
checkpoint = _load_checkpoint(pretrained, logger=logger)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
logger.info(msg=f'Resize the pos_embed shape from \
{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
h, w = self.img_size
pos_size = int(
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
state_dict['pos_embed'] = self.resize_pos_embed(
state_dict['pos_embed'], (h, w), (pos_size, pos_size),
self.patch_size, self.interpolate_mode)
self.load_state_dict(state_dict, False)
elif pretrained is None:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'mlp' in n:
normal_init(m.bias, std=1e-6)
else:
constant_init(m.bias, 0)
elif isinstance(m, Conv2d):
kaiming_init(m.weight, mode='fan_in')
if m.bias is not None:
constant_init(m.bias, 0)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
else:
raise TypeError('pretrained must be a str or None')
def _pos_embeding(self, img, patched_img, pos_embed):
"""Positiong embeding method.
Resize the pos_embed, if the input image size doesn't match
the training size.
Args:
img (torch.Tensor): The inference image tensor, the shape
must be [B, C, H, W].
patched_img (torch.Tensor): The patched image, it should be
shape of [B, L1, C].
pos_embed (torch.Tensor): The pos_embed weighs, it should be
shape of [B, L2, c].
Return:
torch.Tensor: The pos encoded image feature.
"""
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
'the shapes of patched_img and pos_embed must be [B, L, C]'
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
if x_len != pos_len:
if pos_len == (self.img_size[0] // self.patch_size) * (
self.img_size[1] // self.patch_size) + 1:
pos_h = self.img_size[0] // self.patch_size
pos_w = self.img_size[1] // self.patch_size
else:
raise ValueError(
'Unexpected shape of pos_embed, got {}.'.format(
pos_embed.shape))
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
(pos_h, pos_w), self.patch_size,
self.interpolate_mode)
return self.pos_drop(patched_img + pos_embed)
@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): pos_embed weights.
input_shpae (tuple): Tuple for (input_h, intput_w).
pos_shape (tuple): Tuple for (pos_h, pos_w).
patch_size (int): Patch size.
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
input_h, input_w = input_shpae
pos_h, pos_w = pos_shape
cls_token_weight = pos_embed[:, 0]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = F.interpolate(
pos_embed_weight,
size=[input_h // patch_size, input_w // patch_size],
align_corners=False,
mode=mode)
cls_token_weight = cls_token_weight.unsqueeze(1)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed
def forward(self, inputs):
B = inputs.shape[0]
x = self.patch_embed(inputs)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self._pos_embeding(inputs, x, self.pos_embed)
if not self.with_cls_token:
# Remove class token for transformer input
x = x[:, 1:]
outs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if i == len(self.blocks) - 1:
if self.final_norm:
x = self.norm(x)
if i in self.out_indices:
if self.with_cls_token:
# Remove class token and reshape token for decoder head
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
outs.append(out)
return tuple(outs)
def train(self, mode=True):
super(VisionTransformer, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.LayerNorm):
m.eval()
import warnings
from annotator.uniformer.mmcv.cnn import MODELS as MMCV_MODELS
from annotator.uniformer.mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS)
BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS
SEGMENTORS = MODELS
def build_backbone(cfg):
"""Build backbone."""
return BACKBONES.build(cfg)
def build_neck(cfg):
"""Build neck."""
return NECKS.build(cfg)
def build_head(cfg):
"""Build head."""
return HEADS.build(cfg)
def build_loss(cfg):
"""Build loss."""
return LOSSES.build(cfg)
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
"""Build segmentor."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return SEGMENTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
from .ann_head import ANNHead
from .apc_head import APCHead
from .aspp_head import ASPPHead
from .cc_head import CCHead
from .da_head import DAHead
from .dm_head import DMHead
from .dnl_head import DNLHead
from .ema_head import EMAHead
from .enc_head import EncHead
from .fcn_head import FCNHead
from .fpn_head import FPNHead
from .gc_head import GCHead
from .lraspp_head import LRASPPHead
from .nl_head import NLHead
from .ocr_head import OCRHead
# from .point_head import PointHead
from .psa_head import PSAHead
from .psp_head import PSPHead
from .sep_aspp_head import DepthwiseSeparableASPPHead
from .sep_fcn_head import DepthwiseSeparableFCNHead
from .uper_head import UPerHead
__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'APCHead', 'DMHead', 'LRASPPHead'
]
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule
from ..builder import HEADS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
class PPMConcat(nn.ModuleList):
"""Pyramid Pooling Module that only concat the features of each layer.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
"""
def __init__(self, pool_scales=(1, 3, 6, 8)):
super(PPMConcat, self).__init__(
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
def forward(self, feats):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(feats)
ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
concat_outs = torch.cat(ppm_outs, dim=2)
return concat_outs
class SelfAttentionBlock(_SelfAttentionBlock):
"""Make a ANN used SelfAttentionBlock.
Args:
low_in_channels (int): Input channels of lower level feature,
which is the key feature for self-attention.
high_in_channels (int): Input channels of higher level feature,
which is the query feature for self-attention.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
share_key_query (bool): Whether share projection weight between key
and query projection.
query_scale (int): The scale of query feature map.
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module of key feature.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, low_in_channels, high_in_channels, channels,
out_channels, share_key_query, query_scale, key_pool_scales,
conv_cfg, norm_cfg, act_cfg):
key_psp = PPMConcat(key_pool_scales)
if query_scale > 1:
query_downsample = nn.MaxPool2d(kernel_size=query_scale)
else:
query_downsample = None
super(SelfAttentionBlock, self).__init__(
key_in_channels=low_in_channels,
query_in_channels=high_in_channels,
channels=channels,
out_channels=out_channels,
share_key_query=share_key_query,
query_downsample=query_downsample,
key_downsample=key_psp,
key_query_num_convs=1,
key_query_norm=True,
value_out_num_convs=1,
value_out_norm=False,
matmul_norm=True,
with_out=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
class AFNB(nn.Module):
"""Asymmetric Fusion Non-local Block(AFNB)
Args:
low_in_channels (int): Input channels of lower level feature,
which is the key feature for self-attention.
high_in_channels (int): Input channels of higher level feature,
which is the query feature for self-attention.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
and query projection.
query_scales (tuple[int]): The scales of query feature map.
Default: (1,)
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module of key feature.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, low_in_channels, high_in_channels, channels,
out_channels, query_scales, key_pool_scales, conv_cfg,
norm_cfg, act_cfg):
super(AFNB, self).__init__()
self.stages = nn.ModuleList()
for query_scale in query_scales:
self.stages.append(
SelfAttentionBlock(
low_in_channels=low_in_channels,
high_in_channels=high_in_channels,
channels=channels,
out_channels=out_channels,
share_key_query=False,
query_scale=query_scale,
key_pool_scales=key_pool_scales,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.bottleneck = ConvModule(
out_channels + high_in_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, low_feats, high_feats):
"""Forward function."""
priors = [stage(high_feats, low_feats) for stage in self.stages]
context = torch.stack(priors, dim=0).sum(dim=0)
output = self.bottleneck(torch.cat([context, high_feats], 1))
return output
class APNB(nn.Module):
"""Asymmetric Pyramid Non-local Block (APNB)
Args:
in_channels (int): Input channels of key/query feature,
which is the key feature for self-attention.
channels (int): Output channels of key/query transform.
out_channels (int): Output channels.
query_scales (tuple[int]): The scales of query feature map.
Default: (1,)
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module of key feature.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict|None): Config of activation layers.
"""
def __init__(self, in_channels, channels, out_channels, query_scales,
key_pool_scales, conv_cfg, norm_cfg, act_cfg):
super(APNB, self).__init__()
self.stages = nn.ModuleList()
for query_scale in query_scales:
self.stages.append(
SelfAttentionBlock(
low_in_channels=in_channels,
high_in_channels=in_channels,
channels=channels,
out_channels=out_channels,
share_key_query=True,
query_scale=query_scale,
key_pool_scales=key_pool_scales,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.bottleneck = ConvModule(
2 * in_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, feats):
"""Forward function."""
priors = [stage(feats, feats) for stage in self.stages]
context = torch.stack(priors, dim=0).sum(dim=0)
output = self.bottleneck(torch.cat([context, feats], 1))
return output
@HEADS.register_module()
class ANNHead(BaseDecodeHead):
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.
This head is the implementation of `ANNNet
<https://arxiv.org/abs/1908.07678>`_.
Args:
project_channels (int): Projection channels for Nonlocal.
query_scales (tuple[int]): The scales of query feature map.
Default: (1,)
key_pool_scales (tuple[int]): The pooling scales of key feature map.
Default: (1, 3, 6, 8).
"""
def __init__(self,
project_channels,
query_scales=(1, ),
key_pool_scales=(1, 3, 6, 8),
**kwargs):
super(ANNHead, self).__init__(
input_transform='multiple_select', **kwargs)
assert len(self.in_channels) == 2
low_in_channels, high_in_channels = self.in_channels
self.project_channels = project_channels
self.fusion = AFNB(
low_in_channels=low_in_channels,
high_in_channels=high_in_channels,
out_channels=high_in_channels,
channels=project_channels,
query_scales=query_scales,
key_pool_scales=key_pool_scales,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.bottleneck = ConvModule(
high_in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.context = APNB(
in_channels=self.channels,
out_channels=self.channels,
channels=project_channels,
query_scales=query_scales,
key_pool_scales=key_pool_scales,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
low_feats, high_feats = self._transform_inputs(inputs)
output = self.fusion(low_feats, high_feats)
output = self.dropout(output)
output = self.bottleneck(output)
output = self.context(output)
output = self.cls_seg(output)
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
from annotator.uniformer.mmcv.cnn import ConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class ACM(nn.Module):
"""Adaptive Context Module used in APCNet.
Args:
pool_scale (int): Pooling scale used in Adaptive Context
Module to extract region features.
fusion (bool): Add one conv to fuse residual feature.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict | None): Config of conv layers.
norm_cfg (dict | None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
norm_cfg, act_cfg):
super(ACM, self).__init__()
self.pool_scale = pool_scale
self.fusion = fusion
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.pooled_redu_conv = ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.input_redu_conv = ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.global_info = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
self.residual_conv = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.fusion:
self.fusion_conv = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, x):
"""Forward function."""
pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
# [batch_size, channels, h, w]
x = self.input_redu_conv(x)
# [batch_size, channels, pool_scale, pool_scale]
pooled_x = self.pooled_redu_conv(pooled_x)
batch_size = x.size(0)
# [batch_size, pool_scale * pool_scale, channels]
pooled_x = pooled_x.view(batch_size, self.channels,
-1).permute(0, 2, 1).contiguous()
# [batch_size, h * w, pool_scale * pool_scale]
affinity_matrix = self.gla(x + resize(
self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
).permute(0, 2, 3, 1).reshape(
batch_size, -1, self.pool_scale**2)
affinity_matrix = F.sigmoid(affinity_matrix)
# [batch_size, h * w, channels]
z_out = torch.matmul(affinity_matrix, pooled_x)
# [batch_size, channels, h * w]
z_out = z_out.permute(0, 2, 1).contiguous()
# [batch_size, channels, h, w]
z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
z_out = self.residual_conv(z_out)
z_out = F.relu(z_out + x)
if self.fusion:
z_out = self.fusion_conv(z_out)
return z_out
@HEADS.register_module()
class APCHead(BaseDecodeHead):
"""Adaptive Pyramid Context Network for Semantic Segmentation.
This head is the implementation of
`APCNet <https://openaccess.thecvf.com/content_CVPR_2019/papers/\
He_Adaptive_Pyramid_Context_Network_for_Semantic_Segmentation_\
CVPR_2019_paper.pdf>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Adaptive Context
Module. Default: (1, 2, 3, 6).
fusion (bool): Add one conv to fuse residual feature.
"""
def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
super(APCHead, self).__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales
self.fusion = fusion
acm_modules = []
for pool_scale in self.pool_scales:
acm_modules.append(
ACM(pool_scale,
self.fusion,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.acm_modules = nn.ModuleList(acm_modules)
self.bottleneck = ConvModule(
self.in_channels + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
acm_outs = [x]
for acm_module in self.acm_modules:
acm_outs.append(acm_module(x))
acm_outs = torch.cat(acm_outs, dim=1)
output = self.bottleneck(acm_outs)
output = self.cls_seg(output)
return output
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class ASPPModule(nn.ModuleList):
"""Atrous Spatial Pyramid Pooling (ASPP) Module.
Args:
dilations (tuple[int]): Dilation rate of each layer.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
act_cfg):
super(ASPPModule, self).__init__()
self.dilations = dilations
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for dilation in dilations:
self.append(
ConvModule(
self.in_channels,
self.channels,
1 if dilation == 1 else 3,
dilation=dilation,
padding=0 if dilation == 1 else dilation,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def forward(self, x):
"""Forward function."""
aspp_outs = []
for aspp_module in self:
aspp_outs.append(aspp_module(x))
return aspp_outs
@HEADS.register_module()
class ASPPHead(BaseDecodeHead):
"""Rethinking Atrous Convolution for Semantic Image Segmentation.
This head is the implementation of `DeepLabV3
<https://arxiv.org/abs/1706.05587>`_.
Args:
dilations (tuple[int]): Dilation rates for ASPP module.
Default: (1, 6, 12, 18).
"""
def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
super(ASPPHead, self).__init__(**kwargs)
assert isinstance(dilations, (list, tuple))
self.dilations = dilations
self.image_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.aspp_modules = ASPPModule(
dilations,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.bottleneck = ConvModule(
(len(dilations) + 1) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
aspp_outs = [
resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
]
aspp_outs.extend(self.aspp_modules(x))
aspp_outs = torch.cat(aspp_outs, dim=1)
output = self.bottleneck(aspp_outs)
output = self.cls_seg(output)
return output
from abc import ABCMeta, abstractmethod
from .decode_head import BaseDecodeHead
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
"""Base class for cascade decode head used in
:class:`CascadeEncoderDecoder."""
def __init__(self, *args, **kwargs):
super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
@abstractmethod
def forward(self, inputs, prev_output):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs, prev_output)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses
def forward_test(self, inputs, prev_output, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs, prev_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment