convfc_bbox_head.py 6.75 KB
Newer Older
pangjm's avatar
pangjm committed
1
2
import torch.nn as nn

Kai Chen's avatar
Kai Chen committed
3
from ..registry import HEADS
pangjm's avatar
pangjm committed
4
from ..utils import ConvModule
5
from .bbox_head import BBoxHead
pangjm's avatar
pangjm committed
6
7


Kai Chen's avatar
Kai Chen committed
8
@HEADS.register_module
Kai Chen's avatar
Kai Chen committed
9
class ConvFCBBoxHead(BBoxHead):
10
    r"""More general bbox head, with shared conv and fc layers and two optional
pangjm's avatar
pangjm committed
11
12
13
14
15
    separated branches.

                                /-> cls convs -> cls fcs -> cls
    shared convs -> shared fcs
                                \-> reg convs -> reg fcs -> reg
Kai Chen's avatar
Kai Chen committed
16
    """  # noqa: W605
pangjm's avatar
pangjm committed
17
18
19
20
21
22
23
24
25
26

    def __init__(self,
                 num_shared_convs=0,
                 num_shared_fcs=0,
                 num_cls_convs=0,
                 num_cls_fcs=0,
                 num_reg_convs=0,
                 num_reg_fcs=0,
                 conv_out_channels=256,
                 fc_out_channels=1024,
27
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
28
                 norm_cfg=None,
pangjm's avatar
pangjm committed
29
30
                 *args,
                 **kwargs):
Kai Chen's avatar
Kai Chen committed
31
        super(ConvFCBBoxHead, self).__init__(*args, **kwargs)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
32
33
        assert (num_shared_convs + num_shared_fcs + num_cls_convs +
                num_cls_fcs + num_reg_convs + num_reg_fcs > 0)
pangjm's avatar
pangjm committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        if num_cls_convs > 0 or num_reg_convs > 0:
            assert num_shared_fcs == 0
        if not self.with_cls:
            assert num_cls_convs == 0 and num_cls_fcs == 0
        if not self.with_reg:
            assert num_reg_convs == 0 and num_reg_fcs == 0
        self.num_shared_convs = num_shared_convs
        self.num_shared_fcs = num_shared_fcs
        self.num_cls_convs = num_cls_convs
        self.num_cls_fcs = num_cls_fcs
        self.num_reg_convs = num_reg_convs
        self.num_reg_fcs = num_reg_fcs
        self.conv_out_channels = conv_out_channels
        self.fc_out_channels = fc_out_channels
48
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
49
        self.norm_cfg = norm_cfg
pangjm's avatar
pangjm committed
50
51

        # add shared convs and fcs
Kai Chen's avatar
Kai Chen committed
52
53
54
55
        self.shared_convs, self.shared_fcs, last_layer_dim = \
            self._add_conv_fc_branch(
                self.num_shared_convs, self.num_shared_fcs, self.in_channels,
                True)
pangjm's avatar
pangjm committed
56
57
58
        self.shared_out_channels = last_layer_dim

        # add cls specific branch
Kai Chen's avatar
Kai Chen committed
59
60
61
        self.cls_convs, self.cls_fcs, self.cls_last_dim = \
            self._add_conv_fc_branch(
                self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels)
pangjm's avatar
pangjm committed
62
63

        # add reg specific branch
Kai Chen's avatar
Kai Chen committed
64
65
66
        self.reg_convs, self.reg_fcs, self.reg_last_dim = \
            self._add_conv_fc_branch(
                self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels)
pangjm's avatar
pangjm committed
67
68
69

        if self.num_shared_fcs == 0 and not self.with_avg_pool:
            if self.num_cls_fcs == 0:
70
                self.cls_last_dim *= self.roi_feat_area
pangjm's avatar
pangjm committed
71
            if self.num_reg_fcs == 0:
72
                self.reg_last_dim *= self.roi_feat_area
pangjm's avatar
pangjm committed
73
74
75
76
77
78

        self.relu = nn.ReLU(inplace=True)
        # reconstruct fc_cls and fc_reg since input channels are changed
        if self.with_cls:
            self.fc_cls = nn.Linear(self.cls_last_dim, self.num_classes)
        if self.with_reg:
Jiangmiao Pang's avatar
Jiangmiao Pang committed
79
80
            out_dim_reg = (4 if self.reg_class_agnostic else 4 *
                           self.num_classes)
pangjm's avatar
pangjm committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
            self.fc_reg = nn.Linear(self.reg_last_dim, out_dim_reg)

    def _add_conv_fc_branch(self,
                            num_branch_convs,
                            num_branch_fcs,
                            in_channels,
                            is_shared=False):
        """Add shared or separable branch

        convs -> avg pool (optional) -> fcs
        """
        last_layer_dim = in_channels
        # add branch specific conv layers
        branch_convs = nn.ModuleList()
        if num_branch_convs > 0:
            for i in range(num_branch_convs):
97
98
                conv_in_channels = (
                    last_layer_dim if i == 0 else self.conv_out_channels)
pangjm's avatar
pangjm committed
99
100
101
102
103
104
                branch_convs.append(
                    ConvModule(
                        conv_in_channels,
                        self.conv_out_channels,
                        3,
                        padding=1,
105
                        conv_cfg=self.conv_cfg,
Kai Chen's avatar
Kai Chen committed
106
                        norm_cfg=self.norm_cfg))
pangjm's avatar
pangjm committed
107
108
109
110
111
112
113
114
            last_layer_dim = self.conv_out_channels
        # add branch specific fc layers
        branch_fcs = nn.ModuleList()
        if num_branch_fcs > 0:
            # for shared branch, only consider self.with_avg_pool
            # for separated branches, also consider self.num_shared_fcs
            if (is_shared
                    or self.num_shared_fcs == 0) and not self.with_avg_pool:
115
                last_layer_dim *= self.roi_feat_area
pangjm's avatar
pangjm committed
116
            for i in range(num_branch_fcs):
117
118
                fc_in_channels = (
                    last_layer_dim if i == 0 else self.fc_out_channels)
pangjm's avatar
pangjm committed
119
120
121
122
123
124
                branch_fcs.append(
                    nn.Linear(fc_in_channels, self.fc_out_channels))
            last_layer_dim = self.fc_out_channels
        return branch_convs, branch_fcs, last_layer_dim

    def init_weights(self):
Kai Chen's avatar
Kai Chen committed
125
        super(ConvFCBBoxHead, self).init_weights()
pangjm's avatar
pangjm committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        for module_list in [self.shared_fcs, self.cls_fcs, self.reg_fcs]:
            for m in module_list.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # shared part
        if self.num_shared_convs > 0:
            for conv in self.shared_convs:
                x = conv(x)

        if self.num_shared_fcs > 0:
            if self.with_avg_pool:
                x = self.avg_pool(x)
141
142
143

            x = x.flatten(1)

pangjm's avatar
pangjm committed
144
145
146
147
148
149
150
151
152
153
154
            for fc in self.shared_fcs:
                x = self.relu(fc(x))
        # separate branches
        x_cls = x
        x_reg = x

        for conv in self.cls_convs:
            x_cls = conv(x_cls)
        if x_cls.dim() > 2:
            if self.with_avg_pool:
                x_cls = self.avg_pool(x_cls)
155
            x_cls = x_cls.flatten(1)
pangjm's avatar
pangjm committed
156
157
158
159
160
161
162
163
        for fc in self.cls_fcs:
            x_cls = self.relu(fc(x_cls))

        for conv in self.reg_convs:
            x_reg = conv(x_reg)
        if x_reg.dim() > 2:
            if self.with_avg_pool:
                x_reg = self.avg_pool(x_reg)
164
            x_reg = x_reg.flatten(1)
pangjm's avatar
pangjm committed
165
166
167
168
169
170
171
172
        for fc in self.reg_fcs:
            x_reg = self.relu(fc(x_reg))

        cls_score = self.fc_cls(x_cls) if self.with_cls else None
        bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
        return cls_score, bbox_pred


Kai Chen's avatar
Kai Chen committed
173
@HEADS.register_module
Kai Chen's avatar
Kai Chen committed
174
class SharedFCBBoxHead(ConvFCBBoxHead):
pangjm's avatar
pangjm committed
175
176
177

    def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs):
        assert num_fcs >= 1
Kai Chen's avatar
Kai Chen committed
178
        super(SharedFCBBoxHead, self).__init__(
pangjm's avatar
pangjm committed
179
180
181
182
183
184
185
186
187
            num_shared_convs=0,
            num_shared_fcs=num_fcs,
            num_cls_convs=0,
            num_cls_fcs=0,
            num_reg_convs=0,
            num_reg_fcs=0,
            fc_out_channels=fc_out_channels,
            *args,
            **kwargs)