second_fpn.py 3.38 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import numpy as np
zhangwenwei's avatar
zhangwenwei committed
3
import torch
4
5
from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
from mmcv.runner import BaseModule, auto_fp16
zhangwenwei's avatar
zhangwenwei committed
6
from torch import nn as nn
zhangwenwei's avatar
zhangwenwei committed
7

zhangwenwei's avatar
zhangwenwei committed
8
from mmdet.models import NECKS
zhangwenwei's avatar
zhangwenwei committed
9
10


11
@NECKS.register_module()
12
class SECONDFPN(BaseModule):
zhangwenwei's avatar
zhangwenwei committed
13
    """FPN used in SECOND/PointPillars/PartA2/MVXNet.
zhangwenwei's avatar
zhangwenwei committed
14
15

    Args:
16
17
18
19
20
21
22
23
        in_channels (list[int]): Input channels of multi-scale feature maps.
        out_channels (list[int]): Output channels of feature maps.
        upsample_strides (list[int]): Strides used to upsample the
            feature maps.
        norm_cfg (dict): Config dict of normalization layers.
        upsample_cfg (dict): Config dict of upsample layers.
        conv_cfg (dict): Config dict of conv layers.
        use_conv_for_no_stride (bool): Whether to use conv when stride is 1.
zhangwenwei's avatar
zhangwenwei committed
24
25
26
27
    """

    def __init__(self,
                 in_channels=[128, 128, 256],
zhangwenwei's avatar
zhangwenwei committed
28
                 out_channels=[256, 256, 256],
zhangwenwei's avatar
zhangwenwei committed
29
                 upsample_strides=[1, 2, 4],
zhangwenwei's avatar
zhangwenwei committed
30
                 norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
31
32
                 upsample_cfg=dict(type='deconv', bias=False),
                 conv_cfg=dict(type='Conv2d', bias=False),
33
34
                 use_conv_for_no_stride=False,
                 init_cfg=None):
zhangwenwei's avatar
zhangwenwei committed
35
36
        # if for GroupNorm,
        # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
37
        super(SECONDFPN, self).__init__(init_cfg=init_cfg)
zhangwenwei's avatar
zhangwenwei committed
38
        assert len(out_channels) == len(upsample_strides) == len(in_channels)
zhangwenwei's avatar
zhangwenwei committed
39
        self.in_channels = in_channels
zhangwenwei's avatar
zhangwenwei committed
40
        self.out_channels = out_channels
41
        self.fp16_enabled = False
zhangwenwei's avatar
zhangwenwei committed
42
43

        deblocks = []
zhangwenwei's avatar
zhangwenwei committed
44
        for i, out_channel in enumerate(out_channels):
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
            stride = upsample_strides[i]
            if stride > 1 or (stride == 1 and not use_conv_for_no_stride):
                upsample_layer = build_upsample_layer(
                    upsample_cfg,
                    in_channels=in_channels[i],
                    out_channels=out_channel,
                    kernel_size=upsample_strides[i],
                    stride=upsample_strides[i])
            else:
                stride = np.round(1 / stride).astype(np.int64)
                upsample_layer = build_conv_layer(
                    conv_cfg,
                    in_channels=in_channels[i],
                    out_channels=out_channel,
                    kernel_size=stride,
                    stride=stride)

zhangwenwei's avatar
zhangwenwei committed
62
63
64
            deblock = nn.Sequential(upsample_layer,
                                    build_norm_layer(norm_cfg, out_channel)[1],
                                    nn.ReLU(inplace=True))
zhangwenwei's avatar
zhangwenwei committed
65
66
67
            deblocks.append(deblock)
        self.deblocks = nn.ModuleList(deblocks)

68
69
70
71
72
        if init_cfg is None:
            self.init_cfg = [
                dict(type='Kaiming', layer='ConvTranspose2d'),
                dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)
            ]
zhangwenwei's avatar
zhangwenwei committed
73

74
    @auto_fp16()
zhangwenwei's avatar
zhangwenwei committed
75
    def forward(self, x):
zhangwenwei's avatar
zhangwenwei committed
76
        """Forward function.
zhangwenwei's avatar
zhangwenwei committed
77

zhangwenwei's avatar
zhangwenwei committed
78
79
        Args:
            x (torch.Tensor): 4D Tensor in (N, C, H, W) shape.
zhangwenwei's avatar
zhangwenwei committed
80

zhangwenwei's avatar
zhangwenwei committed
81
82
83
        Returns:
            list[torch.Tensor]: Multi-level feature maps.
        """
zhangwenwei's avatar
zhangwenwei committed
84
85
        assert len(x) == len(self.in_channels)
        ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
zhangwenwei's avatar
zhangwenwei committed
86
87

        if len(ups) > 1:
zhangwenwei's avatar
zhangwenwei committed
88
            out = torch.cat(ups, dim=1)
zhangwenwei's avatar
zhangwenwei committed
89
        else:
zhangwenwei's avatar
zhangwenwei committed
90
91
            out = ups[0]
        return [out]