Commit 3337fa69 authored by zhangwenwei's avatar zhangwenwei
Browse files

Refactor SECOND FPN

parent 868c5fab
from functools import partial import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import build_norm_layer, constant_init, kaiming_init from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init,
from torch.nn import Sequential is_norm, kaiming_init)
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models import NECKS from mmdet.models import NECKS
from .. import builder from .. import builder
...@@ -12,33 +11,41 @@ from .. import builder ...@@ -12,33 +11,41 @@ from .. import builder
@NECKS.register_module() @NECKS.register_module()
class SECONDFPN(nn.Module): class SECONDFPN(nn.Module):
"""Compare with RPN, RPNV2 support arbitrary number of stage. """FPN used in SECOND/PointPillars
Args:
in_channels (list[int]): Input channels of multi-scale feature maps
out_channels (list[int]): Output channels of feature maps
upsample_strides (list[int]): Strides used to upsample the feature maps
norm_cfg (dict): Config dict of normalization layers
upsample_cfg (dict): Config dict of upsample layers
""" """
def __init__(self, def __init__(self,
use_norm=True,
in_channels=[128, 128, 256], in_channels=[128, 128, 256],
out_channels=[256, 256, 256],
upsample_strides=[1, 2, 4], upsample_strides=[1, 2, 4],
num_upsample_filters=[256, 256, 256], norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01)): upsample_cfg=dict(type='deconv', bias=False)):
# if for GroupNorm, # if for GroupNorm,
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True) # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
super(SECONDFPN, self).__init__() super(SECONDFPN, self).__init__()
assert len(num_upsample_filters) == len(upsample_strides) assert len(out_channels) == len(upsample_strides) == len(in_channels)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels
ConvTranspose2d = partial(nn.ConvTranspose2d, bias=False)
deblocks = [] deblocks = []
for i, out_channel in enumerate(out_channels):
for i, num_upsample_filter in enumerate(num_upsample_filters): norm_layer = build_norm_layer(norm_cfg, out_channel)[1]
norm_layer = build_norm_layer(norm_cfg, num_upsample_filter)[1] upsample_cfg_ = copy.deepcopy(upsample_cfg)
deblock = Sequential( upsample_cfg_.update(
ConvTranspose2d( in_channels=in_channels[i],
in_channels[i], out_channels=out_channel,
num_upsample_filter, padding=upsample_strides[i],
upsample_strides[i], stride=upsample_strides[i])
stride=upsample_strides[i]), upsample_layer = build_upsample_layer(upsample_cfg_)
deblock = nn.Sequential(
upsample_layer,
norm_layer, norm_layer,
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
...@@ -49,7 +56,7 @@ class SECONDFPN(nn.Module): ...@@ -49,7 +56,7 @@ class SECONDFPN(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
kaiming_init(m) kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): elif is_norm(m):
constant_init(m, 1) constant_init(m, 1)
def forward(self, x): def forward(self, x):
...@@ -65,30 +72,34 @@ class SECONDFPN(nn.Module): ...@@ -65,30 +72,34 @@ class SECONDFPN(nn.Module):
@NECKS.register_module() @NECKS.register_module()
class SECONDFusionFPN(SECONDFPN): class SECONDFusionFPN(SECONDFPN):
"""Compare with RPN, RPNV2 support arbitrary number of stage. """FPN used in multi-modality SECOND/PointPillars
Args:
in_channels (list[int]): Input channels of multi-scale feature maps
out_channels (list[int]): Output channels of feature maps
upsample_strides (list[int]): Strides used to upsample the feature maps
norm_cfg (dict): Config dict of normalization layers
upsample_cfg (dict): Config dict of upsample layers
downsample_rates (list[int]): The downsample rate of feature map in
comparison to the original voxelization input
fusion_layer (dict): Config dict of fusion layers
""" """
def __init__(self, def __init__(self,
use_norm=True,
in_channels=[128, 128, 256], in_channels=[128, 128, 256],
out_channels=[256, 256, 256],
upsample_strides=[1, 2, 4], upsample_strides=[1, 2, 4],
num_upsample_filters=[256, 256, 256],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
down_sample_rate=[40, 8, 8], upsample_cfg=dict(type='deconv', bias=False),
fusion_layer=None, downsample_rates=[40, 8, 8],
cat_points=False): fusion_layer=None):
super(SECONDFusionFPN, self).__init__( super(SECONDFusionFPN,
use_norm, self).__init__(in_channels, out_channels, upsample_strides,
in_channels, norm_cfg, upsample_cfg)
upsample_strides,
num_upsample_filters,
norm_cfg,
)
self.fusion_layer = None self.fusion_layer = None
if fusion_layer is not None: if fusion_layer is not None:
self.fusion_layer = builder.build_fusion_layer(fusion_layer) self.fusion_layer = builder.build_fusion_layer(fusion_layer)
self.cat_points = cat_points self.downsample_rates = downsample_rates
self.down_sample_rate = down_sample_rate
def forward(self, def forward(self,
x, x,
...@@ -107,11 +118,11 @@ class SECONDFusionFPN(SECONDFPN): ...@@ -107,11 +118,11 @@ class SECONDFusionFPN(SECONDFPN):
downsample_pts_coors = torch.zeros_like(coors) downsample_pts_coors = torch.zeros_like(coors)
downsample_pts_coors[:, 0] = coors[:, 0] downsample_pts_coors[:, 0] = coors[:, 0]
downsample_pts_coors[:, 1] = ( downsample_pts_coors[:, 1] = (
coors[:, 1] / self.down_sample_rate[0]) coors[:, 1] / self.downsample_rates[0])
downsample_pts_coors[:, 2] = ( downsample_pts_coors[:, 2] = (
coors[:, 2] / self.down_sample_rate[1]) coors[:, 2] / self.downsample_rates[1])
downsample_pts_coors[:, 3] = ( downsample_pts_coors[:, 3] = (
coors[:, 3] / self.down_sample_rate[2]) coors[:, 3] / self.downsample_rates[2])
# fusion for each point # fusion for each point
out = self.fusion_layer(img_feats, points, out, out = self.fusion_layer(img_feats, points, out,
downsample_pts_coors, img_meta) downsample_pts_coors, img_meta)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment