fpn.py 5.03 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn
import torch.nn.functional as F
ThangVu's avatar
ThangVu committed
3
from mmcv.cnn import xavier_init
Kai Chen's avatar
Kai Chen committed
4

Cao Yuhang's avatar
Cao Yuhang committed
5
from mmdet.core import auto_fp16
Kai Chen's avatar
Kai Chen committed
6
from ..registry import NECKS
7
from ..utils import ConvModule
Kai Chen's avatar
Kai Chen committed
8
9


Kai Chen's avatar
Kai Chen committed
10
@NECKS.register_module
Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
17
18
19
class FPN(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 num_outs,
                 start_level=0,
                 end_level=-1,
                 add_extra_convs=False,
20
                 extra_convs_on_inputs=True,
21
                 relu_before_extra_convs=False,
22
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
23
                 norm_cfg=None,
Kai Chen's avatar
Kai Chen committed
24
25
26
27
28
29
30
31
                 activation=None):
        super(FPN, self).__init__()
        assert isinstance(in_channels, list)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_ins = len(in_channels)
        self.num_outs = num_outs
        self.activation = activation
32
        self.relu_before_extra_convs = relu_before_extra_convs
Cao Yuhang's avatar
Cao Yuhang committed
33
        self.fp16_enabled = False
Kai Chen's avatar
Kai Chen committed
34
35
36
37
38
39
40
41
42
43
44
45

        if end_level == -1:
            self.backbone_end_level = self.num_ins
            assert num_outs >= self.num_ins - start_level
        else:
            # if end_level < inputs, no extra level is allowed
            self.backbone_end_level = end_level
            assert end_level <= len(in_channels)
            assert num_outs == end_level - start_level
        self.start_level = start_level
        self.end_level = end_level
        self.add_extra_convs = add_extra_convs
46
        self.extra_convs_on_inputs = extra_convs_on_inputs
Kai Chen's avatar
Kai Chen committed
47
48
49
50
51
52
53
54
55

        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()

        for i in range(self.start_level, self.backbone_end_level):
            l_conv = ConvModule(
                in_channels[i],
                out_channels,
                1,
56
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
57
                norm_cfg=norm_cfg,
Kai Chen's avatar
Kai Chen committed
58
59
60
61
62
63
64
                activation=self.activation,
                inplace=False)
            fpn_conv = ConvModule(
                out_channels,
                out_channels,
                3,
                padding=1,
65
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
66
                norm_cfg=norm_cfg,
Kai Chen's avatar
Kai Chen committed
67
68
69
70
71
72
73
74
75
76
                activation=self.activation,
                inplace=False)

            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)

        # add extra conv layers (e.g., RetinaNet)
        extra_levels = num_outs - self.backbone_end_level + self.start_level
        if add_extra_convs and extra_levels >= 1:
            for i in range(extra_levels):
77
78
79
80
                if i == 0 and self.extra_convs_on_inputs:
                    in_channels = self.in_channels[self.backbone_end_level - 1]
                else:
                    in_channels = out_channels
Kai Chen's avatar
Kai Chen committed
81
82
83
84
85
86
                extra_fpn_conv = ConvModule(
                    in_channels,
                    out_channels,
                    3,
                    stride=2,
                    padding=1,
Kai Chen's avatar
Kai Chen committed
87
88
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
Kai Chen's avatar
Kai Chen committed
89
90
91
92
93
94
95
96
97
98
                    activation=self.activation,
                    inplace=False)
                self.fpn_convs.append(extra_fpn_conv)

    # default init_weights for conv(msra) and norm in ConvModule
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform')

Cao Yuhang's avatar
Cao Yuhang committed
99
    @auto_fp16()
Kai Chen's avatar
Kai Chen committed
100
101
102
103
104
105
106
107
108
109
110
111
    def forward(self, inputs):
        assert len(inputs) == len(self.in_channels)

        # build laterals
        laterals = [
            lateral_conv(inputs[i + self.start_level])
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
Kai Chen's avatar
Kai Chen committed
112
            laterals[i - 1] += F.interpolate(
Kai Chen's avatar
Kai Chen committed
113
114
115
116
117
118
119
120
121
                laterals[i], scale_factor=2, mode='nearest')

        # build outputs
        # part 1: from original levels
        outs = [
            self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
        ]
        # part 2: add extra levels
        if self.num_outs > len(outs):
Kai Chen's avatar
Kai Chen committed
122
123
            # use max pool to get more levels on top of outputs
            # (e.g., Faster R-CNN, Mask R-CNN)
Kai Chen's avatar
Kai Chen committed
124
125
126
127
128
            if not self.add_extra_convs:
                for i in range(self.num_outs - used_backbone_levels):
                    outs.append(F.max_pool2d(outs[-1], 1, stride=2))
            # add conv layers on top of original feature maps (RetinaNet)
            else:
129
130
131
132
133
                if self.extra_convs_on_inputs:
                    orig = inputs[self.backbone_end_level - 1]
                    outs.append(self.fpn_convs[used_backbone_levels](orig))
                else:
                    outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
Kai Chen's avatar
Kai Chen committed
134
                for i in range(used_backbone_levels + 1, self.num_outs):
135
136
137
138
                    if self.relu_before_extra_convs:
                        outs.append(self.fpn_convs[i](F.relu(outs[-1])))
                    else:
                        outs.append(self.fpn_convs[i](outs[-1]))
Kai Chen's avatar
Kai Chen committed
139
        return tuple(outs)