Commit 48ce34da authored by ChaimZhu's avatar ChaimZhu Committed by Tai-Wang
Browse files

[Feature] add smoke backbone neck (#939)

* add smoke detecotor and it's backbone and neck

* typo fix

* fix typo

* add docstring

* fix typo

* fix comments

* fix comments

* fix comments

* fix typo

* fix typo

* fix

* fix typo

* fix docstring

* refine feature

* fix typo

* use Basemodule in Neck
parent 7c27cd75
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
from .dgcnn import DGCNNBackbone
from .dla import DLANet
from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG
......@@ -10,5 +11,5 @@ from .second import SECOND
__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone'
'MultiBackbone', 'DLANet'
]
import torch
import warnings
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from torch import nn
from mmdet.models.builder import BACKBONES
def dla_build_norm_layer(cfg, num_features):
"""Build normalization layer specially designed for DLANet.
Args:
cfg (dict): The norm layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate a norm layer.
- requires_grad (bool, optional): Whether stop gradient updates.
num_features (int): Number of input channels.
Returns:
Function: Build normalization layer in mmcv.
"""
cfg_ = cfg.copy()
if cfg_['type'] == 'GN':
if num_features % 32 == 0:
return build_norm_layer(cfg_, num_features)
else:
assert 'num_groups' in cfg_
cfg_['num_groups'] = cfg_['num_groups'] // 2
return build_norm_layer(cfg_, num_features)
else:
return build_norm_layer(cfg_, num_features)
class BasicBlock(BaseModule):
"""BasicBlock in DLANet.
Args:
in_channels (int): Input feature channel.
out_channels (int): Output feature channel.
norm_cfg (dict): Dictionary to construct and config
norm layer.
conv_cfg (dict): Dictionary to construct and config
conv layer.
stride (int, optional): Conv stride. Default: 1.
dilation (int, optional): Conv dilation. Default: 1.
init_cfg (dict, optional): Initialization config.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride=1,
dilation=1,
init_cfg=None):
super(BasicBlock, self).__init__(init_cfg)
self.conv1 = build_conv_layer(
conv_cfg,
in_channels,
out_channels,
3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
self.norm1 = dla_build_norm_layer(norm_cfg, out_channels)[1]
self.relu = nn.ReLU(inplace=True)
self.conv2 = build_conv_layer(
conv_cfg,
out_channels,
out_channels,
3,
stride=1,
padding=dilation,
dilation=dilation,
bias=False)
self.norm2 = dla_build_norm_layer(norm_cfg, out_channels)[1]
self.stride = stride
def forward(self, x, identity=None):
"""Forward function."""
if identity is None:
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out += identity
out = self.relu(out)
return out
class Root(BaseModule):
"""Root in DLANet.
Args:
in_channels (int): Input feature channel.
out_channels (int): Output feature channel.
norm_cfg (dict): Dictionary to construct and config
norm layer.
conv_cfg (dict): Dictionary to construct and config
conv layer.
kernel_size (int): Size of convolution kernel.
add_identity (bool): Whether to add identity in root.
init_cfg (dict, optional): Initialization config.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
kernel_size,
add_identity,
init_cfg=None):
super(Root, self).__init__(init_cfg)
self.conv = build_conv_layer(
conv_cfg,
in_channels,
out_channels,
1,
stride=1,
padding=(kernel_size - 1) // 2,
bias=False)
self.norm = dla_build_norm_layer(norm_cfg, out_channels)[1]
self.relu = nn.ReLU(inplace=True)
self.add_identity = add_identity
def forward(self, feat_list):
"""Forward function.
Args:
feat_list (list[torch.Tensor]): Output features from
multiple layers.
"""
children = feat_list
x = self.conv(torch.cat(feat_list, 1))
x = self.norm(x)
if self.add_identity:
x += children[0]
x = self.relu(x)
return x
class Tree(BaseModule):
"""Tree in DLANet.
Args:
levels (int): The level of the tree.
block (nn.Module): The block module in tree.
in_channels: Input feature channel.
out_channels: Output feature channel.
norm_cfg (dict): Dictionary to construct and config
norm layer.
conv_cfg (dict): Dictionary to construct and config
conv layer.
stride (int, optional): Convolution stride.
Default: 1.
level_root (bool, optional): whether belongs to the
root layer.
root_dim (int, optional): Root input feature channel.
root_kernel_size (int, optional): Size of root
convolution kernel. Default: 1.
dilation (int, optional): Conv dilation. Default: 1.
add_identity (bool, optional): Whether to add
identity in root. Default: False.
init_cfg (dict, optional): Initialization config.
Default: None.
"""
def __init__(self,
levels,
block,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride=1,
level_root=False,
root_dim=None,
root_kernel_size=1,
dilation=1,
add_identity=False,
init_cfg=None):
super(Tree, self).__init__(init_cfg)
if root_dim is None:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
if levels == 1:
self.root = Root(root_dim, out_channels, norm_cfg, conv_cfg,
root_kernel_size, add_identity)
self.tree1 = block(
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride,
dilation=dilation)
self.tree2 = block(
out_channels,
out_channels,
norm_cfg,
conv_cfg,
1,
dilation=dilation)
else:
self.tree1 = Tree(
levels - 1,
block,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride,
root_dim=None,
root_kernel_size=root_kernel_size,
dilation=dilation,
add_identity=add_identity)
self.tree2 = Tree(
levels - 1,
block,
out_channels,
out_channels,
norm_cfg,
conv_cfg,
root_dim=root_dim + out_channels,
root_kernel_size=root_kernel_size,
dilation=dilation,
add_identity=add_identity)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
self.project = None
self.levels = levels
if stride > 1:
self.downsample = nn.MaxPool2d(stride, stride=stride)
if in_channels != out_channels:
self.project = nn.Sequential(
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
1,
stride=1,
bias=False),
dla_build_norm_layer(norm_cfg, out_channels)[1])
def forward(self, x, identity=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
identity = self.project(bottom) if self.project else bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, identity)
if self.levels == 1:
x2 = self.tree2(x1)
feat_list = [x2, x1] + children
x = self.root(feat_list)
else:
children.append(x1)
x = self.tree2(x1, children=children)
return x
@BACKBONES.register_module()
class DLANet(BaseModule):
r"""`DLA backbone <https://arxiv.org/abs/1707.06484>`_.
Args:
depth (int): Depth of DLA. Default: 34.
in_channels (int, optional): Number of input image channels.
Default: 3.
norm_cfg (dict, optional): Dictionary to construct and config
norm layer. Default: None.
conv_cfg (dict, optional): Dictionary to construct and config
conv layer. Default: None.
layer_with_level_root (list[bool], optional): Whether to apply
level_root in each DLA layer, this is only used for
tree levels. Default: (False, True, True, True).
with_identity_root (bool, optional): Whether to add identity
in root layer. Default: False.
pretrained (str, optional): model pretrained path.
Default: None.
init_cfg (dict or list[dict], optional): Initialization
config dict. Default: None
"""
arch_settings = {
34: (BasicBlock, (1, 1, 1, 2, 2, 1), (16, 32, 64, 128, 256, 512)),
}
def __init__(self,
depth,
in_channels=3,
out_indices=(0, 1, 2, 3, 4, 5),
frozen_stages=-1,
norm_cfg=None,
conv_cfg=None,
layer_with_level_root=(False, True, True, True),
with_identity_root=False,
pretrained=None,
init_cfg=None):
super(DLANet, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalida depth {depth} for DLA')
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
block, levels, channels = self.arch_settings[depth]
self.channels = channels
self.num_levels = len(levels)
self.frozen_stages = frozen_stages
self.out_indices = out_indices
assert max(out_indices) < self.num_levels
self.base_layer = nn.Sequential(
build_conv_layer(
conv_cfg,
in_channels,
channels[0],
7,
stride=1,
padding=3,
bias=False),
dla_build_norm_layer(norm_cfg, channels[0])[1],
nn.ReLU(inplace=True))
# DLANet first uses two conv layers then uses several
# Tree layers
for i in range(2):
level_layer = self._make_conv_level(
channels[0],
channels[i],
levels[i],
norm_cfg,
conv_cfg,
stride=i + 1)
layer_name = f'level{i}'
self.add_module(layer_name, level_layer)
for i in range(2, self.num_levels):
dla_layer = Tree(
levels[i],
block,
channels[i - 1],
channels[i],
norm_cfg,
conv_cfg,
2,
level_root=layer_with_level_root[i - 2],
add_identity=with_identity_root)
layer_name = f'level{i}'
self.add_module(layer_name, dla_layer)
self._freeze_stages()
def _make_conv_level(self,
in_channels,
out_channels,
num_convs,
norm_cfg,
conv_cfg,
stride=1,
dilation=1):
"""Conv modules.
Args:
in_channels (int): Input feature channel.
out_channels (int): Output feature channel.
num_convs (int): Number of Conv module.
norm_cfg (dict): Dictionary to construct and config
norm layer.
conv_cfg (dict): Dictionary to construct and config
conv layer.
stride (int, optional): Conv stride. Default: 1.
dilation (int, optional): Conv dilation. Default: 1.
"""
modules = []
for i in range(num_convs):
modules.extend([
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
3,
stride=stride if i == 0 else 1,
padding=dilation,
bias=False,
dilation=dilation),
dla_build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True)
])
in_channels = out_channels
return nn.Sequential(*modules)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.base_layer.eval()
for param in self.base_layer.parameters():
param.requires_grad = False
for i in range(2):
m = getattr(self, f'level{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'level{i+1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def forward(self, x):
outs = []
x = self.base_layer(x)
for i in range(self.num_levels):
x = getattr(self, 'level{}'.format(i))(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.necks.fpn import FPN
from .dla_neck import DLANeck
from .imvoxel_neck import OutdoorImVoxelNeck
from .second_fpn import SECONDFPN
__all__ = ['FPN', 'SECONDFPN', 'OutdoorImVoxelNeck']
__all__ = ['FPN', 'SECONDFPN', 'OutdoorImVoxelNeck', 'DLANeck']
import math
import numpy as np
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule
from torch import nn as nn
from mmdet.models.builder import NECKS
def fill_up_weights(up):
"""Simulated bilinear upsampling kernel.
Args:
up (nn.Module): ConvTranspose2d module.
"""
w = up.weight.data
f = math.ceil(w.size(2) / 2)
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(w.size(2)):
for j in range(w.size(3)):
w[0, 0, i, j] = \
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
for c in range(1, w.size(0)):
w[c, 0, :, :] = w[0, 0, :, :]
class IDAUpsample(BaseModule):
"""Iterative Deep Aggregation (IDA) Upsampling module to upsample features
of different scales to a similar scale.
Args:
out_channels (int): Number of output channels for DeformConv.
in_channels (List[int]): List of input channels of multi-scale
feature maps.
kernel_sizes (List[int]): List of size of the convolving
kernel of different scales.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
use_dcn (bool, optional): If True, use DCNv2. Default: True.
"""
def __init__(
self,
out_channels,
in_channels,
kernel_sizes,
norm_cfg=None,
use_dcn=True,
init_cfg=None,
):
super(IDAUpsample, self).__init__(init_cfg)
self.use_dcn = use_dcn
self.projs = nn.ModuleList()
self.ups = nn.ModuleList()
self.nodes = nn.ModuleList()
for i in range(1, len(in_channels)):
in_channel = in_channels[i]
up_kernel_size = int(kernel_sizes[i])
proj = ConvModule(
in_channel,
out_channels,
3,
padding=1,
bias=True,
conv_cfg=dict(type='DCNv2') if self.use_dcn else None,
norm_cfg=norm_cfg)
node = ConvModule(
out_channels,
out_channels,
3,
padding=1,
bias=True,
conv_cfg=dict(type='DCNv2') if self.use_dcn else None,
norm_cfg=norm_cfg)
up = build_conv_layer(
dict(type='deconv'),
out_channels,
out_channels,
up_kernel_size * 2,
stride=up_kernel_size,
padding=up_kernel_size // 2,
output_padding=0,
groups=out_channels,
bias=False)
self.projs.append(proj)
self.ups.append(up)
self.nodes.append(node)
def forward(self, mlvl_features, start_level, end_level):
"""Forward function.
Args:
mlvl_features (list[torch.Tensor]): Features from multiple layers.
start_level (int): Start layer for feature upsampling.
end_level (int): End layer for feature upsampling.
"""
for i in range(start_level, end_level - 1):
upsample = self.ups[i - start_level]
project = self.projs[i - start_level]
mlvl_features[i + 1] = upsample(project(mlvl_features[i + 1]))
node = self.nodes[i - start_level]
mlvl_features[i + 1] = node(mlvl_features[i + 1] +
mlvl_features[i])
class DLAUpsample(BaseModule):
"""Deep Layer Aggregation (DLA) Upsampling module for different scales
feature extraction, upsampling and fusion, It consists of groups of
IDAupsample modules.
Args:
start_level (int): The start layer.
channels (List[int]): List of input channels of multi-scale
feature maps.
scales(List[int]): List of scale of different layers' feature.
in_channels (NoneType, optional): List of input channels of
different scales. Default: None.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
use_dcn (bool, optional): Whether to use dcn in IDAup module.
Default: True.
"""
def __init__(self,
start_level,
channels,
scales,
in_channels=None,
norm_cfg=None,
use_dcn=True,
init_cfg=None):
super(DLAUpsample, self).__init__(init_cfg)
self.start_level = start_level
if in_channels is None:
in_channels = channels
self.channels = channels
channels = list(channels)
scales = np.array(scales, dtype=int)
for i in range(len(channels) - 1):
j = -i - 2
setattr(
self, 'ida_{}'.format(i),
IDAUpsample(channels[j], in_channels[j:],
scales[j:] // scales[j], norm_cfg, use_dcn))
scales[j + 1:] = scales[j]
in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]
def forward(self, mlvl_features):
"""Forward function.
Args:
mlvl_features(list[torch.Tensor]): Features from multi-scale
layers.
Returns:
tuple[torch.Tensor]: Up-sampled features of different layers.
"""
outs = [mlvl_features[-1]]
for i in range(len(mlvl_features) - self.start_level - 1):
ida = getattr(self, 'ida_{}'.format(i))
ida(mlvl_features, len(mlvl_features) - i - 2, len(mlvl_features))
outs.insert(0, mlvl_features[-1])
return outs
@NECKS.register_module()
class DLANeck(BaseModule):
"""DLA Neck.
Args:
in_channels (list[int], optional): List of input channels
of multi-scale feature map.
start_level (int, optioanl): The scale level where upsampling
starts. Default: 2.
end_level (int, optional): The scale level where upsampling
ends. Default: 5.
norm_cfg (dict, optional): Config dict for normalization
layer. Default: None.
use_dcn (bool, optional): Whether to use dcn in IDAup module.
Default: True.
"""
def __init__(self,
in_channels=[16, 32, 64, 128, 256, 512],
start_level=2,
end_level=5,
norm_cfg=None,
use_dcn=True,
init_cfg=None):
super(DLANeck, self).__init__(init_cfg)
self.start_level = start_level
self.end_level = end_level
scales = [2**i for i in range(len(in_channels[self.start_level:]))]
self.dla_up = DLAUpsample(
start_level=self.start_level,
channels=in_channels[self.start_level:],
scales=scales,
norm_cfg=norm_cfg,
use_dcn=use_dcn)
self.ida_up = IDAUpsample(
in_channels[self.start_level],
in_channels[self.start_level:self.end_level],
[2**i for i in range(self.end_level - self.start_level)], norm_cfg,
use_dcn)
def forward(self, x):
mlvl_features = [x[i] for i in range(len(x))]
mlvl_features = self.dla_up(mlvl_features)
outs = []
for i in range(self.end_level - self.start_level):
outs.append(mlvl_features[i].clone())
self.ida_up(outs, 0, len(outs))
return outs[-1]
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
# In order to be consistent with the source code,
# reset the ConvTranspose2d initialization parameters
m.reset_parameters()
# Simulated bilinear upsampling kernel
fill_up_weights(m)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
# In order to be consistent with the source code,
# reset the Conv2d initialization parameters
m.reset_parameters()
......@@ -330,3 +330,26 @@ def test_dgcnn_gf():
assert gf_points[2].shape == torch.Size([1, 100, 64])
assert gf_points[3].shape == torch.Size([1, 100, 64])
assert fa_points.shape == torch.Size([1, 100, 1216])
def test_dla_net():
# test DLANet used in SMOKE
# test list config
cfg = dict(
type='DLANet',
depth=34,
in_channels=3,
norm_cfg=dict(type='GN', num_groups=32))
img = torch.randn((4, 3, 32, 32))
self = build_backbone(cfg)
self.init_weights()
results = self(img)
assert len(results) == 6
assert results[0].shape == torch.Size([4, 16, 32, 32])
assert results[1].shape == torch.Size([4, 32, 16, 16])
assert results[2].shape == torch.Size([4, 64, 8, 8])
assert results[3].shape == torch.Size([4, 128, 4, 4])
assert results[4].shape == torch.Size([4, 256, 2, 2])
assert results[5].shape == torch.Size([4, 512, 1, 1])
......@@ -57,3 +57,45 @@ def test_imvoxel_neck():
inputs = torch.rand([1, 64, 216, 248, 12], device='cuda')
outputs = neck(inputs)
assert outputs[0].shape == (1, 256, 248, 216)
def test_dla_neck():
s = 32
in_channels = [16, 32, 64, 128, 256, 512]
feat_sizes = [s // 2**i for i in range(6)] # [32, 16, 8, 4, 2, 1]
if torch.cuda.is_available():
# Test DLA Neck with DCNv2 on GPU
neck_cfg = dict(
type='DLANeck',
in_channels=[16, 32, 64, 128, 256, 512],
start_level=2,
end_level=5,
norm_cfg=dict(type='GN', num_groups=32))
neck = build_neck(neck_cfg)
neck.init_weights()
neck.cuda()
feats = [
torch.rand(4, in_channels[i], feat_sizes[i], feat_sizes[i]).cuda()
for i in range(len(in_channels))
]
outputs = neck(feats)
assert outputs.shape == (4, 64, 8, 8)
else:
# Test DLA Neck without DCNv2 on CPU
neck_cfg = dict(
type='DLANeck',
in_channels=[16, 32, 64, 128, 256, 512],
start_level=2,
end_level=5,
norm_cfg=dict(type='GN', num_groups=32),
use_dcn=False)
neck = build_neck(neck_cfg)
neck.init_weights()
feats = [
torch.rand(4, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
outputs = neck(feats)
assert outputs.shape == (4, 64, 8, 8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment