import torch from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init, is_norm, kaiming_init) from torch import nn as nn from mmdet.models import NECKS @NECKS.register_module() class SECONDFPN(nn.Module): """FPN used in SECOND/PointPillars/PartA2/MVXNet. Args: in_channels (list[int]): Input channels of multi-scale feature maps out_channels (list[int]): Output channels of feature maps upsample_strides (list[int]): Strides used to upsample the feature maps norm_cfg (dict): Config dict of normalization layers upsample_cfg (dict): Config dict of upsample layers """ def __init__(self, in_channels=[128, 128, 256], out_channels=[256, 256, 256], upsample_strides=[1, 2, 4], norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), upsample_cfg=dict(type='deconv', bias=False)): # if for GroupNorm, # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True) super(SECONDFPN, self).__init__() assert len(out_channels) == len(upsample_strides) == len(in_channels) self.in_channels = in_channels self.out_channels = out_channels deblocks = [] for i, out_channel in enumerate(out_channels): upsample_layer = build_upsample_layer( upsample_cfg, in_channels=in_channels[i], out_channels=out_channel, kernel_size=upsample_strides[i], stride=upsample_strides[i]) deblock = nn.Sequential(upsample_layer, build_norm_layer(norm_cfg, out_channel)[1], nn.ReLU(inplace=True)) deblocks.append(deblock) self.deblocks = nn.ModuleList(deblocks) def init_weights(self): """Initialize weights of FPN.""" for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif is_norm(m): constant_init(m, 1) def forward(self, x): """Forward function. Args: x (torch.Tensor): 4D Tensor in (N, C, H, W) shape. Returns: list[torch.Tensor]: Multi-level feature maps. """ assert len(x) == len(self.in_channels) ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)] if len(ups) > 1: out = torch.cat(ups, dim=1) else: out = ups[0] return [out]