Commit 323bad26 authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'refacor-sec-backbone' into 'master'

Refactor SECOND Backbone

See merge request open-mmlab/mmdet.3d!26
parents 868c5fab 945edf21
......@@ -16,7 +16,7 @@ before_script:
.linting_template: &linting_template_def
stage: linting
script:
- pip install flake8 yapf isort
- pip install flake8==3.7.9 yapf isort
- flake8 .
- isort -rc --check-only --diff mmdet3d/ tools/ tests/
- yapf -r -d mmdet3d/ tools/ tests/ configs/
......
......@@ -57,7 +57,7 @@ model = dict(
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
num_filters=[128, 256],
out_channels=[128, 256],
),
pts_neck=dict(
type='SECONDFPN',
......
......@@ -28,7 +28,7 @@ model = dict(
in_channels=64,
layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2],
num_filters=[64, 128, 256],
out_channels=[64, 128, 256],
),
neck=dict(
type='SECONDFPN',
......
......@@ -26,7 +26,7 @@ model = dict(
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
num_filters=[128, 256],
out_channels=[128, 256],
),
neck=dict(
type='SECONDFPN',
......
......@@ -26,7 +26,7 @@ model = dict(
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
num_filters=[128, 256],
out_channels=[128, 256],
),
neck=dict(
type='SECONDFPN',
......
......@@ -22,7 +22,7 @@ model = dict(
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
num_filters=[128, 256]),
out_channels=[128, 256]),
neck=dict(
type='SECONDFPN',
in_channels=[128, 256],
......
......@@ -27,7 +27,7 @@ model = dict(
in_channels=64,
layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2],
num_filters=[64, 128, 256],
out_channels=[64, 128, 256],
),
neck=dict(
type='SECONDFPN',
......
......@@ -26,7 +26,7 @@ model = dict(
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
num_filters=[128, 256],
out_channels=[128, 256],
),
neck=dict(
type='SECONDFPN',
......
......@@ -34,7 +34,7 @@ model = dict(
norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01),
layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2],
num_filters=[64, 128, 256],
out_channels=[64, 128, 256],
),
pts_neck=dict(
type='SECONDFPN',
......
from functools import partial
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import load_checkpoint
from mmdet.models import BACKBONES
class Empty(nn.Module):
def __init__(self, *args, **kwargs):
super(Empty, self).__init__()
def forward(self, *args, **kwargs):
if len(args) == 1:
return args[0]
elif len(args) == 0:
return None
return args
@BACKBONES.register_module()
class SECOND(nn.Module):
"""Compare with RPN, RPNV2 support arbitrary number of stage.
"""Backbone network for SECOND/PointPillars/PartA2/MVXNet
Args:
in_channels (int): Input channels
out_channels (list[int]): Output channels for multi-scale feature maps
layer_nums (list[int]): Number of layers in each stage
layer_strides (list[int]): Strides of each stage
norm_cfg (dict): Config dict of normalization layers
conv_cfg (dict): Config dict of convolutional layers
"""
def __init__(self,
in_channels=128,
out_channels=[128, 128, 256],
layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2],
num_filters=[128, 128, 256],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01)):
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False)):
super(SECOND, self).__init__()
assert len(layer_strides) == len(layer_nums)
assert len(num_filters) == len(layer_nums)
if norm_cfg is not None:
Conv2d = partial(nn.Conv2d, bias=False)
else:
Conv2d = partial(nn.Conv2d, bias=True)
assert len(out_channels) == len(layer_nums)
in_filters = [in_channels, *num_filters[:-1]]
in_filters = [in_channels, *out_channels[:-1]]
# note that when stride > 1, conv2d with same padding isn't
# equal to pad-conv2d. we should use pad-conv2d.
blocks = []
for i, layer_num in enumerate(layer_nums):
norm_layer = (
build_norm_layer(norm_cfg, num_filters[i])[1]
if norm_cfg is not None else Empty)
block = [
nn.ZeroPad2d(1),
Conv2d(
in_filters[i], num_filters[i], 3, stride=layer_strides[i]),
norm_layer,
build_conv_layer(
conv_cfg,
in_filters[i],
out_channels[i],
3,
stride=layer_strides[i],
padding=1),
build_norm_layer(norm_cfg, out_channels[i])[1],
nn.ReLU(inplace=True),
]
for j in range(layer_num):
norm_layer = (
build_norm_layer(norm_cfg, num_filters[i])[1]
if norm_cfg is not None else Empty)
block.append(
Conv2d(num_filters[i], num_filters[i], 3, padding=1))
block.append(norm_layer)
build_conv_layer(
conv_cfg,
out_channels[i],
out_channels[i],
3,
padding=1))
block.append(build_norm_layer(norm_cfg, out_channels[i])[1])
block.append(nn.ReLU(inplace=True))
block = nn.Sequential(*block)
......@@ -71,6 +62,8 @@ class SECOND(nn.Module):
self.blocks = nn.ModuleList(blocks)
def init_weights(self, pretrained=None):
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger
logger = get_root_logger()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment