Commit a4372e17 authored by huchen's avatar huchen
Browse files

Merge branch 'hepj_work' into 'main'

增加conformer代码

See merge request dcutoolkit/deeplearing/dlexamples_new!7
parents 7f99c1c3 142dcf29
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer, constant_init
from ..builder import BACKBONES
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
class Bottleneck(_Bottleneck):
r"""Bottleneck for the ResNet backbone in `DetectoRS
<https://arxiv.org/pdf/2006.02334.pdf>`_.
This bottleneck allows the users to specify whether to use
SAC (Switchable Atrous Convolution) and RFP (Recursive Feature Pyramid).
Args:
inplanes (int): The number of input channels.
planes (int): The number of output channels before expansion.
rfp_inplanes (int, optional): The number of channels from RFP.
Default: None. If specified, an additional conv layer will be
added for ``rfp_feat``. Otherwise, the structure is the same as
base class.
sac (dict, optional): Dictionary to construct SAC. Default: None.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
rfp_inplanes=None,
sac=None,
**kwargs):
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
assert sac is None or isinstance(sac, dict)
self.sac = sac
self.with_sac = sac is not None
if self.with_sac:
self.conv2 = build_conv_layer(
self.sac,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
bias=False)
self.rfp_inplanes = rfp_inplanes
if self.rfp_inplanes:
self.rfp_conv = build_conv_layer(
None,
self.rfp_inplanes,
planes * self.expansion,
1,
stride=1,
bias=True)
self.init_weights()
def init_weights(self):
"""Initialize the weights."""
if self.rfp_inplanes:
constant_init(self.rfp_conv, 0)
def rfp_forward(self, x, rfp_feat):
"""The forward function that also takes the RFP features as input."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
if self.rfp_inplanes:
rfp_feat = self.rfp_conv(rfp_feat)
out = out + rfp_feat
out = self.relu(out)
return out
class ResLayer(nn.Sequential):
"""ResLayer to build ResNet style backbone for RPF in detectoRS.
The difference between this module and base class is that we pass
``rfp_inplanes`` to the first block.
Args:
block (nn.Module): block used to build ResLayer.
inplanes (int): inplanes of block.
planes (int): planes of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
downsample_first (bool): Downsample at the first block or last block.
False for Hourglass, True for ResNet. Default: True
rfp_inplanes (int, optional): The number of channels from RFP.
Default: None. If specified, an additional conv layer will be
added for ``rfp_feat``. Otherwise, the structure is the same as
base class.
"""
def __init__(self,
block,
inplanes,
planes,
num_blocks,
stride=1,
avg_down=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
downsample_first=True,
rfp_inplanes=None,
**kwargs):
self.block = block
assert downsample_first, f'downsampel_first={downsample_first} is ' \
'not supported in DetectoRS'
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = []
conv_stride = stride
if avg_down and stride != 1:
conv_stride = 1
downsample.append(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False))
downsample.extend([
build_conv_layer(
conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=conv_stride,
bias=False),
build_norm_layer(norm_cfg, planes * block.expansion)[1]
])
downsample = nn.Sequential(*downsample)
layers = []
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
rfp_inplanes=rfp_inplanes,
**kwargs))
inplanes = planes * block.expansion
for _ in range(1, num_blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
**kwargs))
super(ResLayer, self).__init__(*layers)
@BACKBONES.register_module()
class DetectoRS_ResNet(ResNet):
"""ResNet backbone for DetectoRS.
Args:
sac (dict, optional): Dictionary to construct SAC (Switchable Atrous
Convolution). Default: None.
stage_with_sac (list): Which stage to use sac. Default: (False, False,
False, False).
rfp_inplanes (int, optional): The number of channels from RFP.
Default: None. If specified, an additional conv layer will be
added for ``rfp_feat``. Otherwise, the structure is the same as
base class.
output_img (bool): If ``True``, the input image will be inserted into
the starting position of output. Default: False.
pretrained (str, optional): The pretrained model to load.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
sac=None,
stage_with_sac=(False, False, False, False),
rfp_inplanes=None,
output_img=False,
pretrained=None,
**kwargs):
self.sac = sac
self.stage_with_sac = stage_with_sac
self.rfp_inplanes = rfp_inplanes
self.output_img = output_img
self.pretrained = pretrained
super(DetectoRS_ResNet, self).__init__(**kwargs)
self.inplanes = self.stem_channels
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = self.strides[i]
dilation = self.dilations[i]
dcn = self.dcn if self.stage_with_dcn[i] else None
sac = self.sac if self.stage_with_sac[i] else None
if self.plugins is not None:
stage_plugins = self.make_stage_plugins(self.plugins, i)
else:
stage_plugins = None
planes = self.base_channels * 2**i
res_layer = self.make_res_layer(
block=self.block,
inplanes=self.inplanes,
planes=planes,
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=self.with_cp,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
dcn=dcn,
sac=sac,
rfp_inplanes=rfp_inplanes if i > 0 else None,
plugins=stage_plugins)
self.inplanes = planes * self.block.expansion
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer`` for DetectoRS."""
return ResLayer(**kwargs)
def forward(self, x):
"""Forward function."""
outs = list(super(DetectoRS_ResNet, self).forward(x))
if self.output_img:
outs.insert(0, x)
return tuple(outs)
def rfp_forward(self, x, rfp_feats):
"""Forward function for RFP."""
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
rfp_feat = rfp_feats[i] if i > 0 else None
for layer in res_layer:
x = layer.rfp_forward(x, rfp_feat)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
import math
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from .detectors_resnet import Bottleneck as _Bottleneck
from .detectors_resnet import DetectoRS_ResNet
class Bottleneck(_Bottleneck):
expansion = 4
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
**kwargs):
"""Bottleneck block for ResNeXt.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, width, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if self.with_sac:
self.conv2 = build_conv_layer(
self.sac,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
elif not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
self.conv_cfg,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
self.conv2 = build_conv_layer(
self.dcn,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@BACKBONES.register_module()
class DetectoRS_ResNeXt(DetectoRS_ResNet):
"""ResNeXt backbone for DetectoRS.
Args:
groups (int): The number of groups in ResNeXt.
base_width (int): The base width of ResNeXt.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, groups=1, base_width=4, **kwargs):
self.groups = groups
self.base_width = base_width
super(DetectoRS_ResNeXt, self).__init__(**kwargs)
def make_res_layer(self, **kwargs):
return super().make_res_layer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
**kwargs)
import torch.nn as nn
from mmcv.cnn import ConvModule
from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import BasicBlock
class HourglassModule(nn.Module):
"""Hourglass Module for HourglassNet backbone.
Generate module recursively and use BasicBlock as the base unit.
Args:
depth (int): Depth of current HourglassModule.
stage_channels (list[int]): Feature channels of sub-modules in current
and follow-up HourglassModule.
stage_blocks (list[int]): Number of sub-modules stacked in current and
follow-up HourglassModule.
norm_cfg (dict): Dictionary to construct and config norm layer.
"""
def __init__(self,
depth,
stage_channels,
stage_blocks,
norm_cfg=dict(type='BN', requires_grad=True)):
super(HourglassModule, self).__init__()
self.depth = depth
cur_block = stage_blocks[0]
next_block = stage_blocks[1]
cur_channel = stage_channels[0]
next_channel = stage_channels[1]
self.up1 = ResLayer(
BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg)
self.low1 = ResLayer(
BasicBlock,
cur_channel,
next_channel,
cur_block,
stride=2,
norm_cfg=norm_cfg)
if self.depth > 1:
self.low2 = HourglassModule(depth - 1, stage_channels[1:],
stage_blocks[1:])
else:
self.low2 = ResLayer(
BasicBlock,
next_channel,
next_channel,
next_block,
norm_cfg=norm_cfg)
self.low3 = ResLayer(
BasicBlock,
next_channel,
cur_channel,
cur_block,
norm_cfg=norm_cfg,
downsample_first=False)
self.up2 = nn.Upsample(scale_factor=2)
def forward(self, x):
"""Forward function."""
up1 = self.up1(x)
low1 = self.low1(x)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return up1 + up2
@BACKBONES.register_module()
class HourglassNet(nn.Module):
"""HourglassNet backbone.
Stacked Hourglass Networks for Human Pose Estimation.
More details can be found in the `paper
<https://arxiv.org/abs/1603.06937>`_ .
Args:
downsample_times (int): Downsample times in a HourglassModule.
num_stacks (int): Number of HourglassModule modules stacked,
1 for Hourglass-52, 2 for Hourglass-104.
stage_channels (list[int]): Feature channel of each sub-module in a
HourglassModule.
stage_blocks (list[int]): Number of sub-modules stacked in a
HourglassModule.
feat_channel (int): Feature channel of conv after a HourglassModule.
norm_cfg (dict): Dictionary to construct and config norm layer.
Example:
>>> from mmdet.models import HourglassNet
>>> import torch
>>> self = HourglassNet()
>>> self.eval()
>>> inputs = torch.rand(1, 3, 511, 511)
>>> level_outputs = self.forward(inputs)
>>> for level_output in level_outputs:
... print(tuple(level_output.shape))
(1, 256, 128, 128)
(1, 256, 128, 128)
"""
def __init__(self,
downsample_times=5,
num_stacks=2,
stage_channels=(256, 256, 384, 384, 384, 512),
stage_blocks=(2, 2, 2, 2, 2, 4),
feat_channel=256,
norm_cfg=dict(type='BN', requires_grad=True)):
super(HourglassNet, self).__init__()
self.num_stacks = num_stacks
assert self.num_stacks >= 1
assert len(stage_channels) == len(stage_blocks)
assert len(stage_channels) > downsample_times
cur_channel = stage_channels[0]
self.stem = nn.Sequential(
ConvModule(3, 128, 7, padding=3, stride=2, norm_cfg=norm_cfg),
ResLayer(BasicBlock, 128, 256, 1, stride=2, norm_cfg=norm_cfg))
self.hourglass_modules = nn.ModuleList([
HourglassModule(downsample_times, stage_channels, stage_blocks)
for _ in range(num_stacks)
])
self.inters = ResLayer(
BasicBlock,
cur_channel,
cur_channel,
num_stacks - 1,
norm_cfg=norm_cfg)
self.conv1x1s = nn.ModuleList([
ConvModule(
cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
for _ in range(num_stacks - 1)
])
self.out_convs = nn.ModuleList([
ConvModule(
cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
for _ in range(num_stacks)
])
self.remap_convs = nn.ModuleList([
ConvModule(
feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
for _ in range(num_stacks - 1)
])
self.relu = nn.ReLU(inplace=True)
def init_weights(self, pretrained=None):
"""Init module weights.
We do nothing in this function because all modules we used
(ConvModule, BasicBlock and etc.) have default initialization, and
currently we don't provide pretrained model of HourglassNet.
Detector's __init__() will call backbone's init_weights() with
pretrained as input, so we keep this function.
"""
# Training Centripetal Model needs to reset parameters for Conv2d
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.reset_parameters()
def forward(self, x):
"""Forward function."""
inter_feat = self.stem(x)
out_feats = []
for ind in range(self.num_stacks):
single_hourglass = self.hourglass_modules[ind]
out_conv = self.out_convs[ind]
hourglass_feat = single_hourglass(inter_feat)
out_feat = out_conv(hourglass_feat)
out_feats.append(out_feat)
if ind < self.num_stacks - 1:
inter_feat = self.conv1x1s[ind](
inter_feat) + self.remap_convs[ind](
out_feat)
inter_feat = self.inters[ind](self.relu(inter_feat))
return out_feats
import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
kaiming_init)
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.utils import get_root_logger
from ..builder import BACKBONES
from .resnet import BasicBlock, Bottleneck
class HRModule(nn.Module):
"""High-Resolution Module for HRNet.
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
is in this module.
"""
def __init__(self,
num_branches,
blocks,
num_blocks,
in_channels,
num_channels,
multiscale_output=True,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN')):
super(HRModule, self).__init__()
self._check_branches(num_branches, num_blocks, in_channels,
num_channels)
self.in_channels = in_channels
self.num_branches = num_branches
self.multiscale_output = multiscale_output
self.norm_cfg = norm_cfg
self.conv_cfg = conv_cfg
self.with_cp = with_cp
self.branches = self._make_branches(num_branches, blocks, num_blocks,
num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=False)
def _check_branches(self, num_branches, num_blocks, in_channels,
num_channels):
if num_branches != len(num_blocks):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_BLOCKS({len(num_blocks)})'
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_CHANNELS({len(num_channels)})'
raise ValueError(error_msg)
if num_branches != len(in_channels):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_INCHANNELS({len(in_channels)})'
raise ValueError(error_msg)
def _make_one_branch(self,
branch_index,
block,
num_blocks,
num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.in_channels[branch_index] != \
num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
self.in_channels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
block.expansion)[1])
layers = []
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
self.in_channels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
in_channels = self.in_channels
fuse_layers = []
num_out_branches = num_branches if self.multiscale_output else 1
for i in range(num_out_branches):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg, in_channels[i])[1],
nn.Upsample(
scale_factor=2**(j - i), mode='nearest')))
elif j == i:
fuse_layer.append(None)
else:
conv_downsamples = []
for k in range(i - j):
if k == i - j - 1:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[i])[1]))
else:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[j])[1],
nn.ReLU(inplace=False)))
fuse_layer.append(nn.Sequential(*conv_downsamples))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def forward(self, x):
"""Forward function."""
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = 0
for j in range(self.num_branches):
if i == j:
y += x[j]
else:
y += self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
@BACKBONES.register_module()
class HRNet(nn.Module):
"""HRNet backbone.
High-Resolution Representations for Labeling Pixels and Regions
arXiv: https://arxiv.org/abs/1904.04514
Args:
extra (dict): detailed configuration for each stage of HRNet.
in_channels (int): Number of input image channels. Default: 3.
conv_cfg (dict): dictionary to construct and config conv layer.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import HRNet
>>> import torch
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,
>>> num_branches=1,
>>> block='BOTTLENECK',
>>> num_blocks=(4, ),
>>> num_channels=(64, )),
>>> stage2=dict(
>>> num_modules=1,
>>> num_branches=2,
>>> block='BASIC',
>>> num_blocks=(4, 4),
>>> num_channels=(32, 64)),
>>> stage3=dict(
>>> num_modules=4,
>>> num_branches=3,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4),
>>> num_channels=(32, 64, 128)),
>>> stage4=dict(
>>> num_modules=3,
>>> num_branches=4,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4, 4),
>>> num_channels=(32, 64, 128, 256)))
>>> self = HRNet(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 32, 8, 8)
(1, 64, 4, 4)
(1, 128, 2, 2)
(1, 256, 1, 1)
"""
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
def __init__(self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN'),
norm_eval=True,
with_cp=False,
zero_init_residual=False):
super(HRNet, self).__init__()
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
64,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
# stage 1
self.stage1_cfg = self.extra['stage1']
num_channels = self.stage1_cfg['num_channels'][0]
block_type = self.stage1_cfg['block']
num_blocks = self.stage1_cfg['num_blocks'][0]
block = self.blocks_dict[block_type]
stage1_out_channels = num_channels * block.expansion
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
# stage 2
self.stage2_cfg = self.extra['stage2']
num_channels = self.stage2_cfg['num_channels']
block_type = self.stage2_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition1 = self._make_transition_layer([stage1_out_channels],
num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
# stage 3
self.stage3_cfg = self.extra['stage3']
num_channels = self.stage3_cfg['num_channels']
block_type = self.stage3_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition2 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
# stage 4
self.stage4_cfg = self.extra['stage4']
num_channels = self.stage4_cfg['num_channels']
block_type = self.stage4_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition3 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: the normalization layer named "norm2" """
return getattr(self, self.norm2_name)
def _make_transition_layer(self, num_channels_pre_layer,
num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
num_channels_cur_layer[i])[1],
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv_downsamples = []
for j in range(i + 1 - num_branches_pre):
in_channels = num_channels_pre_layer[-1]
out_channels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else in_channels
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = []
layers.append(
block(
inplanes,
planes,
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
inplanes,
planes,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*layers)
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
num_modules = layer_config['num_modules']
num_branches = layer_config['num_branches']
num_blocks = layer_config['num_blocks']
num_channels = layer_config['num_channels']
block = self.blocks_dict[layer_config['block']]
hr_modules = []
for i in range(num_modules):
# multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1:
reset_multiscale_output = False
else:
reset_multiscale_output = True
hr_modules.append(
HRModule(
num_branches,
block,
num_blocks,
in_channels,
num_channels,
reset_multiscale_output,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*hr_modules), in_channels
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Forward function."""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['num_branches']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['num_branches']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['num_branches']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage4(x_list)
return y_list
def train(self, mode=True):
"""Convert the model into training mode whill keeping the normalization
layer freezed."""
super(HRNet, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
import numpy as np
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from .resnet import ResNet
from .resnext import Bottleneck
@BACKBONES.register_module()
class RegNet(ResNet):
"""RegNet backbone.
More details can be found in `paper <https://arxiv.org/abs/2003.13678>`_ .
Args:
arch (dict): The parameter of RegNets.
- w0 (int): initial width
- wa (float): slope of width
- wm (float): quantization parameter to quantize the width
- depth (int): depth of the backbone
- group_w (int): width of group
- bot_mul (float): bottleneck ratio, i.e. expansion of bottlneck.
strides (Sequence[int]): Strides of the first block of each stage.
base_channels (int): Base channels after stem layer.
in_channels (int): Number of input image channels. Default: 3.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import RegNet
>>> import torch
>>> self = RegNet(
arch=dict(
w0=88,
wa=26.31,
wm=2.25,
group_w=48,
depth=25,
bot_mul=1.0))
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 96, 8, 8)
(1, 192, 4, 4)
(1, 432, 2, 2)
(1, 1008, 1, 1)
"""
arch_settings = {
'regnetx_400mf':
dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
'regnetx_800mf':
dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
'regnetx_1.6gf':
dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
'regnetx_3.2gf':
dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
'regnetx_4.0gf':
dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
'regnetx_6.4gf':
dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
'regnetx_8.0gf':
dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
'regnetx_12gf':
dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
}
def __init__(self,
arch,
in_channels=3,
stem_channels=32,
base_channels=32,
strides=(2, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
deep_stem=False,
avg_down=False,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
dcn=None,
stage_with_dcn=(False, False, False, False),
plugins=None,
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
# Generate RegNet parameters first
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'"arch": "{arch}" is not one of the' \
' arch_settings'
arch = self.arch_settings[arch]
elif not isinstance(arch, dict):
raise ValueError('Expect "arch" to be either a string '
f'or a dict, got {type(arch)}')
widths, num_stages = self.generate_regnet(
arch['w0'],
arch['wa'],
arch['wm'],
arch['depth'],
)
# Convert to per stage format
stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
# Generate group widths and bot muls
group_widths = [arch['group_w'] for _ in range(num_stages)]
self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
# Adjust the compatibility of stage_widths and group_widths
stage_widths, group_widths = self.adjust_width_group(
stage_widths, self.bottleneck_ratio, group_widths)
# Group params by stage
self.stage_widths = stage_widths
self.group_widths = group_widths
self.depth = sum(stage_blocks)
self.stem_channels = stem_channels
self.base_channels = base_channels
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.deep_stem = deep_stem
self.avg_down = avg_down
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.dcn = dcn
self.stage_with_dcn = stage_with_dcn
if dcn is not None:
assert len(stage_with_dcn) == num_stages
self.plugins = plugins
self.zero_init_residual = zero_init_residual
self.block = Bottleneck
expansion_bak = self.block.expansion
self.block.expansion = 1
self.stage_blocks = stage_blocks[:num_stages]
self._make_stem_layer(in_channels, stem_channels)
self.inplanes = stem_channels
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = self.strides[i]
dilation = self.dilations[i]
group_width = self.group_widths[i]
width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
stage_groups = width // group_width
dcn = self.dcn if self.stage_with_dcn[i] else None
if self.plugins is not None:
stage_plugins = self.make_stage_plugins(self.plugins, i)
else:
stage_plugins = None
res_layer = self.make_res_layer(
block=self.block,
inplanes=self.inplanes,
planes=self.stage_widths[i],
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=self.with_cp,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
dcn=dcn,
plugins=stage_plugins,
groups=stage_groups,
base_width=group_width,
base_channels=self.stage_widths[i])
self.inplanes = self.stage_widths[i]
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = stage_widths[-1]
self.block.expansion = expansion_bak
def _make_stem_layer(self, in_channels, base_channels):
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
base_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, base_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
def generate_regnet(self,
initial_width,
width_slope,
width_parameter,
depth,
divisor=8):
"""Generates per block width from RegNet parameters.
Args:
initial_width ([int]): Initial width of the backbone
width_slope ([float]): Slope of the quantized linear function
width_parameter ([int]): Parameter used to quantize the width.
depth ([int]): Depth of the backbone.
divisor (int, optional): The divisor of channels. Defaults to 8.
Returns:
list, int: return a list of widths of each stage and the number \
of stages
"""
assert width_slope >= 0
assert initial_width > 0
assert width_parameter > 1
assert initial_width % divisor == 0
widths_cont = np.arange(depth) * width_slope + initial_width
ks = np.round(
np.log(widths_cont / initial_width) / np.log(width_parameter))
widths = initial_width * np.power(width_parameter, ks)
widths = np.round(np.divide(widths, divisor)) * divisor
num_stages = len(np.unique(widths))
widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
return widths, num_stages
@staticmethod
def quantize_float(number, divisor):
"""Converts a float to closest non-zero int divisible by divior.
Args:
number (int): Original number to be quantized.
divisor (int): Divisor used to quantize the number.
Returns:
int: quantized number that is divisible by devisor.
"""
return int(round(number / divisor) * divisor)
def adjust_width_group(self, widths, bottleneck_ratio, groups):
"""Adjusts the compatibility of widths and groups.
Args:
widths (list[int]): Width of each stage.
bottleneck_ratio (float): Bottleneck ratio.
groups (int): number of groups in each stage
Returns:
tuple(list): The adjusted widths and groups of each stage.
"""
bottleneck_width = [
int(w * b) for w, b in zip(widths, bottleneck_ratio)
]
groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
bottleneck_width = [
self.quantize_float(w_bot, g)
for w_bot, g in zip(bottleneck_width, groups)
]
widths = [
int(w_bot / b)
for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
]
return widths, groups
def get_stages_from_blocks(self, widths):
"""Gets widths/stage_blocks of network at each stage.
Args:
widths (list[int]): Width in each stage.
Returns:
tuple(list): width and depth of each stage
"""
width_diff = [
width != width_prev
for width, width_prev in zip(widths + [0], [0] + widths)
]
stage_widths = [
width for width, diff in zip(widths, width_diff[:-1]) if diff
]
stage_blocks = np.diff([
depth for depth, diff in zip(range(len(width_diff)), width_diff)
if diff
]).tolist()
return stage_widths, stage_blocks
def forward(self, x):
"""Forward function."""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
kaiming_init)
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.utils import get_root_logger
from ..builder import BACKBONES
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
class Bottle2neck(_Bottleneck):
expansion = 4
def __init__(self,
inplanes,
planes,
scales=4,
base_width=26,
base_channels=64,
stage_type='normal',
**kwargs):
"""Bottle2neck block for Res2Net.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottle2neck, self).__init__(inplanes, planes, **kwargs)
assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.'
width = int(math.floor(self.planes * (base_width / base_channels)))
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width * scales, postfix=1)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width * scales,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
if stage_type == 'stage' and self.conv2_stride != 1:
self.pool = nn.AvgPool2d(
kernel_size=3, stride=self.conv2_stride, padding=1)
convs = []
bns = []
fallback_on_stride = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
for i in range(scales - 1):
convs.append(
build_conv_layer(
self.conv_cfg,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
bias=False))
bns.append(
build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1])
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
for i in range(scales - 1):
convs.append(
build_conv_layer(
self.dcn,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
bias=False))
bns.append(
build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1])
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.conv3 = build_conv_layer(
self.conv_cfg,
width * scales,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
self.stage_type = stage_type
self.scales = scales
self.width = width
delattr(self, 'conv2')
delattr(self, self.norm2_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
spx = torch.split(out, self.width, 1)
sp = self.convs[0](spx[0].contiguous())
sp = self.relu(self.bns[0](sp))
out = sp
for i in range(1, self.scales - 1):
if self.stage_type == 'stage':
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp.contiguous())
sp = self.relu(self.bns[i](sp))
out = torch.cat((out, sp), 1)
if self.stage_type == 'normal' or self.conv2_stride == 1:
out = torch.cat((out, spx[self.scales - 1]), 1)
elif self.stage_type == 'stage':
out = torch.cat((out, self.pool(spx[self.scales - 1])), 1)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
class Res2Layer(nn.Sequential):
"""Res2Layer to build Res2Net style backbone.
Args:
block (nn.Module): block used to build ResLayer.
inplanes (int): inplanes of block.
planes (int): planes of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottle2neck. Default: False
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
scales (int): Scales used in Res2Net. Default: 4
base_width (int): Basic width of each scale. Default: 26
"""
def __init__(self,
block,
inplanes,
planes,
num_blocks,
stride=1,
avg_down=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
scales=4,
base_width=26,
**kwargs):
self.block = block
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False),
build_conv_layer(
conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=1,
bias=False),
build_norm_layer(norm_cfg, planes * block.expansion)[1],
)
layers = []
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
scales=scales,
base_width=base_width,
stage_type='stage',
**kwargs))
inplanes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
scales=scales,
base_width=base_width,
**kwargs))
super(Res2Layer, self).__init__(*layers)
@BACKBONES.register_module()
class Res2Net(ResNet):
"""Res2Net backbone.
Args:
scales (int): Scales used in Res2Net. Default: 4
base_width (int): Basic width of each scale. Default: 26
depth (int): Depth of res2net, from {50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
num_stages (int): Res2net stages. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottle2neck.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): Dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- position (str, required): Position inside block to insert
plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import Res2Net
>>> import torch
>>> self = Res2Net(depth=50, scales=4, base_width=26)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 8, 8)
(1, 512, 4, 4)
(1, 1024, 2, 2)
(1, 2048, 1, 1)
"""
arch_settings = {
50: (Bottle2neck, (3, 4, 6, 3)),
101: (Bottle2neck, (3, 4, 23, 3)),
152: (Bottle2neck, (3, 8, 36, 3))
}
def __init__(self,
scales=4,
base_width=26,
style='pytorch',
deep_stem=True,
avg_down=True,
**kwargs):
self.scales = scales
self.base_width = base_width
super(Res2Net, self).__init__(
style='pytorch', deep_stem=True, avg_down=True, **kwargs)
def make_res_layer(self, **kwargs):
return Res2Layer(
scales=self.scales,
base_width=self.base_width,
base_channels=self.base_channels,
**kwargs)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottle2neck):
# dcn in Res2Net bottle2neck is in ModuleList
for n in m.convs:
if hasattr(n, 'conv_offset'):
constant_init(n.conv_offset, 0)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottle2neck):
constant_init(m.norm3, 0)
else:
raise TypeError('pretrained must be a str or None')
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNetV1d
class RSoftmax(nn.Module):
"""Radix Softmax module in ``SplitAttentionConv2d``.
Args:
radix (int): Radix of input.
groups (int): Groups of input.
"""
def __init__(self, radix, groups):
super().__init__()
self.radix = radix
self.groups = groups
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class SplitAttentionConv2d(nn.Module):
"""Split-Attention Conv2d in ResNeSt.
Args:
in_channels (int): Number of channels in the input feature map.
channels (int): Number of intermediate channels.
kernel_size (int | tuple[int]): Size of the convolution kernel.
stride (int | tuple[int]): Stride of the convolution.
padding (int | tuple[int]): Zero-padding added to both sides of
dilation (int | tuple[int]): Spacing between kernel elements.
groups (int): Number of blocked connections from input channels to
output channels.
groups (int): Same as nn.Conv2d.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
dcn (dict): Config dict for DCN. Default: None.
"""
def __init__(self,
in_channels,
channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
radix=2,
reduction_factor=4,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None):
super(SplitAttentionConv2d, self).__init__()
inter_channels = max(in_channels * radix // reduction_factor, 32)
self.radix = radix
self.groups = groups
self.channels = channels
self.with_dcn = dcn is not None
self.dcn = dcn
fallback_on_stride = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if self.with_dcn and not fallback_on_stride:
assert conv_cfg is None, 'conv_cfg must be None for DCN'
conv_cfg = dcn
self.conv = build_conv_layer(
conv_cfg,
in_channels,
channels * radix,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups * radix,
bias=False)
# To be consistent with original implementation, starting from 0
self.norm0_name, norm0 = build_norm_layer(
norm_cfg, channels * radix, postfix=0)
self.add_module(self.norm0_name, norm0)
self.relu = nn.ReLU(inplace=True)
self.fc1 = build_conv_layer(
None, channels, inter_channels, 1, groups=self.groups)
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, inter_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.fc2 = build_conv_layer(
None, inter_channels, channels * radix, 1, groups=self.groups)
self.rsoftmax = RSoftmax(radix, groups)
@property
def norm0(self):
"""nn.Module: the normalization layer named "norm0" """
return getattr(self, self.norm0_name)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
def forward(self, x):
x = self.conv(x)
x = self.norm0(x)
x = self.relu(x)
batch, rchannel = x.shape[:2]
batch = x.size(0)
if self.radix > 1:
splits = x.view(batch, self.radix, -1, *x.shape[2:])
gap = splits.sum(dim=1)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
gap = self.fc1(gap)
gap = self.norm1(gap)
gap = self.relu(gap)
atten = self.fc2(gap)
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
if self.radix > 1:
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
out = torch.sum(attens * splits, dim=1)
else:
out = atten * x
return out.contiguous()
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeSt.
Args:
inplane (int): Input planes of this block.
planes (int): Middle planes of this block.
groups (int): Groups of conv2.
base_width (int): Base of width in terms of base channels. Default: 4.
base_channels (int): Base of channels for calculating width.
Default: 64.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels in
SplitAttentionConv2d. Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
kwargs (dict): Key word arguments for base class.
"""
expansion = 4
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
"""Bottleneck block for ResNeSt."""
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.with_modulated_dcn = False
self.conv2 = SplitAttentionConv2d(
width,
width,
kernel_size=3,
stride=1 if self.avg_down_stride else self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
radix=radix,
reduction_factor=reduction_factor,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
dcn=self.dcn)
delattr(self, self.norm2_name)
if self.avg_down_stride:
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
if self.avg_down_stride:
out = self.avd_layer(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@BACKBONES.register_module()
class ResNeSt(ResNetV1d):
"""ResNeSt backbone.
Args:
groups (int): Number of groups of Bottleneck. Default: 1
base_width (int): Base width of Bottleneck. Default: 4
radix (int): Radix of SplitAttentionConv2d. Default: 2
reduction_factor (int): Reduction factor of inter_channels in
SplitAttentionConv2d. Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
kwargs (dict): Keyword arguments for ResNet.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3)),
200: (Bottleneck, (3, 24, 36, 3))
}
def __init__(self,
groups=1,
base_width=4,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
self.groups = groups
self.base_width = base_width
self.radix = radix
self.reduction_factor = reduction_factor
self.avg_down_stride = avg_down_stride
super(ResNeSt, self).__init__(**kwargs)
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``."""
return ResLayer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
radix=self.radix,
reduction_factor=self.reduction_factor,
avg_down_stride=self.avg_down_stride,
**kwargs)
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import ResLayer
class BasicBlock(nn.Module):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None):
super(BasicBlock, self).__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg, planes, planes, 3, padding=1, bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
@property
def norm1(self):
"""nn.Module: normalization layer after the first convolution layer"""
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: normalization layer after the second convolution layer"""
return getattr(self, self.norm2_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None):
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
assert dcn is None or isinstance(dcn, dict)
assert plugins is None or isinstance(plugins, list)
if plugins is not None:
allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
assert all(p['position'] in allowed_position for p in plugins)
self.inplanes = inplanes
self.planes = planes
self.stride = stride
self.dilation = dilation
self.style = style
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.dcn = dcn
self.with_dcn = dcn is not None
self.plugins = plugins
self.with_plugins = plugins is not None
if self.with_plugins:
# collect plugins for conv1/conv2/conv3
self.after_conv1_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv1'
]
self.after_conv2_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv2'
]
self.after_conv3_plugins = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv3'
]
if self.style == 'pytorch':
self.conv1_stride = 1
self.conv2_stride = stride
else:
self.conv1_stride = stride
self.conv2_stride = 1
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
norm_cfg, planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
if self.with_dcn:
fallback_on_stride = dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
conv_cfg,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
self.conv2 = build_conv_layer(
dcn,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
conv_cfg,
planes,
planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
if self.with_plugins:
self.after_conv1_plugin_names = self.make_block_plugins(
planes, self.after_conv1_plugins)
self.after_conv2_plugin_names = self.make_block_plugins(
planes, self.after_conv2_plugins)
self.after_conv3_plugin_names = self.make_block_plugins(
planes * self.expansion, self.after_conv3_plugins)
def make_block_plugins(self, in_channels, plugins):
"""make plugins for block.
Args:
in_channels (int): Input channels of plugin.
plugins (list[dict]): List of plugins cfg to build.
Returns:
list[str]: List of the names of plugin.
"""
assert isinstance(plugins, list)
plugin_names = []
for plugin in plugins:
plugin = plugin.copy()
name, layer = build_plugin_layer(
plugin,
in_channels=in_channels,
postfix=plugin.pop('postfix', ''))
assert not hasattr(self, name), f'duplicate plugin {name}'
self.add_module(name, layer)
plugin_names.append(name)
return plugin_names
def forward_plugin(self, x, plugin_names):
out = x
for name in plugin_names:
out = getattr(self, name)(x)
return out
@property
def norm1(self):
"""nn.Module: normalization layer after the first convolution layer"""
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: normalization layer after the second convolution layer"""
return getattr(self, self.norm2_name)
@property
def norm3(self):
"""nn.Module: normalization layer after the third convolution layer"""
return getattr(self, self.norm3_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@BACKBONES.register_module()
class ResNet(nn.Module):
"""ResNet backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
stem_channels (int | None): Number of stem channels. If not specified,
it will be the same as `base_channels`. Default: None.
base_channels (int): Number of base channels of res layer. Default: 64.
in_channels (int): Number of input image channels. Default: 3.
num_stages (int): Resnet stages. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): Dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- position (str, required): Position inside block to insert
plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import ResNet
>>> import torch
>>> self = ResNet(depth=18)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 8, 8)
(1, 128, 4, 4)
(1, 256, 2, 2)
(1, 512, 1, 1)
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth,
in_channels=3,
stem_channels=None,
base_channels=64,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
deep_stem=False,
avg_down=False,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
dcn=None,
stage_with_dcn=(False, False, False, False),
plugins=None,
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.depth = depth
if stem_channels is None:
stem_channels = base_channels
self.stem_channels = stem_channels
self.base_channels = base_channels
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.deep_stem = deep_stem
self.avg_down = avg_down
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.dcn = dcn
self.stage_with_dcn = stage_with_dcn
if dcn is not None:
assert len(stage_with_dcn) == num_stages
self.plugins = plugins
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = stem_channels
self._make_stem_layer(in_channels, stem_channels)
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
dcn = self.dcn if self.stage_with_dcn[i] else None
if plugins is not None:
stage_plugins = self.make_stage_plugins(plugins, i)
else:
stage_plugins = None
planes = base_channels * 2**i
res_layer = self.make_res_layer(
block=self.block,
inplanes=self.inplanes,
planes=planes,
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
plugins=stage_plugins)
self.inplanes = planes * self.block.expansion
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = self.block.expansion * base_channels * 2**(
len(self.stage_blocks) - 1)
def make_stage_plugins(self, plugins, stage_idx):
"""Make plugins for ResNet ``stage_idx`` th stage.
Currently we support to insert ``context_block``,
``empirical_attention_block``, ``nonlocal_block`` into the backbone
like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
Bottleneck.
An example of plugins format could be:
Examples:
>>> plugins=[
... dict(cfg=dict(type='xxx', arg1='xxx'),
... stages=(False, True, True, True),
... position='after_conv2'),
... dict(cfg=dict(type='yyy'),
... stages=(True, True, True, True),
... position='after_conv3'),
... dict(cfg=dict(type='zzz', postfix='1'),
... stages=(True, True, True, True),
... position='after_conv3'),
... dict(cfg=dict(type='zzz', postfix='2'),
... stages=(True, True, True, True),
... position='after_conv3')
... ]
>>> self = ResNet(depth=18)
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
>>> assert len(stage_plugins) == 3
Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
.. code-block:: none
conv1-> conv2->conv3->yyy->zzz1->zzz2
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
.. code-block:: none
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
If stages is missing, the plugin would be applied to all stages.
Args:
plugins (list[dict]): List of plugins cfg to build. The postfix is
required if multiple same type plugins are inserted.
stage_idx (int): Index of stage to build
Returns:
list[dict]: Plugins for current stage
"""
stage_plugins = []
for plugin in plugins:
plugin = plugin.copy()
stages = plugin.pop('stages', None)
assert stages is None or len(stages) == self.num_stages
# whether to insert plugin into current stage
if stages is None or stages[stage_idx]:
stage_plugins.append(plugin)
return stage_plugins
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``."""
return ResLayer(**kwargs)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
def _make_stem_layer(self, in_channels, stem_channels):
if self.deep_stem:
self.stem = nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
stem_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
nn.ReLU(inplace=True),
build_conv_layer(
self.conv_cfg,
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, stem_channels)[1],
nn.ReLU(inplace=True))
else:
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, stem_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def _freeze_stages(self):
if self.frozen_stages >= 0:
if self.deep_stem:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
else:
self.norm1.eval()
for m in [self.conv1, self.norm1]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'layer{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) and hasattr(
m.conv2, 'conv_offset'):
constant_init(m.conv2.conv_offset, 0)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
"""Forward function."""
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(ResNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
@BACKBONES.register_module()
class ResNetV1d(ResNet):
r"""ResNetV1d variant described in `Bag of Tricks
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
"""
def __init__(self, **kwargs):
super(ResNetV1d, self).__init__(
deep_stem=True, avg_down=True, **kwargs)
import math
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
class Bottleneck(_Bottleneck):
expansion = 4
def __init__(self,
inplanes,
planes,
groups=1,
base_width=4,
base_channels=64,
**kwargs):
"""Bottleneck block for ResNeXt.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes *
(base_width / base_channels)) * groups
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, width, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
self.conv_cfg,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
self.conv2 = build_conv_layer(
self.dcn,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
if self.with_plugins:
self._del_block_plugins(self.after_conv1_plugin_names +
self.after_conv2_plugin_names +
self.after_conv3_plugin_names)
self.after_conv1_plugin_names = self.make_block_plugins(
width, self.after_conv1_plugins)
self.after_conv2_plugin_names = self.make_block_plugins(
width, self.after_conv2_plugins)
self.after_conv3_plugin_names = self.make_block_plugins(
self.planes * self.expansion, self.after_conv3_plugins)
def _del_block_plugins(self, plugin_names):
"""delete plugins for block if exist.
Args:
plugin_names (list[str]): List of plugins name to delete.
"""
assert isinstance(plugin_names, list)
for plugin_name in plugin_names:
del self._modules[plugin_name]
@BACKBONES.register_module()
class ResNeXt(ResNet):
"""ResNeXt backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
num_stages (int): Resnet stages. Default: 4.
groups (int): Group of resnext.
base_width (int): Base width of resnext.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, groups=1, base_width=4, **kwargs):
self.groups = groups
self.base_width = base_width
super(ResNeXt, self).__init__(**kwargs)
def make_res_layer(self, **kwargs):
"""Pack all blocks in a stage into a ``ResLayer``"""
return ResLayer(
groups=self.groups,
base_width=self.base_width,
base_channels=self.base_channels,
**kwargs)
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import VGG, constant_init, kaiming_init, normal_init, xavier_init
from mmcv.runner import load_checkpoint
from mmdet.utils import get_root_logger
from ..builder import BACKBONES
@BACKBONES.register_module()
class SSDVGG(VGG):
"""VGG Backbone network for single-shot-detection.
Args:
input_size (int): width and height of input, from {300, 512}.
depth (int): Depth of vgg, from {11, 13, 16, 19}.
out_indices (Sequence[int]): Output from which stages.
Example:
>>> self = SSDVGG(input_size=300, depth=11)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 300, 300)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 1024, 19, 19)
(1, 512, 10, 10)
(1, 256, 5, 5)
(1, 256, 3, 3)
(1, 256, 1, 1)
"""
extra_setting = {
300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128),
}
def __init__(self,
input_size,
depth,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
l2_norm_scale=20.):
# TODO: in_channels for mmcv.VGG
super(SSDVGG, self).__init__(
depth,
with_last_pool=with_last_pool,
ceil_mode=ceil_mode,
out_indices=out_indices)
assert input_size in (300, 512)
self.input_size = input_size
self.features.add_module(
str(len(self.features)),
nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
self.features.add_module(
str(len(self.features)),
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6))
self.features.add_module(
str(len(self.features)), nn.ReLU(inplace=True))
self.features.add_module(
str(len(self.features)), nn.Conv2d(1024, 1024, kernel_size=1))
self.features.add_module(
str(len(self.features)), nn.ReLU(inplace=True))
self.out_feature_indices = out_feature_indices
self.inplanes = 1024
self.extra = self._make_extra_layers(self.extra_setting[input_size])
self.l2_norm = L2Norm(
self.features[out_feature_indices[0] - 1].out_channels,
l2_norm_scale)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.features.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
else:
raise TypeError('pretrained must be a str or None')
for m in self.extra.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
constant_init(self.l2_norm, self.l2_norm.scale)
def forward(self, x):
"""Forward function."""
outs = []
for i, layer in enumerate(self.features):
x = layer(x)
if i in self.out_feature_indices:
outs.append(x)
for i, layer in enumerate(self.extra):
x = F.relu(layer(x), inplace=True)
if i % 2 == 1:
outs.append(x)
outs[0] = self.l2_norm(outs[0])
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def _make_extra_layers(self, outplanes):
layers = []
kernel_sizes = (1, 3)
num_layers = 0
outplane = None
for i in range(len(outplanes)):
if self.inplanes == 'S':
self.inplanes = outplane
continue
k = kernel_sizes[num_layers % 2]
if outplanes[i] == 'S':
outplane = outplanes[i + 1]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=2, padding=1)
else:
outplane = outplanes[i]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=1, padding=0)
layers.append(conv)
self.inplanes = outplanes[i]
num_layers += 1
if self.input_size == 512:
layers.append(nn.Conv2d(self.inplanes, 256, 4, padding=1))
return nn.Sequential(*layers)
class L2Norm(nn.Module):
def __init__(self, n_dims, scale=20., eps=1e-10):
"""L2 normalization layer.
Args:
n_dims (int): Number of dimensions to be normalized
scale (float, optional): Defaults to 20..
eps (float, optional): Used to avoid division by zero.
Defaults to 1e-10.
"""
super(L2Norm, self).__init__()
self.n_dims = n_dims
self.weight = nn.Parameter(torch.Tensor(self.n_dims))
self.eps = eps
self.scale = scale
def forward(self, x):
"""Forward function."""
# normalization layer convert to FP32 in FP16 training
x_float = x.float()
norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
return (self.weight[None, :, None, None].float().expand_as(x_float) *
x_float / norm).type_as(x)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer, kaiming_init
from torch.nn.modules.utils import _pair
from mmdet.models.backbones.resnet import Bottleneck, ResNet
from mmdet.models.builder import BACKBONES
class TridentConv(nn.Module):
"""Trident Convolution Module.
Args:
in_channels (int): Number of channels in input.
out_channels (int): Number of channels in output.
kernel_size (int): Size of convolution kernel.
stride (int, optional): Convolution stride. Default: 1.
trident_dilations (tuple[int, int, int], optional): Dilations of
different trident branch. Default: (1, 2, 3).
test_branch_idx (int, optional): In inference, all 3 branches will
be used if `test_branch_idx==-1`, otherwise only branch with
index `test_branch_idx` will be used. Default: 1.
bias (bool, optional): Whether to use bias in convolution or not.
Default: False.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
trident_dilations=(1, 2, 3),
test_branch_idx=1,
bias=False):
super(TridentConv, self).__init__()
self.num_branch = len(trident_dilations)
self.with_bias = bias
self.test_branch_idx = test_branch_idx
self.stride = _pair(stride)
self.kernel_size = _pair(kernel_size)
self.paddings = _pair(trident_dilations)
self.dilations = trident_dilations
self.in_channels = in_channels
self.out_channels = out_channels
self.bias = bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.init_weights()
def init_weights(self):
kaiming_init(self, distribution='uniform', mode='fan_in')
def extra_repr(self):
tmpstr = f'in_channels={self.in_channels}'
tmpstr += f', out_channels={self.out_channels}'
tmpstr += f', kernel_size={self.kernel_size}'
tmpstr += f', num_branch={self.num_branch}'
tmpstr += f', test_branch_idx={self.test_branch_idx}'
tmpstr += f', stride={self.stride}'
tmpstr += f', paddings={self.paddings}'
tmpstr += f', dilations={self.dilations}'
tmpstr += f', bias={self.bias}'
return tmpstr
def forward(self, inputs):
if self.training or self.test_branch_idx == -1:
outputs = [
F.conv2d(input, self.weight, self.bias, self.stride, padding,
dilation) for input, dilation, padding in zip(
inputs, self.dilations, self.paddings)
]
else:
assert len(inputs) == 1
outputs = [
F.conv2d(inputs[0], self.weight, self.bias, self.stride,
self.paddings[self.test_branch_idx],
self.dilations[self.test_branch_idx])
]
return outputs
# Since TridentNet is defined over ResNet50 and ResNet101, here we
# only support TridentBottleneckBlock.
class TridentBottleneck(Bottleneck):
"""BottleBlock for TridentResNet.
Args:
trident_dilations (tuple[int, int, int]): Dilations of different
trident branch.
test_branch_idx (int): In inference, all 3 branches will be used
if `test_branch_idx==-1`, otherwise only branch with index
`test_branch_idx` will be used.
concat_output (bool): Whether to concat the output list to a Tensor.
`True` only in the last Block.
"""
def __init__(self, trident_dilations, test_branch_idx, concat_output,
**kwargs):
super(TridentBottleneck, self).__init__(**kwargs)
self.trident_dilations = trident_dilations
self.num_branch = len(trident_dilations)
self.concat_output = concat_output
self.test_branch_idx = test_branch_idx
self.conv2 = TridentConv(
self.planes,
self.planes,
kernel_size=3,
stride=self.conv2_stride,
bias=False,
trident_dilations=self.trident_dilations,
test_branch_idx=test_branch_idx)
def forward(self, x):
def _inner_forward(x):
num_branch = (
self.num_branch
if self.training or self.test_branch_idx == -1 else 1)
identity = x
if not isinstance(x, list):
x = (x, ) * num_branch
identity = x
if self.downsample is not None:
identity = [self.downsample(b) for b in x]
out = [self.conv1(b) for b in x]
out = [self.norm1(b) for b in out]
out = [self.relu(b) for b in out]
if self.with_plugins:
for k in range(len(out)):
out[k] = self.forward_plugin(out[k],
self.after_conv1_plugin_names)
out = self.conv2(out)
out = [self.norm2(b) for b in out]
out = [self.relu(b) for b in out]
if self.with_plugins:
for k in range(len(out)):
out[k] = self.forward_plugin(out[k],
self.after_conv2_plugin_names)
out = [self.conv3(b) for b in out]
out = [self.norm3(b) for b in out]
if self.with_plugins:
for k in range(len(out)):
out[k] = self.forward_plugin(out[k],
self.after_conv3_plugin_names)
out = [
out_b + identity_b for out_b, identity_b in zip(out, identity)
]
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = [self.relu(b) for b in out]
if self.concat_output:
out = torch.cat(out, dim=0)
return out
def make_trident_res_layer(block,
inplanes,
planes,
num_blocks,
stride=1,
trident_dilations=(1, 2, 3),
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None,
test_branch_idx=-1):
"""Build Trident Res Layers."""
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = []
conv_stride = stride
downsample.extend([
build_conv_layer(
conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=conv_stride,
bias=False),
build_norm_layer(norm_cfg, planes * block.expansion)[1]
])
downsample = nn.Sequential(*downsample)
layers = []
for i in range(num_blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=stride if i == 0 else 1,
trident_dilations=trident_dilations,
downsample=downsample if i == 0 else None,
style=style,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
plugins=plugins,
test_branch_idx=test_branch_idx,
concat_output=True if i == num_blocks - 1 else False))
inplanes = planes * block.expansion
return nn.Sequential(*layers)
@BACKBONES.register_module()
class TridentResNet(ResNet):
"""The stem layer, stage 1 and stage 2 in Trident ResNet are identical to
ResNet, while in stage 3, Trident BottleBlock is utilized to replace the
normal BottleBlock to yield trident output. Different branch shares the
convolution weight but uses different dilations to achieve multi-scale
output.
/ stage3(b0) \
x - stem - stage1 - stage2 - stage3(b1) - output
\ stage3(b2) /
Args:
depth (int): Depth of resnet, from {50, 101, 152}.
num_branch (int): Number of branches in TridentNet.
test_branch_idx (int): In inference, all 3 branches will be used
if `test_branch_idx==-1`, otherwise only branch with index
`test_branch_idx` will be used.
trident_dilations (tuple[int]): Dilations of different trident branch.
len(trident_dilations) should be equal to num_branch.
""" # noqa
def __init__(self, depth, num_branch, test_branch_idx, trident_dilations,
**kwargs):
assert num_branch == len(trident_dilations)
assert depth in (50, 101, 152)
super(TridentResNet, self).__init__(depth, **kwargs)
assert self.num_stages == 3
self.test_branch_idx = test_branch_idx
self.num_branch = num_branch
last_stage_idx = self.num_stages - 1
stride = self.strides[last_stage_idx]
dilation = trident_dilations
dcn = self.dcn if self.stage_with_dcn[last_stage_idx] else None
if self.plugins is not None:
stage_plugins = self.make_stage_plugins(self.plugins,
last_stage_idx)
else:
stage_plugins = None
planes = self.base_channels * 2**last_stage_idx
res_layer = make_trident_res_layer(
TridentBottleneck,
inplanes=(self.block.expansion * self.base_channels *
2**(last_stage_idx - 1)),
planes=planes,
num_blocks=self.stage_blocks[last_stage_idx],
stride=stride,
trident_dilations=dilation,
style=self.style,
with_cp=self.with_cp,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
dcn=dcn,
plugins=stage_plugins,
test_branch_idx=self.test_branch_idx)
layer_name = f'layer{last_stage_idx + 1}'
self.__setattr__(layer_name, res_layer)
self.res_layers.pop(last_stage_idx)
self.res_layers.insert(last_stage_idx, layer_name)
self._freeze_stages()
import warnings
from mmcv.utils import Registry, build_from_cfg
from torch import nn
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_backbone(cfg):
"""Build backbone."""
return build(cfg, BACKBONES)
def build_neck(cfg):
"""Build neck."""
return build(cfg, NECKS)
def build_roi_extractor(cfg):
"""Build roi extractor."""
return build(cfg, ROI_EXTRACTORS)
def build_shared_head(cfg):
"""Build shared head."""
return build(cfg, SHARED_HEADS)
def build_head(cfg):
"""Build head."""
return build(cfg, HEADS)
def build_loss(cfg):
"""Build loss."""
return build(cfg, LOSSES)
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
from .anchor_free_head import AnchorFreeHead
from .anchor_head import AnchorHead
from .atss_head import ATSSHead
from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead
from .centripetal_head import CentripetalHead
from .corner_head import CornerHead
from .embedding_rpn_head import EmbeddingRPNHead
from .fcos_head import FCOSHead
from .fovea_head import FoveaHead
from .free_anchor_retina_head import FreeAnchorRetinaHead
from .fsaf_head import FSAFHead
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .gfl_head import GFLHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .nasfcos_head import NASFCOSHead
from .paa_head import PAAHead
from .pisa_retinanet_head import PISARetinaHead
from .pisa_ssd_head import PISASSDHead
from .reppoints_head import RepPointsHead
from .retina_head import RetinaHead
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .sabl_retina_head import SABLRetinaHead
from .ssd_head import SSDHead
from .transformer_head import TransformerHead
from .vfnet_head import VFNetHead
from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead
from .yolo_head import YOLOV3Head
__all__ = [
'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption',
'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead',
'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead',
'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead',
'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'TransformerHead',
'StageCascadeRPNHead', 'CascadeRPNHead', 'EmbeddingRPNHead'
]
from abc import abstractmethod
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from mmcv.runner import force_fp32
from mmdet.core import multi_apply
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
@HEADS.register_module()
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.).
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels. Used in child classes.
stacked_convs (int): Number of stacking convs of the head.
strides (tuple): Downsample factor of each feature map.
dcn_on_last_conv (bool): If true, use dcn in the last layer of
towers. Default: False.
conv_bias (bool | str): If specified as `auto`, it will be decided by
the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
None, otherwise False. Default: "auto".
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
""" # noqa: W605
_version = 1
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
dcn_on_last_conv=False,
conv_bias='auto',
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
conv_cfg=None,
norm_cfg=None,
train_cfg=None,
test_cfg=None):
super(AnchorFreeHead, self).__init__()
self.num_classes = num_classes
self.cls_out_channels = num_classes
self.in_channels = in_channels
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.dcn_on_last_conv = dcn_on_last_conv
assert conv_bias == 'auto' or isinstance(conv_bias, bool)
self.conv_bias = conv_bias
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.fp16_enabled = False
self._init_layers()
def _init_layers(self):
"""Initialize layers of the head."""
self._init_cls_convs()
self._init_reg_convs()
self._init_predictor()
def _init_cls_convs(self):
"""Initialize classification conv layers of the head."""
self.cls_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
if self.dcn_on_last_conv and i == self.stacked_convs - 1:
conv_cfg = dict(type='DCNv2')
else:
conv_cfg = self.conv_cfg
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.conv_bias))
def _init_reg_convs(self):
"""Initialize bbox regression conv layers of the head."""
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
if self.dcn_on_last_conv and i == self.stacked_convs - 1:
conv_cfg = dict(type='DCNv2')
else:
conv_cfg = self.conv_cfg
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.conv_bias))
def _init_predictor(self):
"""Initialize predictor layers of the head."""
self.conv_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
def init_weights(self):
"""Initialize weights of the head."""
for m in self.cls_convs:
if isinstance(m.conv, nn.Conv2d):
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
if isinstance(m.conv, nn.Conv2d):
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.conv_cls, std=0.01, bias=bias_cls)
normal_init(self.conv_reg, std=0.01)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Hack some keys of the model state dict so that can load checkpoints
of previous version."""
version = local_metadata.get('version', None)
if version is None:
# the key is different in early versions
# for example, 'fcos_cls' become 'conv_cls' now
bbox_head_keys = [
k for k in state_dict.keys() if k.startswith(prefix)
]
ori_predictor_keys = []
new_predictor_keys = []
# e.g. 'fcos_cls' or 'fcos_reg'
for key in bbox_head_keys:
ori_predictor_keys.append(key)
key = key.split('.')
conv_name = None
if key[1].endswith('cls'):
conv_name = 'conv_cls'
elif key[1].endswith('reg'):
conv_name = 'conv_reg'
elif key[1].endswith('centerness'):
conv_name = 'conv_centerness'
else:
assert NotImplementedError
if conv_name is not None:
key[1] = conv_name
new_predictor_keys.append('.'.join(key))
else:
ori_predictor_keys.pop(-1)
for i in range(len(new_predictor_keys)):
state_dict[new_predictor_keys[i]] = state_dict.pop(
ori_predictor_keys[i])
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
def forward(self, feats):
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually contain classification scores and bbox predictions.
cls_scores (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_points * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_points * 4.
"""
return multi_apply(self.forward_single, feats)[:2]
def forward_single(self, x):
"""Forward features of a single scale levle.
Args:
x (Tensor): FPN feature maps of the specified stride.
Returns:
tuple: Scores for each class, bbox predictions, features
after classification and regression conv layers, some
models needs these features like FCOS.
"""
cls_feat = x
reg_feat = x
for cls_layer in self.cls_convs:
cls_feat = cls_layer(cls_feat)
cls_score = self.conv_cls(cls_feat)
for reg_layer in self.reg_convs:
reg_feat = reg_layer(reg_feat)
bbox_pred = self.conv_reg(reg_feat)
return cls_score, bbox_pred, cls_feat, reg_feat
@abstractmethod
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute loss of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_points * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_points * 4.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
"""
raise NotImplementedError
@abstractmethod
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg=None,
rescale=None):
"""Transform network output for a batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_points * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_points * 4, H, W)
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
rescale (bool): If True, return boxes in original image space
"""
raise NotImplementedError
@abstractmethod
def get_targets(self, points, gt_bboxes_list, gt_labels_list):
"""Compute regression, classification and centerss targets for points
in multiple images.
Args:
points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2).
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
each has shape (num_gt, 4).
gt_labels_list (list[Tensor]): Ground truth labels of each box,
each has shape (num_gt,).
"""
raise NotImplementedError
def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Get points of a single scale level."""
h, w = featmap_size
x_range = torch.arange(w, dtype=dtype, device=device)
y_range = torch.arange(h, dtype=dtype, device=device)
y, x = torch.meshgrid(y_range, x_range)
if flatten:
y = y.flatten()
x = x.flatten()
return y, x
def get_points(self, featmap_sizes, dtype, device, flatten=False):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
dtype (torch.dtype): Type of points.
device (torch.device): Device of points.
Returns:
tuple: points of each image.
"""
mlvl_points = []
for i in range(len(featmap_sizes)):
mlvl_points.append(
self._get_points_single(featmap_sizes[i], self.strides[i],
dtype, device, flatten))
return mlvl_points
def aug_test(self, feats, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[ndarray]: bbox results of each class
"""
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmcv.runner import force_fp32
from mmdet.core import (anchor_inside_flags, build_anchor_generator,
build_assigner, build_bbox_coder, build_sampler,
images_to_levels, multi_apply, multiclass_nms, unmap)
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
@HEADS.register_module()
class AnchorHead(BaseDenseHead, BBoxTestMixin):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels. Used in child classes.
anchor_generator (dict): Config dict for anchor generator
bbox_coder (dict): Config of bounding box coder.
reg_decoded_bbox (bool): If true, the regression loss would be
applied directly on decoded bounding boxes, converting both
the predicted boxes and regression targets to absolute
coordinates format. Default False. It should be `True` when
using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
""" # noqa: W605
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8, 16, 32],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
clip_border=True,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)),
reg_decoded_bbox=False,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
train_cfg=None,
test_cfg=None):
super(AnchorHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
# TODO better way to determine whether sample or not
self.sampling = loss_cls['type'] not in [
'FocalLoss', 'GHMC', 'QualityFocalLoss'
]
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes + 1
if self.cls_out_channels <= 0:
raise ValueError(f'num_classes={num_classes} is too small')
self.reg_decoded_bbox = reg_decoded_bbox
self.bbox_coder = build_bbox_coder(bbox_coder)
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if self.train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
# use PseudoSampler when sampling is False
if self.sampling and hasattr(self.train_cfg, 'sampler'):
sampler_cfg = self.train_cfg.sampler
else:
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.fp16_enabled = False
self.anchor_generator = build_anchor_generator(anchor_generator)
# usually the numbers of anchors for each level are the same
# except SSD detectors
self.num_anchors = self.anchor_generator.num_base_anchors[0]
self._init_layers()
def _init_layers(self):
"""Initialize layers of the head."""
self.conv_cls = nn.Conv2d(self.in_channels,
self.num_anchors * self.cls_out_channels, 1)
self.conv_reg = nn.Conv2d(self.in_channels, self.num_anchors * 4, 1)
def init_weights(self):
"""Initialize weights of the head."""
normal_init(self.conv_cls, std=0.01)
normal_init(self.conv_reg, std=0.01)
def forward_single(self, x):
"""Forward feature of a single scale level.
Args:
x (Tensor): Features of a single scale level.
Returns:
tuple:
cls_score (Tensor): Cls scores for a single scale level \
the channels number is num_anchors * num_classes.
bbox_pred (Tensor): Box energies / deltas for a single scale \
level, the channels number is num_anchors * 4.
"""
cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
return cls_score, bbox_pred
def forward(self, feats):
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: A tuple of classification scores and bbox prediction.
- cls_scores (list[Tensor]): Classification scores for all \
scale levels, each is a 4D-tensor, the channels number \
is num_anchors * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for all \
scale levels, each is a 4D-tensor, the channels number \
is num_anchors * 4.
"""
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
device (torch.device | str): Device for returned tensors
Returns:
tuple:
anchor_list (list[Tensor]): Anchors of each image.
valid_flag_list (list[Tensor]): Valid flags of each image.
"""
num_imgs = len(img_metas)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = self.anchor_generator.valid_flags(
featmap_sizes, img_meta['pad_shape'], device)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def _get_targets_single(self,
flat_anchors,
valid_flags,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
label_channels=1,
unmap_outputs=True):
"""Compute regression and classification targets for anchors in a
single image.
Args:
flat_anchors (Tensor): Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors ,4)
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
img_meta (dict): Meta info of the image.
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
img_meta (dict): Meta info of the image.
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple:
labels_list (list[Tensor]): Labels of each level
label_weights_list (list[Tensor]): Label weights of each level
bbox_targets_list (list[Tensor]): BBox targets of each level
bbox_weights_list (list[Tensor]): BBox weights of each level
num_total_pos (int): Number of positive samples in all images
num_total_neg (int): Number of negative samples in all images
"""
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 7
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
assign_result = self.assigner.assign(
anchors, gt_bboxes, gt_bboxes_ignore,
None if self.sampling else gt_labels)
sampling_result = self.sampler.sample(assign_result, anchors,
gt_bboxes)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = anchors.new_full((num_valid_anchors, ),
self.num_classes,
dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
if not self.reg_decoded_bbox:
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
else:
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels is None:
# Only rpn gives gt_labels as None
# Foreground is the first class since v2.5.0
labels[pos_inds] = 0
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_anchors.size(0)
labels = unmap(
labels, num_total_anchors, inside_flags,
fill=self.num_classes) # fill bg label
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds, sampling_result)
def get_targets(self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True,
return_sampling_results=False):
"""Compute regression and classification targets for anchors in
multiple images.
Args:
anchor_list (list[list[Tensor]]): Multi level anchors of each
image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, 4).
valid_flag_list (list[list[Tensor]]): Multi level valid flags of
each image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, )
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
ignored.
gt_labels_list (list[Tensor]): Ground truth labels of each box.
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: Usually returns a tuple containing learning targets.
- labels_list (list[Tensor]): Labels of each level.
- label_weights_list (list[Tensor]): Label weights of each \
level.
- bbox_targets_list (list[Tensor]): BBox targets of each level.
- bbox_weights_list (list[Tensor]): BBox weights of each level.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
additional_returns: This function enables user-defined returns from
`self._get_targets_single`. These returns are currently refined
to properties at each feature map (i.e. having HxW dimension).
The results will be concatenated after the end
"""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors to a single tensor
concat_anchor_list = []
concat_valid_flag_list = []
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
concat_anchor_list.append(torch.cat(anchor_list[i]))
concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
results = multi_apply(
self._get_targets_single,
concat_anchor_list,
concat_valid_flag_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
label_channels=label_channels,
unmap_outputs=unmap_outputs)
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
rest_results = list(results[7:]) # user-added return values
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
# split targets to a list w.r.t. multiple levels
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights,
num_level_anchors)
res = (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg)
if return_sampling_results:
res = res + (sampling_results_list, )
for i, r in enumerate(rest_results): # user-added return values
rest_results[i] = images_to_levels(r, num_level_anchors)
return res + tuple(rest_results)
def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples):
"""Compute loss of a single scale level.
Args:
cls_score (Tensor): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W).
bbox_pred (Tensor): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
bbox_targets (Tensor): BBox regression targets of each anchor wight
shape (N, num_total_anchors, 4).
bbox_weights (Tensor): BBox regression loss weights of each anchor
with shape (N, num_total_anchors, 4).
num_total_samples (int): If sampling, num total samples equal to
the number of total anchors; Otherwise, it is the number of
positive anchors.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
if self.reg_decoded_bbox:
# When the regression loss (e.g. `IouLoss`, `GIouLoss`)
# is applied directly on the decoded bounding boxes, it
# decodes the already encoded coordinates to absolute format.
anchors = anchors.reshape(-1, 4)
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
avg_factor=num_total_samples)
return loss_cls, loss_bbox
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss. Default: None
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos + num_total_neg if self.sampling else num_total_pos)
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
concat_anchor_list = []
for i in range(len(anchor_list)):
concat_anchor_list.append(torch.cat(anchor_list[i]))
all_anchor_list = images_to_levels(concat_anchor_list,
num_level_anchors)
losses_cls, losses_bbox = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
all_anchor_list,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples)
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg=None,
rescale=False,
with_nms=True):
"""Transform network output for a batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is an (n, 5) tensor, where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1. The second item is a
(n,) tensor where each item is the predicted class labelof the
corresponding box.
Example:
>>> import mmcv
>>> self = AnchorHead(
>>> num_classes=9,
>>> in_channels=1,
>>> anchor_generator=dict(
>>> type='AnchorGenerator',
>>> scales=[8],
>>> ratios=[0.5, 1.0, 2.0],
>>> strides=[4,]))
>>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}]
>>> cfg = mmcv.Config(dict(
>>> score_thr=0.00,
>>> nms=dict(type='nms', iou_thr=1.0),
>>> max_per_img=10))
>>> feat = torch.rand(1, 1, 3, 3)
>>> cls_score, bbox_pred = self.forward_single(feat)
>>> # note the input lists are over different levels, not images
>>> cls_scores, bbox_preds = [cls_score], [bbox_pred]
>>> result_list = self.get_bboxes(cls_scores, bbox_preds,
>>> img_metas, cfg)
>>> det_bboxes, det_labels = result_list[0]
>>> assert len(result_list) == 1
>>> assert det_bboxes.shape[1] == 5
>>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img
"""
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
if with_nms:
# some heads don't support with_nms argument
proposals = self._get_bboxes_single(cls_score_list,
bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
else:
proposals = self._get_bboxes_single(cls_score_list,
bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale,
with_nms)
result_list.append(proposals)
return result_list
def _get_bboxes_single(self,
cls_score_list,
bbox_pred_list,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False,
with_nms=True):
"""Transform outputs for a single batch item into bbox predictions.
Args:
cls_score_list (list[Tensor]): Box scores for a single scale level
Has shape (num_anchors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas for a single
scale level with shape (num_anchors * 4, H, W).
mlvl_anchors (list[Tensor]): Box reference for a single scale level
with shape (num_total_anchors, 4).
img_shape (tuple[int]): Shape of the input image,
(height, width, 3).
scale_factor (ndarray): Scale factor of the image arange as
(w_scale, h_scale, w_scale, h_scale).
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1.
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
mlvl_bboxes = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_score_list,
bbox_pred_list, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
max_scores, _ = scores[:, :-1].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
bboxes = self.bbox_coder.decode(
anchors, bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
if self.use_sigmoid_cls:
# Add a dummy background class to the backend when using sigmoid
# remind that we set FG labels to [0, num_class-1] since mmdet v2.0
# BG cat_id: num_class
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
if with_nms:
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
else:
return mlvl_bboxes, mlvl_scores
def aug_test(self, feats, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[ndarray]: bbox results of each class
"""
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init
from mmcv.runner import force_fp32
from mmdet.core import (anchor_inside_flags, build_assigner, build_sampler,
images_to_levels, multi_apply, multiclass_nms,
reduce_mean, unmap)
from ..builder import HEADS, build_loss
from .anchor_head import AnchorHead
EPS = 1e-12
@HEADS.register_module()
class ATSSHead(AnchorHead):
"""Bridging the Gap Between Anchor-based and Anchor-free Detection via
Adaptive Training Sample Selection.
ATSS head structure is similar with FCOS, however ATSS use anchor boxes
and assign label by Adaptive Training Sample Selection instead max-iou.
https://arxiv.org/abs/1912.02424
"""
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
conv_cfg=None,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
loss_centerness=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
**kwargs):
self.stacked_convs = stacked_convs
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
super(ATSSHead, self).__init__(num_classes, in_channels, **kwargs)
self.sampling = False
if self.train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
# SSD sampling=False so use PseudoSampler
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.loss_centerness = build_loss(loss_centerness)
def _init_layers(self):
"""Initialize layers of the head."""
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.atss_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.atss_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
self.atss_centerness = nn.Conv2d(
self.feat_channels, self.num_anchors * 1, 3, padding=1)
self.scales = nn.ModuleList(
[Scale(1.0) for _ in self.anchor_generator.strides])
def init_weights(self):
"""Initialize weights of the head."""
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.atss_cls, std=0.01, bias=bias_cls)
normal_init(self.atss_reg, std=0.01)
normal_init(self.atss_centerness, std=0.01)
def forward(self, feats):
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually a tuple of classification scores and bbox prediction
cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_anchors * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_anchors * 4.
"""
return multi_apply(self.forward_single, feats, self.scales)
def forward_single(self, x, scale):
"""Forward feature of a single scale level.
Args:
x (Tensor): Features of a single scale level.
scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
the bbox prediction.
Returns:
tuple:
cls_score (Tensor): Cls scores for a single scale level
the channels number is num_anchors * num_classes.
bbox_pred (Tensor): Box energies / deltas for a single scale
level, the channels number is num_anchors * 4.
centerness (Tensor): Centerness for a single scale level, the
channel number is (N, num_anchors * 1, H, W).
"""
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
cls_score = self.atss_cls(cls_feat)
# we just follow atss, not apply exp in bbox_pred
bbox_pred = scale(self.atss_reg(reg_feat)).float()
centerness = self.atss_centerness(reg_feat)
return cls_score, bbox_pred, centerness
def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels,
label_weights, bbox_targets, num_total_samples):
"""Compute loss of a single scale level.
Args:
cls_score (Tensor): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W).
bbox_pred (Tensor): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
bbox_targets (Tensor): BBox regression targets of each anchor wight
shape (N, num_total_anchors, 4).
num_total_samples (int): Number os positive samples that is
reduced over all GPUs.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
anchors = anchors.reshape(-1, 4)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels).contiguous()
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
# classification loss
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
pos_centerness = centerness[pos_inds]
centerness_targets = self.centerness_target(
pos_anchors, pos_bbox_targets)
pos_decode_bbox_pred = self.bbox_coder.decode(
pos_anchors, pos_bbox_pred)
pos_decode_bbox_targets = self.bbox_coder.decode(
pos_anchors, pos_bbox_targets)
# regression loss
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_decode_bbox_targets,
weight=centerness_targets,
avg_factor=1.0)
# centerness loss
loss_centerness = self.loss_centerness(
pos_centerness,
centerness_targets,
avg_factor=num_total_samples)
else:
loss_bbox = bbox_pred.sum() * 0
loss_centerness = centerness.sum() * 0
centerness_targets = bbox_targets.new_tensor(0.)
return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
def loss(self,
cls_scores,
bbox_preds,
centernesses,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
centernesses (list[Tensor]): Centerness for each scale
level with shape (N, num_anchors * 1, H, W)
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor] | None): specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels)
if cls_reg_targets is None:
return None
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = reduce_mean(
torch.tensor(num_total_pos, dtype=torch.float,
device=device)).item()
num_total_samples = max(num_total_samples, 1.0)
losses_cls, losses_bbox, loss_centerness,\
bbox_avg_factor = multi_apply(
self.loss_single,
anchor_list,
cls_scores,
bbox_preds,
centernesses,
labels_list,
label_weights_list,
bbox_targets_list,
num_total_samples=num_total_samples)
bbox_avg_factor = sum(bbox_avg_factor)
bbox_avg_factor = reduce_mean(bbox_avg_factor).item()
if bbox_avg_factor < EPS:
bbox_avg_factor = 1
losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_centerness=loss_centerness)
def centerness_target(self, anchors, bbox_targets):
# only calculate pos centerness targets, otherwise there may be nan
gts = self.bbox_coder.decode(anchors, bbox_targets)
anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
l_ = anchors_cx - gts[:, 0]
t_ = anchors_cy - gts[:, 1]
r_ = gts[:, 2] - anchors_cx
b_ = gts[:, 3] - anchors_cy
left_right = torch.stack([l_, r_], dim=1)
top_bottom = torch.stack([t_, b_], dim=1)
centerness = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
(top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
assert not torch.isnan(centerness).any()
return centerness
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
def get_bboxes(self,
cls_scores,
bbox_preds,
centernesses,
img_metas,
cfg=None,
rescale=False,
with_nms=True):
"""Transform network output for a batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
with shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
centernesses (list[Tensor]): Centerness for each scale level with
shape (N, num_anchors * 1, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is an (n, 5) tensor, where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1. The second item is a
(n,) tensor where each item is the predicted class label of the
corresponding box.
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
centerness_pred_list = [
centernesses[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
centerness_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale,
with_nms)
result_list.append(proposals)
return result_list
def _get_bboxes_single(self,
cls_scores,
bbox_preds,
centernesses,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False,
with_nms=True):
"""Transform outputs for a single batch item into labeled boxes.
Args:
cls_scores (list[Tensor]): Box scores for a single scale level
with shape (num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for a single
scale level with shape (num_anchors * 4, H, W).
centernesses (list[Tensor]): Centerness for a single scale level
with shape (num_anchors * 1, H, W).
mlvl_anchors (list[Tensor]): Box reference for a single scale level
with shape (num_total_anchors, 4).
img_shape (tuple[int]): Shape of the input image,
(height, width, 3).
scale_factor (ndarray): Scale factor of the image arrange as
(w_scale, h_scale, w_scale, h_scale).
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
tuple(Tensor):
det_bboxes (Tensor): BBox predictions in shape (n, 5), where
the first 4 columns are bounding box positions
(tl_x, tl_y, br_x, br_y) and the 5-th column is a score
between 0 and 1.
det_labels (Tensor): A (n,) tensor where each item is the
predicted class label of the corresponding box.
"""
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_bboxes = []
mlvl_scores = []
mlvl_centerness = []
for cls_score, bbox_pred, centerness, anchors in zip(
cls_scores, bbox_preds, centernesses, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels).sigmoid()
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
max_scores, _ = (scores * centerness[:, None]).max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
centerness = centerness[topk_inds]
bboxes = self.bbox_coder.decode(
anchors, bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_centerness.append(centerness)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
# Add a dummy background class to the backend when using sigmoid
# remind that we set FG labels to [0, num_class-1] since mmdet v2.0
# BG cat_id: num_class
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
mlvl_centerness = torch.cat(mlvl_centerness)
if with_nms:
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes,
mlvl_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=mlvl_centerness)
return det_bboxes, det_labels
else:
return mlvl_bboxes, mlvl_scores, mlvl_centerness
def get_targets(self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True):
"""Get targets for ATSS head.
This method is almost the same as `AnchorHead.get_targets()`. Besides
returning the targets as the parent method does, it also returns the
anchors as the first element of the returned tuple.
"""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
num_level_anchors_list = [num_level_anchors] * num_imgs
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
anchor_list[i] = torch.cat(anchor_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single,
anchor_list,
valid_flag_list,
num_level_anchors_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
label_channels=label_channels,
unmap_outputs=unmap_outputs)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
# split targets to a list w.r.t. multiple levels
anchors_list = images_to_levels(all_anchors, num_level_anchors)
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights,
num_level_anchors)
return (anchors_list, labels_list, label_weights_list,
bbox_targets_list, bbox_weights_list, num_total_pos,
num_total_neg)
def _get_target_single(self,
flat_anchors,
valid_flags,
num_level_anchors,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
label_channels=1,
unmap_outputs=True):
"""Compute regression, classification targets for anchors in a single
image.
Args:
flat_anchors (Tensor): Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors ,4)
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
num_level_anchors Tensor): Number of anchors of each scale level.
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
img_meta (dict): Meta info of the image.
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: N is the number of total anchors in the image.
labels (Tensor): Labels of all anchors in the image with shape
(N,).
label_weights (Tensor): Label weights of all anchor in the
image with shape (N,).
bbox_targets (Tensor): BBox targets of all anchors in the
image with shape (N, 4).
bbox_weights (Tensor): BBox weights of all anchors in the
image with shape (N, 4)
pos_inds (Tensor): Indices of postive anchor with shape
(num_pos,).
neg_inds (Tensor): Indices of negative anchor with shape
(num_neg,).
"""
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 7
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
num_level_anchors_inside = self.get_num_level_anchors_inside(
num_level_anchors, inside_flags)
assign_result = self.assigner.assign(anchors, num_level_anchors_inside,
gt_bboxes, gt_bboxes_ignore,
gt_labels)
sampling_result = self.sampler.sample(assign_result, anchors,
gt_bboxes)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = anchors.new_full((num_valid_anchors, ),
self.num_classes,
dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
if hasattr(self, 'bbox_coder'):
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
else:
# used in VFNetHead
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels is None:
# Only rpn gives gt_labels as None
# Foreground is the first class since v2.5.0
labels[pos_inds] = 0
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_anchors.size(0)
anchors = unmap(anchors, num_total_anchors, inside_flags)
labels = unmap(
labels, num_total_anchors, inside_flags, fill=self.num_classes)
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (anchors, labels, label_weights, bbox_targets, bbox_weights,
pos_inds, neg_inds)
def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
split_inside_flags = torch.split(inside_flags, num_level_anchors)
num_level_anchors_inside = [
int(flags.sum()) for flags in split_inside_flags
]
return num_level_anchors_inside
from abc import ABCMeta, abstractmethod
import torch.nn as nn
class BaseDenseHead(nn.Module, metaclass=ABCMeta):
"""Base class for DenseHeads."""
def __init__(self):
super(BaseDenseHead, self).__init__()
@abstractmethod
def loss(self, **kwargs):
"""Compute losses of the head."""
pass
@abstractmethod
def get_bboxes(self, **kwargs):
"""Transform network output for a batch into bbox predictions."""
pass
def forward_train(self,
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=None,
proposal_cfg=None,
**kwargs):
"""
Args:
x (list[Tensor]): Features from FPN.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
Returns:
tuple:
losses: (dict[str, Tensor]): A dictionary of loss components.
proposal_list (list[Tensor]): Proposals of each image.
"""
outs = self(x)
if gt_labels is None:
loss_inputs = outs + (gt_bboxes, img_metas)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
if proposal_cfg is None:
return losses
else:
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
return losses, proposal_list
from __future__ import division
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmcv.ops import DeformConv2d
from mmdet.core import (RegionAssigner, build_assigner, build_sampler,
images_to_levels, multi_apply)
from ..builder import HEADS, build_head
from .base_dense_head import BaseDenseHead
from .rpn_head import RPNHead
class AdaptiveConv(nn.Module):
"""AdaptiveConv used to adapt the sampling location with the anchors.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the conv kernel. Default: 3
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of
the input. Default: 1
dilation (int or tuple, optional): Spacing between kernel elements.
Default: 3
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
bias (bool, optional): If set True, adds a learnable bias to the
output. Default: False.
type (str, optional): Type of adaptive conv, can be either 'offset'
(arbitrary anchors) or 'dilation' (uniform anchor).
Default: 'dilation'.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=3,
groups=1,
bias=False,
type='dilation'):
super(AdaptiveConv, self).__init__()
assert type in ['offset', 'dilation']
self.adapt_type = type
assert kernel_size == 3, 'Adaptive conv only supports kernels 3'
if self.adapt_type == 'offset':
assert stride == 1 and padding == 1 and groups == 1, \
'Addptive conv offset mode only supports padding: {1}, ' \
f'stride: {1}, groups: {1}'
self.conv = DeformConv2d(
in_channels,
out_channels,
kernel_size,
padding=padding,
stride=stride,
groups=groups,
bias=bias)
else:
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
padding=dilation,
dilation=dilation)
def init_weights(self):
"""Init weights."""
normal_init(self.conv, std=0.01)
def forward(self, x, offset):
"""Forward function."""
if self.adapt_type == 'offset':
N, _, H, W = x.shape
assert offset is not None
assert H * W == offset.shape[1]
# reshape [N, NA, 18] to (N, 18, H, W)
offset = offset.permute(0, 2, 1).reshape(N, -1, H, W)
offset = offset.contiguous()
x = self.conv(x, offset)
else:
assert offset is None
x = self.conv(x)
return x
@HEADS.register_module()
class StageCascadeRPNHead(RPNHead):
"""Stage of CascadeRPNHead.
Args:
in_channels (int): Number of channels in the input feature map.
anchor_generator (dict): anchor generator config.
adapt_cfg (dict): adaptation config.
bridged_feature (bool, optional): wheater update rpn feature.
Default: False.
with_cls (bool, optional): wheather use classification branch.
Default: True.
sampling (bool, optional): wheather use sampling. Default: True.
"""
def __init__(self,
in_channels,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[1.0],
strides=[4, 8, 16, 32, 64]),
adapt_cfg=dict(type='dilation', dilation=3),
bridged_feature=False,
with_cls=True,
sampling=True,
**kwargs):
self.with_cls = with_cls
self.anchor_strides = anchor_generator['strides']
self.anchor_scales = anchor_generator['scales']
self.bridged_feature = bridged_feature
self.adapt_cfg = adapt_cfg
super(StageCascadeRPNHead, self).__init__(
in_channels, anchor_generator=anchor_generator, **kwargs)
# override sampling and sampler
self.sampling = sampling
if self.train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
# use PseudoSampler when sampling is False
if self.sampling and hasattr(self.train_cfg, 'sampler'):
sampler_cfg = self.train_cfg.sampler
else:
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
def _init_layers(self):
"""Init layers of a CascadeRPN stage."""
self.rpn_conv = AdaptiveConv(self.in_channels, self.feat_channels,
**self.adapt_cfg)
if self.with_cls:
self.rpn_cls = nn.Conv2d(self.feat_channels,
self.num_anchors * self.cls_out_channels,
1)
self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
self.relu = nn.ReLU(inplace=True)
def init_weights(self):
"""Init weights of a CascadeRPN stage."""
self.rpn_conv.init_weights()
normal_init(self.rpn_reg, std=0.01)
if self.with_cls:
normal_init(self.rpn_cls, std=0.01)
def forward_single(self, x, offset):
"""Forward function of single scale."""
bridged_x = x
x = self.relu(self.rpn_conv(x, offset))
if self.bridged_feature:
bridged_x = x # update feature
cls_score = self.rpn_cls(x) if self.with_cls else None
bbox_pred = self.rpn_reg(x)
return bridged_x, cls_score, bbox_pred
def forward(self, feats, offset_list=None):
"""Forward function."""
if offset_list is None:
offset_list = [None for _ in range(len(feats))]
return multi_apply(self.forward_single, feats, offset_list)
def _region_targets_single(self,
anchors,
valid_flags,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
featmap_sizes,
label_channels=1):
"""Get anchor targets based on region for single level."""
assign_result = self.assigner.assign(
anchors,
valid_flags,
gt_bboxes,
img_meta,
featmap_sizes,
self.anchor_scales[0],
self.anchor_strides,
gt_bboxes_ignore=gt_bboxes_ignore,
gt_labels=None,
allowed_border=self.train_cfg.allowed_border)
flat_anchors = torch.cat(anchors)
sampling_result = self.sampler.sample(assign_result, flat_anchors,
gt_bboxes)
num_anchors = flat_anchors.shape[0]
bbox_targets = torch.zeros_like(flat_anchors)
bbox_weights = torch.zeros_like(flat_anchors)
labels = flat_anchors.new_zeros(num_anchors, dtype=torch.long)
label_weights = flat_anchors.new_zeros(num_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
if not self.reg_decoded_bbox:
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
else:
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels is None:
labels[pos_inds] = 1
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds)
def region_targets(self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
featmap_sizes,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True):
"""See :func:`StageCascadeRPNHead.get_targets`."""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
pos_inds_list, neg_inds_list) = multi_apply(
self._region_targets_single,
anchor_list,
valid_flag_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
featmap_sizes=featmap_sizes,
label_channels=label_channels)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
# split targets to a list w.r.t. multiple levels
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights,
num_level_anchors)
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg)
def get_targets(self,
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
featmap_sizes,
gt_bboxes_ignore=None,
label_channels=1):
"""Compute regression and classification targets for anchors.
Args:
anchor_list (list[list]): Multi level anchors of each image.
valid_flag_list (list[list]): Multi level valid flags of each
image.
gt_bboxes (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
featmap_sizes (list[Tensor]): Feature mapsize each level
gt_bboxes_ignore (list[Tensor]): Ignore bboxes of each images
label_channels (int): Channel of label.
Returns:
cls_reg_targets (tuple)
"""
if isinstance(self.assigner, RegionAssigner):
cls_reg_targets = self.region_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
featmap_sizes,
gt_bboxes_ignore_list=gt_bboxes_ignore,
label_channels=label_channels)
else:
cls_reg_targets = super(StageCascadeRPNHead, self).get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
label_channels=label_channels)
return cls_reg_targets
def anchor_offset(self, anchor_list, anchor_strides, featmap_sizes):
""" Get offest for deformable conv based on anchor shape
NOTE: currently support deformable kernel_size=3 and dilation=1
Args:
anchor_list (list[list[tensor])): [NI, NLVL, NA, 4] list of
multi-level anchors
anchor_strides (list[int]): anchor stride of each level
Returns:
offset_list (list[tensor]): [NLVL, NA, 2, 18]: offset of DeformConv
kernel.
"""
def _shape_offset(anchors, stride, ks=3, dilation=1):
# currently support kernel_size=3 and dilation=1
assert ks == 3 and dilation == 1
pad = (ks - 1) // 2
idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device)
yy, xx = torch.meshgrid(idx, idx) # return order matters
xx = xx.reshape(-1)
yy = yy.reshape(-1)
w = (anchors[:, 2] - anchors[:, 0]) / stride
h = (anchors[:, 3] - anchors[:, 1]) / stride
w = w / (ks - 1) - dilation
h = h / (ks - 1) - dilation
offset_x = w[:, None] * xx # (NA, ks**2)
offset_y = h[:, None] * yy # (NA, ks**2)
return offset_x, offset_y
def _ctr_offset(anchors, stride, featmap_size):
feat_h, feat_w = featmap_size
assert len(anchors) == feat_h * feat_w
x = (anchors[:, 0] + anchors[:, 2]) * 0.5
y = (anchors[:, 1] + anchors[:, 3]) * 0.5
# compute centers on feature map
x = x / stride
y = y / stride
# compute predefine centers
xx = torch.arange(0, feat_w, device=anchors.device)
yy = torch.arange(0, feat_h, device=anchors.device)
yy, xx = torch.meshgrid(yy, xx)
xx = xx.reshape(-1).type_as(x)
yy = yy.reshape(-1).type_as(y)
offset_x = x - xx # (NA, )
offset_y = y - yy # (NA, )
return offset_x, offset_y
num_imgs = len(anchor_list)
num_lvls = len(anchor_list[0])
dtype = anchor_list[0][0].dtype
device = anchor_list[0][0].device
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
offset_list = []
for i in range(num_imgs):
mlvl_offset = []
for lvl in range(num_lvls):
c_offset_x, c_offset_y = _ctr_offset(anchor_list[i][lvl],
anchor_strides[lvl],
featmap_sizes[lvl])
s_offset_x, s_offset_y = _shape_offset(anchor_list[i][lvl],
anchor_strides[lvl])
# offset = ctr_offset + shape_offset
offset_x = s_offset_x + c_offset_x[:, None]
offset_y = s_offset_y + c_offset_y[:, None]
# offset order (y0, x0, y1, x2, .., y8, x8, y9, x9)
offset = torch.stack([offset_y, offset_x], dim=-1)
offset = offset.reshape(offset.size(0), -1) # [NA, 2*ks**2]
mlvl_offset.append(offset)
offset_list.append(torch.cat(mlvl_offset)) # [totalNA, 2*ks**2]
offset_list = images_to_levels(offset_list, num_level_anchors)
return offset_list
def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples):
"""Loss function on single scale."""
# classification loss
if self.with_cls:
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
if self.reg_decoded_bbox:
# When the regression loss (e.g. `IouLoss`, `GIouLoss`)
# is applied directly on the decoded bounding boxes, it
# decodes the already encoded coordinates to absolute format.
anchors = anchors.reshape(-1, 4)
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
loss_reg = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
avg_factor=num_total_samples)
if self.with_cls:
return loss_cls, loss_reg
return None, loss_reg
def loss(self,
anchor_list,
valid_flag_list,
cls_scores,
bbox_preds,
gt_bboxes,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
anchor_list (list[list]): Multi level anchors of each image.
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss. Default: None
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
featmap_sizes,
gt_bboxes_ignore=gt_bboxes_ignore,
label_channels=label_channels)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
if self.sampling:
num_total_samples = num_total_pos + num_total_neg
else:
# 200 is hard-coded average factor,
# which follows guilded anchoring.
num_total_samples = sum([label.numel()
for label in labels_list]) / 200.0
# change per image, per level anchor_list to per_level, per_image
mlvl_anchor_list = list(zip(*anchor_list))
# concat mlvl_anchor_list
mlvl_anchor_list = [
torch.cat(anchors, dim=0) for anchors in mlvl_anchor_list
]
losses = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
mlvl_anchor_list,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples)
if self.with_cls:
return dict(loss_rpn_cls=losses[0], loss_rpn_reg=losses[1])
return dict(loss_rpn_reg=losses[1])
def get_bboxes(self,
anchor_list,
cls_scores,
bbox_preds,
img_metas,
cfg,
rescale=False):
"""Get proposal predict."""
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
anchor_list[img_id], img_shape,
scale_factor, cfg, rescale)
result_list.append(proposals)
return result_list
def refine_bboxes(self, anchor_list, bbox_preds, img_metas):
"""Refine bboxes through stages."""
num_levels = len(bbox_preds)
new_anchor_list = []
for img_id in range(len(img_metas)):
mlvl_anchors = []
for i in range(num_levels):
bbox_pred = bbox_preds[i][img_id].detach()
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
img_shape = img_metas[img_id]['img_shape']
bboxes = self.bbox_coder.decode(anchor_list[img_id][i],
bbox_pred, img_shape)
mlvl_anchors.append(bboxes)
new_anchor_list.append(mlvl_anchors)
return new_anchor_list
@HEADS.register_module()
class CascadeRPNHead(BaseDenseHead):
"""The CascadeRPNHead will predict more accurate region proposals, which is
required for two-stage detectors (such as Fast/Faster R-CNN). CascadeRPN
consists of a sequence of RPNStage to progressively improve the accuracy of
the detected proposals.
More details can be found in ``https://arxiv.org/abs/1909.06720``.
Args:
num_stages (int): number of CascadeRPN stages.
stages (list[dict]): list of configs to build the stages.
train_cfg (list[dict]): list of configs at training time each stage.
test_cfg (dict): config at testing time.
"""
def __init__(self, num_stages, stages, train_cfg, test_cfg):
super(CascadeRPNHead, self).__init__()
assert num_stages == len(stages)
self.num_stages = num_stages
self.stages = nn.ModuleList()
for i in range(len(stages)):
train_cfg_i = train_cfg[i] if train_cfg is not None else None
stages[i].update(train_cfg=train_cfg_i)
stages[i].update(test_cfg=test_cfg)
self.stages.append(build_head(stages[i]))
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def init_weights(self):
"""Init weight of CascadeRPN."""
for i in range(self.num_stages):
self.stages[i].init_weights()
def loss(self):
"""loss() is implemented in StageCascadeRPNHead."""
pass
def get_bboxes(self):
"""get_bboxes() is implemented in StageCascadeRPNHead."""
pass
def forward_train(self,
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=None,
proposal_cfg=None):
"""Forward train function."""
assert gt_labels is None, 'RPN does not require gt_labels'
featmap_sizes = [featmap.size()[-2:] for featmap in x]
device = x[0].device
anchor_list, valid_flag_list = self.stages[0].get_anchors(
featmap_sizes, img_metas, device=device)
losses = dict()
for i in range(self.num_stages):
stage = self.stages[i]
if stage.adapt_cfg['type'] == 'offset':
offset_list = stage.anchor_offset(anchor_list,
stage.anchor_strides,
featmap_sizes)
else:
offset_list = None
x, cls_score, bbox_pred = stage(x, offset_list)
rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score,
bbox_pred, gt_bboxes, img_metas)
stage_loss = stage.loss(*rpn_loss_inputs)
for name, value in stage_loss.items():
losses['s{}.{}'.format(i, name)] = value
# refine boxes
if i < self.num_stages - 1:
anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
img_metas)
if proposal_cfg is None:
return losses
else:
proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score,
bbox_pred, img_metas,
self.test_cfg)
return losses, proposal_list
def simple_test_rpn(self, x, img_metas):
"""Simple forward test function."""
featmap_sizes = [featmap.size()[-2:] for featmap in x]
device = x[0].device
anchor_list, _ = self.stages[0].get_anchors(
featmap_sizes, img_metas, device=device)
for i in range(self.num_stages):
stage = self.stages[i]
if stage.adapt_cfg['type'] == 'offset':
offset_list = stage.anchor_offset(anchor_list,
stage.anchor_strides,
featmap_sizes)
else:
offset_list = None
x, cls_score, bbox_pred = stage(x, offset_list)
if i < self.num_stages - 1:
anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
img_metas)
proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score,
bbox_pred, img_metas,
self.test_cfg)
return proposal_list
def aug_test_rpn(self, x, img_metas):
"""Augmented forward test function."""
raise NotImplementedError
import torch.nn as nn
from mmcv.cnn import ConvModule, normal_init
from mmcv.ops import DeformConv2d
from mmdet.core import multi_apply
from ..builder import HEADS, build_loss
from .corner_head import CornerHead
@HEADS.register_module()
class CentripetalHead(CornerHead):
"""Head of CentripetalNet: Pursuing High-quality Keypoint Pairs for Object
Detection.
CentripetalHead inherits from :class:`CornerHead`. It removes the
embedding branch and adds guiding shift and centripetal shift branches.
More details can be found in the `paper
<https://arxiv.org/abs/2003.09119>`_ .
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
num_feat_levels (int): Levels of feature from the previous module. 2
for HourglassNet-104 and 1 for HourglassNet-52. HourglassNet-104
outputs the final feature and intermediate supervision feature and
HourglassNet-52 only outputs the final feature. Default: 2.
corner_emb_channels (int): Channel of embedding vector. Default: 1.
train_cfg (dict | None): Training config. Useless in CornerHead,
but we keep this variable for SingleStageDetector. Default: None.
test_cfg (dict | None): Testing config of CornerHead. Default: None.
loss_heatmap (dict | None): Config of corner heatmap loss. Default:
GaussianFocalLoss.
loss_embedding (dict | None): Config of corner embedding loss. Default:
AssociativeEmbeddingLoss.
loss_offset (dict | None): Config of corner offset loss. Default:
SmoothL1Loss.
loss_guiding_shift (dict): Config of guiding shift loss. Default:
SmoothL1Loss.
loss_centripetal_shift (dict): Config of centripetal shift loss.
Default: SmoothL1Loss.
"""
def __init__(self,
*args,
centripetal_shift_channels=2,
guiding_shift_channels=2,
feat_adaption_conv_kernel=3,
loss_guiding_shift=dict(
type='SmoothL1Loss', beta=1.0, loss_weight=0.05),
loss_centripetal_shift=dict(
type='SmoothL1Loss', beta=1.0, loss_weight=1),
**kwargs):
assert centripetal_shift_channels == 2, (
'CentripetalHead only support centripetal_shift_channels == 2')
self.centripetal_shift_channels = centripetal_shift_channels
assert guiding_shift_channels == 2, (
'CentripetalHead only support guiding_shift_channels == 2')
self.guiding_shift_channels = guiding_shift_channels
self.feat_adaption_conv_kernel = feat_adaption_conv_kernel
super(CentripetalHead, self).__init__(*args, **kwargs)
self.loss_guiding_shift = build_loss(loss_guiding_shift)
self.loss_centripetal_shift = build_loss(loss_centripetal_shift)
def _init_centripetal_layers(self):
"""Initialize centripetal layers.
Including feature adaption deform convs (feat_adaption), deform offset
prediction convs (dcn_off), guiding shift (guiding_shift) and
centripetal shift ( centripetal_shift). Each branch has two parts:
prefix `tl_` for top-left and `br_` for bottom-right.
"""
self.tl_feat_adaption = nn.ModuleList()
self.br_feat_adaption = nn.ModuleList()
self.tl_dcn_offset = nn.ModuleList()
self.br_dcn_offset = nn.ModuleList()
self.tl_guiding_shift = nn.ModuleList()
self.br_guiding_shift = nn.ModuleList()
self.tl_centripetal_shift = nn.ModuleList()
self.br_centripetal_shift = nn.ModuleList()
for _ in range(self.num_feat_levels):
self.tl_feat_adaption.append(
DeformConv2d(self.in_channels, self.in_channels,
self.feat_adaption_conv_kernel, 1, 1))
self.br_feat_adaption.append(
DeformConv2d(self.in_channels, self.in_channels,
self.feat_adaption_conv_kernel, 1, 1))
self.tl_guiding_shift.append(
self._make_layers(
out_channels=self.guiding_shift_channels,
in_channels=self.in_channels))
self.br_guiding_shift.append(
self._make_layers(
out_channels=self.guiding_shift_channels,
in_channels=self.in_channels))
self.tl_dcn_offset.append(
ConvModule(
self.guiding_shift_channels,
self.feat_adaption_conv_kernel**2 *
self.guiding_shift_channels,
1,
bias=False,
act_cfg=None))
self.br_dcn_offset.append(
ConvModule(
self.guiding_shift_channels,
self.feat_adaption_conv_kernel**2 *
self.guiding_shift_channels,
1,
bias=False,
act_cfg=None))
self.tl_centripetal_shift.append(
self._make_layers(
out_channels=self.centripetal_shift_channels,
in_channels=self.in_channels))
self.br_centripetal_shift.append(
self._make_layers(
out_channels=self.centripetal_shift_channels,
in_channels=self.in_channels))
def _init_layers(self):
"""Initialize layers for CentripetalHead.
Including two parts: CornerHead layers and CentripetalHead layers
"""
super()._init_layers() # using _init_layers in CornerHead
self._init_centripetal_layers()
def init_weights(self):
"""Initialize weights of the head."""
super().init_weights()
for i in range(self.num_feat_levels):
normal_init(self.tl_feat_adaption[i], std=0.01)
normal_init(self.br_feat_adaption[i], std=0.01)
normal_init(self.tl_dcn_offset[i].conv, std=0.1)
normal_init(self.br_dcn_offset[i].conv, std=0.1)
_ = [x.conv.reset_parameters() for x in self.tl_guiding_shift[i]]
_ = [x.conv.reset_parameters() for x in self.br_guiding_shift[i]]
_ = [
x.conv.reset_parameters() for x in self.tl_centripetal_shift[i]
]
_ = [
x.conv.reset_parameters() for x in self.br_centripetal_shift[i]
]
def forward_single(self, x, lvl_ind):
"""Forward feature of a single level.
Args:
x (Tensor): Feature of a single level.
lvl_ind (int): Level index of current feature.
Returns:
tuple[Tensor]: A tuple of CentripetalHead's output for current
feature level. Containing the following Tensors:
- tl_heat (Tensor): Predicted top-left corner heatmap.
- br_heat (Tensor): Predicted bottom-right corner heatmap.
- tl_off (Tensor): Predicted top-left offset heatmap.
- br_off (Tensor): Predicted bottom-right offset heatmap.
- tl_guiding_shift (Tensor): Predicted top-left guiding shift
heatmap.
- br_guiding_shift (Tensor): Predicted bottom-right guiding
shift heatmap.
- tl_centripetal_shift (Tensor): Predicted top-left centripetal
shift heatmap.
- br_centripetal_shift (Tensor): Predicted bottom-right
centripetal shift heatmap.
"""
tl_heat, br_heat, _, _, tl_off, br_off, tl_pool, br_pool = super(
).forward_single(
x, lvl_ind, return_pool=True)
tl_guiding_shift = self.tl_guiding_shift[lvl_ind](tl_pool)
br_guiding_shift = self.br_guiding_shift[lvl_ind](br_pool)
tl_dcn_offset = self.tl_dcn_offset[lvl_ind](tl_guiding_shift.detach())
br_dcn_offset = self.br_dcn_offset[lvl_ind](br_guiding_shift.detach())
tl_feat_adaption = self.tl_feat_adaption[lvl_ind](tl_pool,
tl_dcn_offset)
br_feat_adaption = self.br_feat_adaption[lvl_ind](br_pool,
br_dcn_offset)
tl_centripetal_shift = self.tl_centripetal_shift[lvl_ind](
tl_feat_adaption)
br_centripetal_shift = self.br_centripetal_shift[lvl_ind](
br_feat_adaption)
result_list = [
tl_heat, br_heat, tl_off, br_off, tl_guiding_shift,
br_guiding_shift, tl_centripetal_shift, br_centripetal_shift
]
return result_list
def loss(self,
tl_heats,
br_heats,
tl_offs,
br_offs,
tl_guiding_shifts,
br_guiding_shifts,
tl_centripetal_shifts,
br_centripetal_shifts,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
tl_heats (list[Tensor]): Top-left corner heatmaps for each level
with shape (N, num_classes, H, W).
br_heats (list[Tensor]): Bottom-right corner heatmaps for each
level with shape (N, num_classes, H, W).
tl_offs (list[Tensor]): Top-left corner offsets for each level
with shape (N, corner_offset_channels, H, W).
br_offs (list[Tensor]): Bottom-right corner offsets for each level
with shape (N, corner_offset_channels, H, W).
tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
level with shape (N, guiding_shift_channels, H, W).
br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
each level with shape (N, guiding_shift_channels, H, W).
tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
for each level with shape (N, centripetal_shift_channels, H,
W).
br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
shifts for each level with shape (N,
centripetal_shift_channels, H, W).
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [left, top, right, bottom] format.
gt_labels (list[Tensor]): Class indices corresponding to each box.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components. Containing the
following losses:
- det_loss (list[Tensor]): Corner keypoint losses of all
feature levels.
- off_loss (list[Tensor]): Corner offset losses of all feature
levels.
- guiding_loss (list[Tensor]): Guiding shift losses of all
feature levels.
- centripetal_loss (list[Tensor]): Centripetal shift losses of
all feature levels.
"""
targets = self.get_targets(
gt_bboxes,
gt_labels,
tl_heats[-1].shape,
img_metas[0]['pad_shape'],
with_corner_emb=self.with_corner_emb,
with_guiding_shift=True,
with_centripetal_shift=True)
mlvl_targets = [targets for _ in range(self.num_feat_levels)]
[det_losses, off_losses, guiding_losses, centripetal_losses
] = multi_apply(self.loss_single, tl_heats, br_heats, tl_offs,
br_offs, tl_guiding_shifts, br_guiding_shifts,
tl_centripetal_shifts, br_centripetal_shifts,
mlvl_targets)
loss_dict = dict(
det_loss=det_losses,
off_loss=off_losses,
guiding_loss=guiding_losses,
centripetal_loss=centripetal_losses)
return loss_dict
def loss_single(self, tl_hmp, br_hmp, tl_off, br_off, tl_guiding_shift,
br_guiding_shift, tl_centripetal_shift,
br_centripetal_shift, targets):
"""Compute losses for single level.
Args:
tl_hmp (Tensor): Top-left corner heatmap for current level with
shape (N, num_classes, H, W).
br_hmp (Tensor): Bottom-right corner heatmap for current level with
shape (N, num_classes, H, W).
tl_off (Tensor): Top-left corner offset for current level with
shape (N, corner_offset_channels, H, W).
br_off (Tensor): Bottom-right corner offset for current level with
shape (N, corner_offset_channels, H, W).
tl_guiding_shift (Tensor): Top-left guiding shift for current level
with shape (N, guiding_shift_channels, H, W).
br_guiding_shift (Tensor): Bottom-right guiding shift for current
level with shape (N, guiding_shift_channels, H, W).
tl_centripetal_shift (Tensor): Top-left centripetal shift for
current level with shape (N, centripetal_shift_channels, H, W).
br_centripetal_shift (Tensor): Bottom-right centripetal shift for
current level with shape (N, centripetal_shift_channels, H, W).
targets (dict): Corner target generated by `get_targets`.
Returns:
tuple[torch.Tensor]: Losses of the head's differnet branches
containing the following losses:
- det_loss (Tensor): Corner keypoint loss.
- off_loss (Tensor): Corner offset loss.
- guiding_loss (Tensor): Guiding shift loss.
- centripetal_loss (Tensor): Centripetal shift loss.
"""
targets['corner_embedding'] = None
det_loss, _, _, off_loss = super().loss_single(tl_hmp, br_hmp, None,
None, tl_off, br_off,
targets)
gt_tl_guiding_shift = targets['topleft_guiding_shift']
gt_br_guiding_shift = targets['bottomright_guiding_shift']
gt_tl_centripetal_shift = targets['topleft_centripetal_shift']
gt_br_centripetal_shift = targets['bottomright_centripetal_shift']
gt_tl_heatmap = targets['topleft_heatmap']
gt_br_heatmap = targets['bottomright_heatmap']
# We only compute the offset loss at the real corner position.
# The value of real corner would be 1 in heatmap ground truth.
# The mask is computed in class agnostic mode and its shape is
# batch * 1 * width * height.
tl_mask = gt_tl_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
gt_tl_heatmap)
br_mask = gt_br_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
gt_br_heatmap)
# Guiding shift loss
tl_guiding_loss = self.loss_guiding_shift(
tl_guiding_shift,
gt_tl_guiding_shift,
tl_mask,
avg_factor=tl_mask.sum())
br_guiding_loss = self.loss_guiding_shift(
br_guiding_shift,
gt_br_guiding_shift,
br_mask,
avg_factor=br_mask.sum())
guiding_loss = (tl_guiding_loss + br_guiding_loss) / 2.0
# Centripetal shift loss
tl_centripetal_loss = self.loss_centripetal_shift(
tl_centripetal_shift,
gt_tl_centripetal_shift,
tl_mask,
avg_factor=tl_mask.sum())
br_centripetal_loss = self.loss_centripetal_shift(
br_centripetal_shift,
gt_br_centripetal_shift,
br_mask,
avg_factor=br_mask.sum())
centripetal_loss = (tl_centripetal_loss + br_centripetal_loss) / 2.0
return det_loss, off_loss, guiding_loss, centripetal_loss
def get_bboxes(self,
tl_heats,
br_heats,
tl_offs,
br_offs,
tl_guiding_shifts,
br_guiding_shifts,
tl_centripetal_shifts,
br_centripetal_shifts,
img_metas,
rescale=False,
with_nms=True):
"""Transform network output for a batch into bbox predictions.
Args:
tl_heats (list[Tensor]): Top-left corner heatmaps for each level
with shape (N, num_classes, H, W).
br_heats (list[Tensor]): Bottom-right corner heatmaps for each
level with shape (N, num_classes, H, W).
tl_offs (list[Tensor]): Top-left corner offsets for each level
with shape (N, corner_offset_channels, H, W).
br_offs (list[Tensor]): Bottom-right corner offsets for each level
with shape (N, corner_offset_channels, H, W).
tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
level with shape (N, guiding_shift_channels, H, W). Useless in
this function, we keep this arg because it's the raw output
from CentripetalHead.
br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
each level with shape (N, guiding_shift_channels, H, W).
Useless in this function, we keep this arg because it's the
raw output from CentripetalHead.
tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
for each level with shape (N, centripetal_shift_channels, H,
W).
br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
shifts for each level with shape (N,
centripetal_shift_channels, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
"""
assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(img_metas)
result_list = []
for img_id in range(len(img_metas)):
result_list.append(
self._get_bboxes_single(
tl_heats[-1][img_id:img_id + 1, :],
br_heats[-1][img_id:img_id + 1, :],
tl_offs[-1][img_id:img_id + 1, :],
br_offs[-1][img_id:img_id + 1, :],
img_metas[img_id],
tl_emb=None,
br_emb=None,
tl_centripetal_shift=tl_centripetal_shifts[-1][
img_id:img_id + 1, :],
br_centripetal_shift=br_centripetal_shifts[-1][
img_id:img_id + 1, :],
rescale=rescale,
with_nms=with_nms))
return result_list
from math import ceil, log
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, bias_init_with_prob
from mmcv.ops import CornerPool, batched_nms
from mmdet.core import multi_apply
from ..builder import HEADS, build_loss
from ..utils import gaussian_radius, gen_gaussian_target
from .base_dense_head import BaseDenseHead
class BiCornerPool(nn.Module):
"""Bidirectional Corner Pooling Module (TopLeft, BottomRight, etc.)
Args:
in_channels (int): Input channels of module.
out_channels (int): Output channels of module.
feat_channels (int): Feature channels of module.
directions (list[str]): Directions of two CornerPools.
norm_cfg (dict): Dictionary to construct and config norm layer.
"""
def __init__(self,
in_channels,
directions,
feat_channels=128,
out_channels=128,
norm_cfg=dict(type='BN', requires_grad=True)):
super(BiCornerPool, self).__init__()
self.direction1_conv = ConvModule(
in_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg)
self.direction2_conv = ConvModule(
in_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg)
self.aftpool_conv = ConvModule(
feat_channels,
out_channels,
3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=None)
self.conv1 = ConvModule(
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
self.conv2 = ConvModule(
in_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg)
self.direction1_pool = CornerPool(directions[0])
self.direction2_pool = CornerPool(directions[1])
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward features from the upstream network.
Args:
x (tensor): Input feature of BiCornerPool.
Returns:
conv2 (tensor): Output feature of BiCornerPool.
"""
direction1_conv = self.direction1_conv(x)
direction2_conv = self.direction2_conv(x)
direction1_feat = self.direction1_pool(direction1_conv)
direction2_feat = self.direction2_pool(direction2_conv)
aftpool_conv = self.aftpool_conv(direction1_feat + direction2_feat)
conv1 = self.conv1(x)
relu = self.relu(aftpool_conv + conv1)
conv2 = self.conv2(relu)
return conv2
@HEADS.register_module()
class CornerHead(BaseDenseHead):
"""Head of CornerNet: Detecting Objects as Paired Keypoints.
Code is modified from the `official github repo
<https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/
kp.py#L73>`_ .
More details can be found in the `paper
<https://arxiv.org/abs/1808.01244>`_ .
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
num_feat_levels (int): Levels of feature from the previous module. 2
for HourglassNet-104 and 1 for HourglassNet-52. Because
HourglassNet-104 outputs the final feature and intermediate
supervision feature and HourglassNet-52 only outputs the final
feature. Default: 2.
corner_emb_channels (int): Channel of embedding vector. Default: 1.
train_cfg (dict | None): Training config. Useless in CornerHead,
but we keep this variable for SingleStageDetector. Default: None.
test_cfg (dict | None): Testing config of CornerHead. Default: None.
loss_heatmap (dict | None): Config of corner heatmap loss. Default:
GaussianFocalLoss.
loss_embedding (dict | None): Config of corner embedding loss. Default:
AssociativeEmbeddingLoss.
loss_offset (dict | None): Config of corner offset loss. Default:
SmoothL1Loss.
"""
def __init__(self,
num_classes,
in_channels,
num_feat_levels=2,
corner_emb_channels=1,
train_cfg=None,
test_cfg=None,
loss_heatmap=dict(
type='GaussianFocalLoss',
alpha=2.0,
gamma=4.0,
loss_weight=1),
loss_embedding=dict(
type='AssociativeEmbeddingLoss',
pull_weight=0.25,
push_weight=0.25),
loss_offset=dict(
type='SmoothL1Loss', beta=1.0, loss_weight=1)):
super(CornerHead, self).__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.corner_emb_channels = corner_emb_channels
self.with_corner_emb = self.corner_emb_channels > 0
self.corner_offset_channels = 2
self.num_feat_levels = num_feat_levels
self.loss_heatmap = build_loss(
loss_heatmap) if loss_heatmap is not None else None
self.loss_embedding = build_loss(
loss_embedding) if loss_embedding is not None else None
self.loss_offset = build_loss(
loss_offset) if loss_offset is not None else None
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self._init_layers()
def _make_layers(self, out_channels, in_channels=256, feat_channels=256):
"""Initialize conv sequential for CornerHead."""
return nn.Sequential(
ConvModule(in_channels, feat_channels, 3, padding=1),
ConvModule(
feat_channels, out_channels, 1, norm_cfg=None, act_cfg=None))
def _init_corner_kpt_layers(self):
"""Initialize corner keypoint layers.
Including corner heatmap branch and corner offset branch. Each branch
has two parts: prefix `tl_` for top-left and `br_` for bottom-right.
"""
self.tl_pool, self.br_pool = nn.ModuleList(), nn.ModuleList()
self.tl_heat, self.br_heat = nn.ModuleList(), nn.ModuleList()
self.tl_off, self.br_off = nn.ModuleList(), nn.ModuleList()
for _ in range(self.num_feat_levels):
self.tl_pool.append(
BiCornerPool(
self.in_channels, ['top', 'left'],
out_channels=self.in_channels))
self.br_pool.append(
BiCornerPool(
self.in_channels, ['bottom', 'right'],
out_channels=self.in_channels))
self.tl_heat.append(
self._make_layers(
out_channels=self.num_classes,
in_channels=self.in_channels))
self.br_heat.append(
self._make_layers(
out_channels=self.num_classes,
in_channels=self.in_channels))
self.tl_off.append(
self._make_layers(
out_channels=self.corner_offset_channels,
in_channels=self.in_channels))
self.br_off.append(
self._make_layers(
out_channels=self.corner_offset_channels,
in_channels=self.in_channels))
def _init_corner_emb_layers(self):
"""Initialize corner embedding layers.
Only include corner embedding branch with two parts: prefix `tl_` for
top-left and `br_` for bottom-right.
"""
self.tl_emb, self.br_emb = nn.ModuleList(), nn.ModuleList()
for _ in range(self.num_feat_levels):
self.tl_emb.append(
self._make_layers(
out_channels=self.corner_emb_channels,
in_channels=self.in_channels))
self.br_emb.append(
self._make_layers(
out_channels=self.corner_emb_channels,
in_channels=self.in_channels))
def _init_layers(self):
"""Initialize layers for CornerHead.
Including two parts: corner keypoint layers and corner embedding layers
"""
self._init_corner_kpt_layers()
if self.with_corner_emb:
self._init_corner_emb_layers()
def init_weights(self):
"""Initialize weights of the head."""
bias_init = bias_init_with_prob(0.1)
for i in range(self.num_feat_levels):
# The initialization of parameters are different between nn.Conv2d
# and ConvModule. Our experiments show that using the original
# initialization of nn.Conv2d increases the final mAP by about 0.2%
self.tl_heat[i][-1].conv.reset_parameters()
self.tl_heat[i][-1].conv.bias.data.fill_(bias_init)
self.br_heat[i][-1].conv.reset_parameters()
self.br_heat[i][-1].conv.bias.data.fill_(bias_init)
self.tl_off[i][-1].conv.reset_parameters()
self.br_off[i][-1].conv.reset_parameters()
if self.with_corner_emb:
self.tl_emb[i][-1].conv.reset_parameters()
self.br_emb[i][-1].conv.reset_parameters()
def forward(self, feats):
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually a tuple of corner heatmaps, offset heatmaps and
embedding heatmaps.
- tl_heats (list[Tensor]): Top-left corner heatmaps for all
levels, each is a 4D-tensor, the channels number is
num_classes.
- br_heats (list[Tensor]): Bottom-right corner heatmaps for all
levels, each is a 4D-tensor, the channels number is
num_classes.
- tl_embs (list[Tensor] | list[None]): Top-left embedding
heatmaps for all levels, each is a 4D-tensor or None.
If not None, the channels number is corner_emb_channels.
- br_embs (list[Tensor] | list[None]): Bottom-right embedding
heatmaps for all levels, each is a 4D-tensor or None.
If not None, the channels number is corner_emb_channels.
- tl_offs (list[Tensor]): Top-left offset heatmaps for all
levels, each is a 4D-tensor. The channels number is
corner_offset_channels.
- br_offs (list[Tensor]): Bottom-right offset heatmaps for all
levels, each is a 4D-tensor. The channels number is
corner_offset_channels.
"""
lvl_ind = list(range(self.num_feat_levels))
return multi_apply(self.forward_single, feats, lvl_ind)
def forward_single(self, x, lvl_ind, return_pool=False):
"""Forward feature of a single level.
Args:
x (Tensor): Feature of a single level.
lvl_ind (int): Level index of current feature.
return_pool (bool): Return corner pool feature or not.
Returns:
tuple[Tensor]: A tuple of CornerHead's output for current feature
level. Containing the following Tensors:
- tl_heat (Tensor): Predicted top-left corner heatmap.
- br_heat (Tensor): Predicted bottom-right corner heatmap.
- tl_emb (Tensor | None): Predicted top-left embedding heatmap.
None for `self.with_corner_emb == False`.
- br_emb (Tensor | None): Predicted bottom-right embedding
heatmap. None for `self.with_corner_emb == False`.
- tl_off (Tensor): Predicted top-left offset heatmap.
- br_off (Tensor): Predicted bottom-right offset heatmap.
- tl_pool (Tensor): Top-left corner pool feature. Not must
have.
- br_pool (Tensor): Bottom-right corner pool feature. Not must
have.
"""
tl_pool = self.tl_pool[lvl_ind](x)
tl_heat = self.tl_heat[lvl_ind](tl_pool)
br_pool = self.br_pool[lvl_ind](x)
br_heat = self.br_heat[lvl_ind](br_pool)
tl_emb, br_emb = None, None
if self.with_corner_emb:
tl_emb = self.tl_emb[lvl_ind](tl_pool)
br_emb = self.br_emb[lvl_ind](br_pool)
tl_off = self.tl_off[lvl_ind](tl_pool)
br_off = self.br_off[lvl_ind](br_pool)
result_list = [tl_heat, br_heat, tl_emb, br_emb, tl_off, br_off]
if return_pool:
result_list.append(tl_pool)
result_list.append(br_pool)
return result_list
def get_targets(self,
gt_bboxes,
gt_labels,
feat_shape,
img_shape,
with_corner_emb=False,
with_guiding_shift=False,
with_centripetal_shift=False):
"""Generate corner targets.
Including corner heatmap, corner offset.
Optional: corner embedding, corner guiding shift, centripetal shift.
For CornerNet, we generate corner heatmap, corner offset and corner
embedding from this function.
For CentripetalNet, we generate corner heatmap, corner offset, guiding
shift and centripetal shift from this function.
Args:
gt_bboxes (list[Tensor]): Ground truth bboxes of each image, each
has shape (num_gt, 4).
gt_labels (list[Tensor]): Ground truth labels of each box, each has
shape (num_gt,).
feat_shape (list[int]): Shape of output feature,
[batch, channel, height, width].
img_shape (list[int]): Shape of input image,
[height, width, channel].
with_corner_emb (bool): Generate corner embedding target or not.
Default: False.
with_guiding_shift (bool): Generate guiding shift target or not.
Default: False.
with_centripetal_shift (bool): Generate centripetal shift target or
not. Default: False.
Returns:
dict: Ground truth of corner heatmap, corner offset, corner
embedding, guiding shift and centripetal shift. Containing the
following keys:
- topleft_heatmap (Tensor): Ground truth top-left corner
heatmap.
- bottomright_heatmap (Tensor): Ground truth bottom-right
corner heatmap.
- topleft_offset (Tensor): Ground truth top-left corner offset.
- bottomright_offset (Tensor): Ground truth bottom-right corner
offset.
- corner_embedding (list[list[list[int]]]): Ground truth corner
embedding. Not must have.
- topleft_guiding_shift (Tensor): Ground truth top-left corner
guiding shift. Not must have.
- bottomright_guiding_shift (Tensor): Ground truth bottom-right
corner guiding shift. Not must have.
- topleft_centripetal_shift (Tensor): Ground truth top-left
corner centripetal shift. Not must have.
- bottomright_centripetal_shift (Tensor): Ground truth
bottom-right corner centripetal shift. Not must have.
"""
batch_size, _, height, width = feat_shape
img_h, img_w = img_shape[:2]
width_ratio = float(width / img_w)
height_ratio = float(height / img_h)
gt_tl_heatmap = gt_bboxes[-1].new_zeros(
[batch_size, self.num_classes, height, width])
gt_br_heatmap = gt_bboxes[-1].new_zeros(
[batch_size, self.num_classes, height, width])
gt_tl_offset = gt_bboxes[-1].new_zeros([batch_size, 2, height, width])
gt_br_offset = gt_bboxes[-1].new_zeros([batch_size, 2, height, width])
if with_corner_emb:
match = []
# Guiding shift is a kind of offset, from center to corner
if with_guiding_shift:
gt_tl_guiding_shift = gt_bboxes[-1].new_zeros(
[batch_size, 2, height, width])
gt_br_guiding_shift = gt_bboxes[-1].new_zeros(
[batch_size, 2, height, width])
# Centripetal shift is also a kind of offset, from center to corner
# and normalized by log.
if with_centripetal_shift:
gt_tl_centripetal_shift = gt_bboxes[-1].new_zeros(
[batch_size, 2, height, width])
gt_br_centripetal_shift = gt_bboxes[-1].new_zeros(
[batch_size, 2, height, width])
for batch_id in range(batch_size):
# Ground truth of corner embedding per image is a list of coord set
corner_match = []
for box_id in range(len(gt_labels[batch_id])):
left, top, right, bottom = gt_bboxes[batch_id][box_id]
center_x = (left + right) / 2.0
center_y = (top + bottom) / 2.0
label = gt_labels[batch_id][box_id]
# Use coords in the feature level to generate ground truth
scale_left = left * width_ratio
scale_right = right * width_ratio
scale_top = top * height_ratio
scale_bottom = bottom * height_ratio
scale_center_x = center_x * width_ratio
scale_center_y = center_y * height_ratio
# Int coords on feature map/ground truth tensor
left_idx = int(min(scale_left, width - 1))
right_idx = int(min(scale_right, width - 1))
top_idx = int(min(scale_top, height - 1))
bottom_idx = int(min(scale_bottom, height - 1))
# Generate gaussian heatmap
scale_box_width = ceil(scale_right - scale_left)
scale_box_height = ceil(scale_bottom - scale_top)
radius = gaussian_radius((scale_box_height, scale_box_width),
min_overlap=0.3)
radius = max(0, int(radius))
gt_tl_heatmap[batch_id, label] = gen_gaussian_target(
gt_tl_heatmap[batch_id, label], [left_idx, top_idx],
radius)
gt_br_heatmap[batch_id, label] = gen_gaussian_target(
gt_br_heatmap[batch_id, label], [right_idx, bottom_idx],
radius)
# Generate corner offset
left_offset = scale_left - left_idx
top_offset = scale_top - top_idx
right_offset = scale_right - right_idx
bottom_offset = scale_bottom - bottom_idx
gt_tl_offset[batch_id, 0, top_idx, left_idx] = left_offset
gt_tl_offset[batch_id, 1, top_idx, left_idx] = top_offset
gt_br_offset[batch_id, 0, bottom_idx, right_idx] = right_offset
gt_br_offset[batch_id, 1, bottom_idx,
right_idx] = bottom_offset
# Generate corner embedding
if with_corner_emb:
corner_match.append([[top_idx, left_idx],
[bottom_idx, right_idx]])
# Generate guiding shift
if with_guiding_shift:
gt_tl_guiding_shift[batch_id, 0, top_idx,
left_idx] = scale_center_x - left_idx
gt_tl_guiding_shift[batch_id, 1, top_idx,
left_idx] = scale_center_y - top_idx
gt_br_guiding_shift[batch_id, 0, bottom_idx,
right_idx] = right_idx - scale_center_x
gt_br_guiding_shift[
batch_id, 1, bottom_idx,
right_idx] = bottom_idx - scale_center_y
# Generate centripetal shift
if with_centripetal_shift:
gt_tl_centripetal_shift[batch_id, 0, top_idx,
left_idx] = log(scale_center_x -
scale_left)
gt_tl_centripetal_shift[batch_id, 1, top_idx,
left_idx] = log(scale_center_y -
scale_top)
gt_br_centripetal_shift[batch_id, 0, bottom_idx,
right_idx] = log(scale_right -
scale_center_x)
gt_br_centripetal_shift[batch_id, 1, bottom_idx,
right_idx] = log(scale_bottom -
scale_center_y)
if with_corner_emb:
match.append(corner_match)
target_result = dict(
topleft_heatmap=gt_tl_heatmap,
topleft_offset=gt_tl_offset,
bottomright_heatmap=gt_br_heatmap,
bottomright_offset=gt_br_offset)
if with_corner_emb:
target_result.update(corner_embedding=match)
if with_guiding_shift:
target_result.update(
topleft_guiding_shift=gt_tl_guiding_shift,
bottomright_guiding_shift=gt_br_guiding_shift)
if with_centripetal_shift:
target_result.update(
topleft_centripetal_shift=gt_tl_centripetal_shift,
bottomright_centripetal_shift=gt_br_centripetal_shift)
return target_result
def loss(self,
tl_heats,
br_heats,
tl_embs,
br_embs,
tl_offs,
br_offs,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
tl_heats (list[Tensor]): Top-left corner heatmaps for each level
with shape (N, num_classes, H, W).
br_heats (list[Tensor]): Bottom-right corner heatmaps for each
level with shape (N, num_classes, H, W).
tl_embs (list[Tensor]): Top-left corner embeddings for each level
with shape (N, corner_emb_channels, H, W).
br_embs (list[Tensor]): Bottom-right corner embeddings for each
level with shape (N, corner_emb_channels, H, W).
tl_offs (list[Tensor]): Top-left corner offsets for each level
with shape (N, corner_offset_channels, H, W).
br_offs (list[Tensor]): Bottom-right corner offsets for each level
with shape (N, corner_offset_channels, H, W).
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [left, top, right, bottom] format.
gt_labels (list[Tensor]): Class indices corresponding to each box.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components. Containing the
following losses:
- det_loss (list[Tensor]): Corner keypoint losses of all
feature levels.
- pull_loss (list[Tensor]): Part one of AssociativeEmbedding
losses of all feature levels.
- push_loss (list[Tensor]): Part two of AssociativeEmbedding
losses of all feature levels.
- off_loss (list[Tensor]): Corner offset losses of all feature
levels.
"""
targets = self.get_targets(
gt_bboxes,
gt_labels,
tl_heats[-1].shape,
img_metas[0]['pad_shape'],
with_corner_emb=self.with_corner_emb)
mlvl_targets = [targets for _ in range(self.num_feat_levels)]
det_losses, pull_losses, push_losses, off_losses = multi_apply(
self.loss_single, tl_heats, br_heats, tl_embs, br_embs, tl_offs,
br_offs, mlvl_targets)
loss_dict = dict(det_loss=det_losses, off_loss=off_losses)
if self.with_corner_emb:
loss_dict.update(pull_loss=pull_losses, push_loss=push_losses)
return loss_dict
def loss_single(self, tl_hmp, br_hmp, tl_emb, br_emb, tl_off, br_off,
targets):
"""Compute losses for single level.
Args:
tl_hmp (Tensor): Top-left corner heatmap for current level with
shape (N, num_classes, H, W).
br_hmp (Tensor): Bottom-right corner heatmap for current level with
shape (N, num_classes, H, W).
tl_emb (Tensor): Top-left corner embedding for current level with
shape (N, corner_emb_channels, H, W).
br_emb (Tensor): Bottom-right corner embedding for current level
with shape (N, corner_emb_channels, H, W).
tl_off (Tensor): Top-left corner offset for current level with
shape (N, corner_offset_channels, H, W).
br_off (Tensor): Bottom-right corner offset for current level with
shape (N, corner_offset_channels, H, W).
targets (dict): Corner target generated by `get_targets`.
Returns:
tuple[torch.Tensor]: Losses of the head's differnet branches
containing the following losses:
- det_loss (Tensor): Corner keypoint loss.
- pull_loss (Tensor): Part one of AssociativeEmbedding loss.
- push_loss (Tensor): Part two of AssociativeEmbedding loss.
- off_loss (Tensor): Corner offset loss.
"""
gt_tl_hmp = targets['topleft_heatmap']
gt_br_hmp = targets['bottomright_heatmap']
gt_tl_off = targets['topleft_offset']
gt_br_off = targets['bottomright_offset']
gt_embedding = targets['corner_embedding']
# Detection loss
tl_det_loss = self.loss_heatmap(
tl_hmp.sigmoid(),
gt_tl_hmp,
avg_factor=max(1,
gt_tl_hmp.eq(1).sum()))
br_det_loss = self.loss_heatmap(
br_hmp.sigmoid(),
gt_br_hmp,
avg_factor=max(1,
gt_br_hmp.eq(1).sum()))
det_loss = (tl_det_loss + br_det_loss) / 2.0
# AssociativeEmbedding loss
if self.with_corner_emb and self.loss_embedding is not None:
pull_loss, push_loss = self.loss_embedding(tl_emb, br_emb,
gt_embedding)
else:
pull_loss, push_loss = None, None
# Offset loss
# We only compute the offset loss at the real corner position.
# The value of real corner would be 1 in heatmap ground truth.
# The mask is computed in class agnostic mode and its shape is
# batch * 1 * width * height.
tl_off_mask = gt_tl_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
gt_tl_hmp)
br_off_mask = gt_br_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
gt_br_hmp)
tl_off_loss = self.loss_offset(
tl_off,
gt_tl_off,
tl_off_mask,
avg_factor=max(1, tl_off_mask.sum()))
br_off_loss = self.loss_offset(
br_off,
gt_br_off,
br_off_mask,
avg_factor=max(1, br_off_mask.sum()))
off_loss = (tl_off_loss + br_off_loss) / 2.0
return det_loss, pull_loss, push_loss, off_loss
def get_bboxes(self,
tl_heats,
br_heats,
tl_embs,
br_embs,
tl_offs,
br_offs,
img_metas,
rescale=False,
with_nms=True):
"""Transform network output for a batch into bbox predictions.
Args:
tl_heats (list[Tensor]): Top-left corner heatmaps for each level
with shape (N, num_classes, H, W).
br_heats (list[Tensor]): Bottom-right corner heatmaps for each
level with shape (N, num_classes, H, W).
tl_embs (list[Tensor]): Top-left corner embeddings for each level
with shape (N, corner_emb_channels, H, W).
br_embs (list[Tensor]): Bottom-right corner embeddings for each
level with shape (N, corner_emb_channels, H, W).
tl_offs (list[Tensor]): Top-left corner offsets for each level
with shape (N, corner_offset_channels, H, W).
br_offs (list[Tensor]): Bottom-right corner offsets for each level
with shape (N, corner_offset_channels, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
"""
assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(img_metas)
result_list = []
for img_id in range(len(img_metas)):
result_list.append(
self._get_bboxes_single(
tl_heats[-1][img_id:img_id + 1, :],
br_heats[-1][img_id:img_id + 1, :],
tl_offs[-1][img_id:img_id + 1, :],
br_offs[-1][img_id:img_id + 1, :],
img_metas[img_id],
tl_emb=tl_embs[-1][img_id:img_id + 1, :],
br_emb=br_embs[-1][img_id:img_id + 1, :],
rescale=rescale,
with_nms=with_nms))
return result_list
def _get_bboxes_single(self,
tl_heat,
br_heat,
tl_off,
br_off,
img_meta,
tl_emb=None,
br_emb=None,
tl_centripetal_shift=None,
br_centripetal_shift=None,
rescale=False,
with_nms=True):
"""Transform outputs for a single batch item into bbox predictions.
Args:
tl_heat (Tensor): Top-left corner heatmap for current level with
shape (N, num_classes, H, W).
br_heat (Tensor): Bottom-right corner heatmap for current level
with shape (N, num_classes, H, W).
tl_off (Tensor): Top-left corner offset for current level with
shape (N, corner_offset_channels, H, W).
br_off (Tensor): Bottom-right corner offset for current level with
shape (N, corner_offset_channels, H, W).
img_meta (dict): Meta information of current image, e.g.,
image size, scaling factor, etc.
tl_emb (Tensor): Top-left corner embedding for current level with
shape (N, corner_emb_channels, H, W).
br_emb (Tensor): Bottom-right corner embedding for current level
with shape (N, corner_emb_channels, H, W).
tl_centripetal_shift: Top-left corner's centripetal shift for
current level with shape (N, 2, H, W).
br_centripetal_shift: Bottom-right corner's centripetal shift for
current level with shape (N, 2, H, W).
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
"""
if isinstance(img_meta, (list, tuple)):
img_meta = img_meta[0]
batch_bboxes, batch_scores, batch_clses = self.decode_heatmap(
tl_heat=tl_heat.sigmoid(),
br_heat=br_heat.sigmoid(),
tl_off=tl_off,
br_off=br_off,
tl_emb=tl_emb,
br_emb=br_emb,
tl_centripetal_shift=tl_centripetal_shift,
br_centripetal_shift=br_centripetal_shift,
img_meta=img_meta,
k=self.test_cfg.corner_topk,
kernel=self.test_cfg.local_maximum_kernel,
distance_threshold=self.test_cfg.distance_threshold)
if rescale:
batch_bboxes /= batch_bboxes.new_tensor(img_meta['scale_factor'])
bboxes = batch_bboxes.view([-1, 4])
scores = batch_scores.view([-1, 1])
clses = batch_clses.view([-1, 1])
idx = scores.argsort(dim=0, descending=True)
bboxes = bboxes[idx].view([-1, 4])
scores = scores[idx].view(-1)
clses = clses[idx].view(-1)
detections = torch.cat([bboxes, scores.unsqueeze(-1)], -1)
keepinds = (detections[:, -1] > -0.1)
detections = detections[keepinds]
labels = clses[keepinds]
if with_nms:
detections, labels = self._bboxes_nms(detections, labels,
self.test_cfg)
return detections, labels
def _bboxes_nms(self, bboxes, labels, cfg):
if labels.numel() == 0:
return bboxes, labels
out_bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:, -1], labels,
cfg.nms_cfg)
out_labels = labels[keep]
if len(out_bboxes) > 0:
idx = torch.argsort(out_bboxes[:, -1], descending=True)
idx = idx[:cfg.max_per_img]
out_bboxes = out_bboxes[idx]
out_labels = out_labels[idx]
return out_bboxes, out_labels
def _gather_feat(self, feat, ind, mask=None):
"""Gather feature according to index.
Args:
feat (Tensor): Target feature map.
ind (Tensor): Target coord index.
mask (Tensor | None): Mask of featuremap. Default: None.
Returns:
feat (Tensor): Gathered feature.
"""
dim = feat.size(2)
ind = ind.unsqueeze(2).repeat(1, 1, dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def _local_maximum(self, heat, kernel=3):
"""Extract local maximum pixel with given kernal.
Args:
heat (Tensor): Target heatmap.
kernel (int): Kernel size of max pooling. Default: 3.
Returns:
heat (Tensor): A heatmap where local maximum pixels maintain its
own value and other positions are 0.
"""
pad = (kernel - 1) // 2
hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
keep = (hmax == heat).float()
return heat * keep
def _transpose_and_gather_feat(self, feat, ind):
"""Transpose and gather feature according to index.
Args:
feat (Tensor): Target feature map.
ind (Tensor): Target coord index.
Returns:
feat (Tensor): Transposed and gathered feature.
"""
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = self._gather_feat(feat, ind)
return feat
def _topk(self, scores, k=20):
"""Get top k positions from heatmap.
Args:
scores (Tensor): Target heatmap with shape
[batch, num_classes, height, width].
k (int): Target number. Default: 20.
Returns:
tuple[torch.Tensor]: Scores, indexes, categories and coords of
topk keypoint. Containing following Tensors:
- topk_scores (Tensor): Max scores of each topk keypoint.
- topk_inds (Tensor): Indexes of each topk keypoint.
- topk_clses (Tensor): Categories of each topk keypoint.
- topk_ys (Tensor): Y-coord of each topk keypoint.
- topk_xs (Tensor): X-coord of each topk keypoint.
"""
batch, _, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), k)
topk_clses = topk_inds // (height * width)
topk_inds = topk_inds % (height * width)
topk_ys = topk_inds // width
topk_xs = (topk_inds % width).int().float()
return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs
def decode_heatmap(self,
tl_heat,
br_heat,
tl_off,
br_off,
tl_emb=None,
br_emb=None,
tl_centripetal_shift=None,
br_centripetal_shift=None,
img_meta=None,
k=100,
kernel=3,
distance_threshold=0.5,
num_dets=1000):
"""Transform outputs for a single batch item into raw bbox predictions.
Args:
tl_heat (Tensor): Top-left corner heatmap for current level with
shape (N, num_classes, H, W).
br_heat (Tensor): Bottom-right corner heatmap for current level
with shape (N, num_classes, H, W).
tl_off (Tensor): Top-left corner offset for current level with
shape (N, corner_offset_channels, H, W).
br_off (Tensor): Bottom-right corner offset for current level with
shape (N, corner_offset_channels, H, W).
tl_emb (Tensor | None): Top-left corner embedding for current
level with shape (N, corner_emb_channels, H, W).
br_emb (Tensor | None): Bottom-right corner embedding for current
level with shape (N, corner_emb_channels, H, W).
tl_centripetal_shift (Tensor | None): Top-left centripetal shift
for current level with shape (N, 2, H, W).
br_centripetal_shift (Tensor | None): Bottom-right centripetal
shift for current level with shape (N, 2, H, W).
img_meta (dict): Meta information of current image, e.g.,
image size, scaling factor, etc.
k (int): Get top k corner keypoints from heatmap.
kernel (int): Max pooling kernel for extract local maximum pixels.
distance_threshold (float): Distance threshold. Top-left and
bottom-right corner keypoints with feature distance less than
the threshold will be regarded as keypoints from same object.
num_dets (int): Num of raw boxes before doing nms.
Returns:
tuple[torch.Tensor]: Decoded output of CornerHead, containing the
following Tensors:
- bboxes (Tensor): Coords of each box.
- scores (Tensor): Scores of each box.
- clses (Tensor): Categories of each box.
"""
with_embedding = tl_emb is not None and br_emb is not None
with_centripetal_shift = (
tl_centripetal_shift is not None
and br_centripetal_shift is not None)
assert with_embedding + with_centripetal_shift == 1
batch, _, height, width = tl_heat.size()
inp_h, inp_w, _ = img_meta['pad_shape']
# perform nms on heatmaps
tl_heat = self._local_maximum(tl_heat, kernel=kernel)
br_heat = self._local_maximum(br_heat, kernel=kernel)
tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = self._topk(tl_heat, k=k)
br_scores, br_inds, br_clses, br_ys, br_xs = self._topk(br_heat, k=k)
# We use repeat instead of expand here because expand is a
# shallow-copy function. Thus it could cause unexpected testing result
# sometimes. Using expand will decrease about 10% mAP during testing
# compared to repeat.
tl_ys = tl_ys.view(batch, k, 1).repeat(1, 1, k)
tl_xs = tl_xs.view(batch, k, 1).repeat(1, 1, k)
br_ys = br_ys.view(batch, 1, k).repeat(1, k, 1)
br_xs = br_xs.view(batch, 1, k).repeat(1, k, 1)
tl_off = self._transpose_and_gather_feat(tl_off, tl_inds)
tl_off = tl_off.view(batch, k, 1, 2)
br_off = self._transpose_and_gather_feat(br_off, br_inds)
br_off = br_off.view(batch, 1, k, 2)
tl_xs = tl_xs + tl_off[..., 0]
tl_ys = tl_ys + tl_off[..., 1]
br_xs = br_xs + br_off[..., 0]
br_ys = br_ys + br_off[..., 1]
if with_centripetal_shift:
tl_centripetal_shift = self._transpose_and_gather_feat(
tl_centripetal_shift, tl_inds).view(batch, k, 1, 2).exp()
br_centripetal_shift = self._transpose_and_gather_feat(
br_centripetal_shift, br_inds).view(batch, 1, k, 2).exp()
tl_ctxs = tl_xs + tl_centripetal_shift[..., 0]
tl_ctys = tl_ys + tl_centripetal_shift[..., 1]
br_ctxs = br_xs - br_centripetal_shift[..., 0]
br_ctys = br_ys - br_centripetal_shift[..., 1]
# all possible boxes based on top k corners (ignoring class)
tl_xs *= (inp_w / width)
tl_ys *= (inp_h / height)
br_xs *= (inp_w / width)
br_ys *= (inp_h / height)
if with_centripetal_shift:
tl_ctxs *= (inp_w / width)
tl_ctys *= (inp_h / height)
br_ctxs *= (inp_w / width)
br_ctys *= (inp_h / height)
x_off = img_meta['border'][2]
y_off = img_meta['border'][0]
tl_xs -= x_off
tl_ys -= y_off
br_xs -= x_off
br_ys -= y_off
tl_xs *= tl_xs.gt(0.0).type_as(tl_xs)
tl_ys *= tl_ys.gt(0.0).type_as(tl_ys)
br_xs *= br_xs.gt(0.0).type_as(br_xs)
br_ys *= br_ys.gt(0.0).type_as(br_ys)
bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
area_bboxes = ((br_xs - tl_xs) * (br_ys - tl_ys)).abs()
if with_centripetal_shift:
tl_ctxs -= x_off
tl_ctys -= y_off
br_ctxs -= x_off
br_ctys -= y_off
tl_ctxs *= tl_ctxs.gt(0.0).type_as(tl_ctxs)
tl_ctys *= tl_ctys.gt(0.0).type_as(tl_ctys)
br_ctxs *= br_ctxs.gt(0.0).type_as(br_ctxs)
br_ctys *= br_ctys.gt(0.0).type_as(br_ctys)
ct_bboxes = torch.stack((tl_ctxs, tl_ctys, br_ctxs, br_ctys),
dim=3)
area_ct_bboxes = ((br_ctxs - tl_ctxs) * (br_ctys - tl_ctys)).abs()
rcentral = torch.zeros_like(ct_bboxes)
# magic nums from paper section 4.1
mu = torch.ones_like(area_bboxes) / 2.4
mu[area_bboxes > 3500] = 1 / 2.1 # large bbox have smaller mu
bboxes_center_x = (bboxes[..., 0] + bboxes[..., 2]) / 2
bboxes_center_y = (bboxes[..., 1] + bboxes[..., 3]) / 2
rcentral[..., 0] = bboxes_center_x - mu * (bboxes[..., 2] -
bboxes[..., 0]) / 2
rcentral[..., 1] = bboxes_center_y - mu * (bboxes[..., 3] -
bboxes[..., 1]) / 2
rcentral[..., 2] = bboxes_center_x + mu * (bboxes[..., 2] -
bboxes[..., 0]) / 2
rcentral[..., 3] = bboxes_center_y + mu * (bboxes[..., 3] -
bboxes[..., 1]) / 2
area_rcentral = ((rcentral[..., 2] - rcentral[..., 0]) *
(rcentral[..., 3] - rcentral[..., 1])).abs()
dists = area_ct_bboxes / area_rcentral
tl_ctx_inds = (ct_bboxes[..., 0] <= rcentral[..., 0]) | (
ct_bboxes[..., 0] >= rcentral[..., 2])
tl_cty_inds = (ct_bboxes[..., 1] <= rcentral[..., 1]) | (
ct_bboxes[..., 1] >= rcentral[..., 3])
br_ctx_inds = (ct_bboxes[..., 2] <= rcentral[..., 0]) | (
ct_bboxes[..., 2] >= rcentral[..., 2])
br_cty_inds = (ct_bboxes[..., 3] <= rcentral[..., 1]) | (
ct_bboxes[..., 3] >= rcentral[..., 3])
if with_embedding:
tl_emb = self._transpose_and_gather_feat(tl_emb, tl_inds)
tl_emb = tl_emb.view(batch, k, 1)
br_emb = self._transpose_and_gather_feat(br_emb, br_inds)
br_emb = br_emb.view(batch, 1, k)
dists = torch.abs(tl_emb - br_emb)
tl_scores = tl_scores.view(batch, k, 1).repeat(1, 1, k)
br_scores = br_scores.view(batch, 1, k).repeat(1, k, 1)
scores = (tl_scores + br_scores) / 2 # scores for all possible boxes
# tl and br should have same class
tl_clses = tl_clses.view(batch, k, 1).repeat(1, 1, k)
br_clses = br_clses.view(batch, 1, k).repeat(1, k, 1)
cls_inds = (tl_clses != br_clses)
# reject boxes based on distances
dist_inds = dists > distance_threshold
# reject boxes based on widths and heights
width_inds = (br_xs <= tl_xs)
height_inds = (br_ys <= tl_ys)
scores[cls_inds] = -1
scores[width_inds] = -1
scores[height_inds] = -1
scores[dist_inds] = -1
if with_centripetal_shift:
scores[tl_ctx_inds] = -1
scores[tl_cty_inds] = -1
scores[br_ctx_inds] = -1
scores[br_cty_inds] = -1
scores = scores.view(batch, -1)
scores, inds = torch.topk(scores, num_dets)
scores = scores.unsqueeze(2)
bboxes = bboxes.view(batch, -1, 4)
bboxes = self._gather_feat(bboxes, inds)
clses = tl_clses.contiguous().view(batch, -1, 1)
clses = self._gather_feat(clses, inds).float()
return bboxes, scores, clses
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment