Commit b2b69dc6 authored by zhangwenwei's avatar zhangwenwei
Browse files

Refactor SECOND Backbone

parent 868c5fab
from functools import partial
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from mmdet.models import BACKBONES from mmdet.models import BACKBONES
class Empty(nn.Module):
def __init__(self, *args, **kwargs):
super(Empty, self).__init__()
def forward(self, *args, **kwargs):
if len(args) == 1:
return args[0]
elif len(args) == 0:
return None
return args
@BACKBONES.register_module() @BACKBONES.register_module()
class SECOND(nn.Module): class SECOND(nn.Module):
"""Compare with RPN, RPNV2 support arbitrary number of stage. """Backbone network for SECOND/PointPillars/MVXNet
Args:
in_channels (int): Input channels
out_channels (list[int]): Output channels for multi-scale feature maps
layer_nums (list[int]): Number of layers in each stage
layer_strides (list[int]): Strides of each stage
norm_cfg (dict): Config dict of normalization layers
conv_cfg (dict): Config dict of convolutional layers
""" """
def __init__(self, def __init__(self,
in_channels=128, in_channels=128,
out_channels=[128, 128, 256],
layer_nums=[3, 5, 5], layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2], layer_strides=[2, 2, 2],
num_filters=[128, 128, 256], norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01)): conv_cfg=dict(type='Conv2d', bias=False)):
super(SECOND, self).__init__() super(SECOND, self).__init__()
assert len(layer_strides) == len(layer_nums) assert len(layer_strides) == len(layer_nums)
assert len(num_filters) == len(layer_nums) assert len(out_channels) == len(layer_nums)
if norm_cfg is not None:
Conv2d = partial(nn.Conv2d, bias=False)
else:
Conv2d = partial(nn.Conv2d, bias=True)
in_filters = [in_channels, *num_filters[:-1]] in_filters = [in_channels, *out_channels[:-1]]
# note that when stride > 1, conv2d with same padding isn't # note that when stride > 1, conv2d with same padding isn't
# equal to pad-conv2d. we should use pad-conv2d. # equal to pad-conv2d. we should use pad-conv2d.
blocks = [] blocks = []
for i, layer_num in enumerate(layer_nums): for i, layer_num in enumerate(layer_nums):
norm_layer = (
build_norm_layer(norm_cfg, num_filters[i])[1]
if norm_cfg is not None else Empty)
block = [ block = [
nn.ZeroPad2d(1), build_conv_layer(
Conv2d( conv_cfg,
in_filters[i], num_filters[i], 3, stride=layer_strides[i]), in_filters[i],
norm_layer, out_channels[i],
3,
stride=layer_strides[i],
padding=1),
build_norm_layer(norm_cfg, out_channels[i])[1],
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
] ]
for j in range(layer_num): for j in range(layer_num):
norm_layer = (
build_norm_layer(norm_cfg, num_filters[i])[1]
if norm_cfg is not None else Empty)
block.append( block.append(
Conv2d(num_filters[i], num_filters[i], 3, padding=1)) build_conv_layer(
block.append(norm_layer) conv_cfg,
out_channels[i],
out_channels[i],
3,
padding=1))
block.append(build_norm_layer(norm_cfg, out_channels[i])[1])
block.append(nn.ReLU(inplace=True)) block.append(nn.ReLU(inplace=True))
block = nn.Sequential(*block) block = nn.Sequential(*block)
...@@ -71,6 +62,8 @@ class SECOND(nn.Module): ...@@ -71,6 +62,8 @@ class SECOND(nn.Module):
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str): if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger from mmdet3d.utils import get_root_logger
logger = get_root_logger() logger = get_root_logger()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment