convfc_bbox_head.py 6.78 KB
Newer Older
pangjm's avatar
pangjm committed
1
2
import torch.nn as nn

Kai Chen's avatar
Kai Chen committed
3
from ..registry import HEADS
pangjm's avatar
pangjm committed
4
from ..utils import ConvModule
5
from .bbox_head import BBoxHead
pangjm's avatar
pangjm committed
6
7


Kai Chen's avatar
Kai Chen committed
8
@HEADS.register_module
Kai Chen's avatar
Kai Chen committed
9
class ConvFCBBoxHead(BBoxHead):
pangjm's avatar
pangjm committed
10
11
12
13
14
15
    """More general bbox head, with shared conv and fc layers and two optional
    separated branches.

                                /-> cls convs -> cls fcs -> cls
    shared convs -> shared fcs
                                \-> reg convs -> reg fcs -> reg
Kai Chen's avatar
Kai Chen committed
16
    """  # noqa: W605
pangjm's avatar
pangjm committed
17
18
19
20
21
22
23
24
25
26

    def __init__(self,
                 num_shared_convs=0,
                 num_shared_fcs=0,
                 num_cls_convs=0,
                 num_cls_fcs=0,
                 num_reg_convs=0,
                 num_reg_fcs=0,
                 conv_out_channels=256,
                 fc_out_channels=1024,
27
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
28
                 norm_cfg=None,
pangjm's avatar
pangjm committed
29
30
                 *args,
                 **kwargs):
Kai Chen's avatar
Kai Chen committed
31
        super(ConvFCBBoxHead, self).__init__(*args, **kwargs)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
32
33
        assert (num_shared_convs + num_shared_fcs + num_cls_convs +
                num_cls_fcs + num_reg_convs + num_reg_fcs > 0)
pangjm's avatar
pangjm committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        if num_cls_convs > 0 or num_reg_convs > 0:
            assert num_shared_fcs == 0
        if not self.with_cls:
            assert num_cls_convs == 0 and num_cls_fcs == 0
        if not self.with_reg:
            assert num_reg_convs == 0 and num_reg_fcs == 0
        self.num_shared_convs = num_shared_convs
        self.num_shared_fcs = num_shared_fcs
        self.num_cls_convs = num_cls_convs
        self.num_cls_fcs = num_cls_fcs
        self.num_reg_convs = num_reg_convs
        self.num_reg_fcs = num_reg_fcs
        self.conv_out_channels = conv_out_channels
        self.fc_out_channels = fc_out_channels
48
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
49
        self.norm_cfg = norm_cfg
pangjm's avatar
pangjm committed
50
51

        # add shared convs and fcs
Kai Chen's avatar
Kai Chen committed
52
53
54
55
        self.shared_convs, self.shared_fcs, last_layer_dim = \
            self._add_conv_fc_branch(
                self.num_shared_convs, self.num_shared_fcs, self.in_channels,
                True)
pangjm's avatar
pangjm committed
56
57
58
        self.shared_out_channels = last_layer_dim

        # add cls specific branch
Kai Chen's avatar
Kai Chen committed
59
60
61
        self.cls_convs, self.cls_fcs, self.cls_last_dim = \
            self._add_conv_fc_branch(
                self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels)
pangjm's avatar
pangjm committed
62
63

        # add reg specific branch
Kai Chen's avatar
Kai Chen committed
64
65
66
        self.reg_convs, self.reg_fcs, self.reg_last_dim = \
            self._add_conv_fc_branch(
                self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels)
pangjm's avatar
pangjm committed
67
68
69

        if self.num_shared_fcs == 0 and not self.with_avg_pool:
            if self.num_cls_fcs == 0:
70
                self.cls_last_dim *= self.roi_feat_area
pangjm's avatar
pangjm committed
71
            if self.num_reg_fcs == 0:
72
                self.reg_last_dim *= self.roi_feat_area
pangjm's avatar
pangjm committed
73
74
75
76
77
78

        self.relu = nn.ReLU(inplace=True)
        # reconstruct fc_cls and fc_reg since input channels are changed
        if self.with_cls:
            self.fc_cls = nn.Linear(self.cls_last_dim, self.num_classes)
        if self.with_reg:
Jiangmiao Pang's avatar
Jiangmiao Pang committed
79
80
            out_dim_reg = (4 if self.reg_class_agnostic else 4 *
                           self.num_classes)
pangjm's avatar
pangjm committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
            self.fc_reg = nn.Linear(self.reg_last_dim, out_dim_reg)

    def _add_conv_fc_branch(self,
                            num_branch_convs,
                            num_branch_fcs,
                            in_channels,
                            is_shared=False):
        """Add shared or separable branch

        convs -> avg pool (optional) -> fcs
        """
        last_layer_dim = in_channels
        # add branch specific conv layers
        branch_convs = nn.ModuleList()
        if num_branch_convs > 0:
            for i in range(num_branch_convs):
97
98
                conv_in_channels = (
                    last_layer_dim if i == 0 else self.conv_out_channels)
pangjm's avatar
pangjm committed
99
100
101
102
103
104
                branch_convs.append(
                    ConvModule(
                        conv_in_channels,
                        self.conv_out_channels,
                        3,
                        padding=1,
105
                        conv_cfg=self.conv_cfg,
Kai Chen's avatar
Kai Chen committed
106
                        norm_cfg=self.norm_cfg))
pangjm's avatar
pangjm committed
107
108
109
110
111
112
113
114
            last_layer_dim = self.conv_out_channels
        # add branch specific fc layers
        branch_fcs = nn.ModuleList()
        if num_branch_fcs > 0:
            # for shared branch, only consider self.with_avg_pool
            # for separated branches, also consider self.num_shared_fcs
            if (is_shared
                    or self.num_shared_fcs == 0) and not self.with_avg_pool:
115
                last_layer_dim *= self.roi_feat_area
pangjm's avatar
pangjm committed
116
            for i in range(num_branch_fcs):
117
118
                fc_in_channels = (
                    last_layer_dim if i == 0 else self.fc_out_channels)
pangjm's avatar
pangjm committed
119
120
121
122
123
124
                branch_fcs.append(
                    nn.Linear(fc_in_channels, self.fc_out_channels))
            last_layer_dim = self.fc_out_channels
        return branch_convs, branch_fcs, last_layer_dim

    def init_weights(self):
Kai Chen's avatar
Kai Chen committed
125
        super(ConvFCBBoxHead, self).init_weights()
pangjm's avatar
pangjm committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        for module_list in [self.shared_fcs, self.cls_fcs, self.reg_fcs]:
            for m in module_list.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # shared part
        if self.num_shared_convs > 0:
            for conv in self.shared_convs:
                x = conv(x)

        if self.num_shared_fcs > 0:
            if self.with_avg_pool:
                x = self.avg_pool(x)
            x = x.view(x.size(0), -1)
            for fc in self.shared_fcs:
                x = self.relu(fc(x))
        # separate branches
        x_cls = x
        x_reg = x

        for conv in self.cls_convs:
            x_cls = conv(x_cls)
        if x_cls.dim() > 2:
            if self.with_avg_pool:
                x_cls = self.avg_pool(x_cls)
            x_cls = x_cls.view(x_cls.size(0), -1)
        for fc in self.cls_fcs:
            x_cls = self.relu(fc(x_cls))

        for conv in self.reg_convs:
            x_reg = conv(x_reg)
        if x_reg.dim() > 2:
            if self.with_avg_pool:
                x_reg = self.avg_pool(x_reg)
            x_reg = x_reg.view(x_reg.size(0), -1)
        for fc in self.reg_fcs:
            x_reg = self.relu(fc(x_reg))

        cls_score = self.fc_cls(x_cls) if self.with_cls else None
        bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
        return cls_score, bbox_pred


Kai Chen's avatar
Kai Chen committed
171
@HEADS.register_module
Kai Chen's avatar
Kai Chen committed
172
class SharedFCBBoxHead(ConvFCBBoxHead):
pangjm's avatar
pangjm committed
173
174
175

    def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs):
        assert num_fcs >= 1
Kai Chen's avatar
Kai Chen committed
176
        super(SharedFCBBoxHead, self).__init__(
pangjm's avatar
pangjm committed
177
178
179
180
181
182
183
184
185
            num_shared_convs=0,
            num_shared_fcs=num_fcs,
            num_cls_convs=0,
            num_cls_fcs=0,
            num_reg_convs=0,
            num_reg_fcs=0,
            fc_out_channels=fc_out_channels,
            *args,
            **kwargs)