base_conv_bbox_head.py 4.26 KB
Newer Older
1
2
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer
3
from mmcv.runner import BaseModule
4
5
6
7
8
9
from torch import nn as nn

from mmdet.models.builder import HEADS


@HEADS.register_module()
10
class BaseConvBboxHead(BaseModule):
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    r"""More general bbox head, with shared conv layers and two optional
    separated branches.

    .. code-block:: none

                     /-> cls convs -> cls_score
        shared convs
                     \-> reg convs -> bbox_pred
    """

    def __init__(self,
                 in_channels=0,
                 shared_conv_channels=(),
                 cls_conv_channels=(),
                 num_cls_out_channels=0,
                 reg_conv_channels=(),
                 num_reg_out_channels=0,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 act_cfg=dict(type='ReLU'),
                 bias='auto',
32
                 init_cfg=None,
33
34
                 *args,
                 **kwargs):
35
36
        super(BaseConvBboxHead, self).__init__(
            init_cfg=init_cfg, *args, **kwargs)
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        assert in_channels > 0
        assert num_cls_out_channels > 0
        assert num_reg_out_channels > 0
        self.in_channels = in_channels
        self.shared_conv_channels = shared_conv_channels
        self.cls_conv_channels = cls_conv_channels
        self.num_cls_out_channels = num_cls_out_channels
        self.reg_conv_channels = reg_conv_channels
        self.num_reg_out_channels = num_reg_out_channels
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        self.bias = bias

        # add shared convs
        if len(self.shared_conv_channels) > 0:
            self.shared_convs = self._add_conv_branch(
                self.in_channels, self.shared_conv_channels)
            out_channels = self.shared_conv_channels[-1]
        else:
            out_channels = self.in_channels

        # add cls specific branch
        prev_channel = out_channels
        if len(self.cls_conv_channels) > 0:
            self.cls_convs = self._add_conv_branch(prev_channel,
                                                   self.cls_conv_channels)
            prev_channel = self.cls_conv_channels[-1]

        self.conv_cls = build_conv_layer(
            conv_cfg,
            in_channels=prev_channel,
            out_channels=num_cls_out_channels,
            kernel_size=1)
        # add reg specific branch
        prev_channel = out_channels
        if len(self.reg_conv_channels) > 0:
            self.reg_convs = self._add_conv_branch(prev_channel,
                                                   self.reg_conv_channels)
            prev_channel = self.reg_conv_channels[-1]

        self.conv_reg = build_conv_layer(
            conv_cfg,
            in_channels=prev_channel,
            out_channels=num_reg_out_channels,
            kernel_size=1)

    def _add_conv_branch(self, in_channels, conv_channels):
        """Add shared or separable branch."""
        conv_spec = [in_channels] + list(conv_channels)
        # add branch specific conv layers
        conv_layers = nn.Sequential()
        for i in range(len(conv_spec) - 1):
            conv_layers.add_module(
                f'layer{i}',
                ConvModule(
                    conv_spec[i],
                    conv_spec[i + 1],
                    kernel_size=1,
                    padding=0,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg,
                    act_cfg=self.act_cfg,
                    bias=self.bias,
                    inplace=True))
        return conv_layers

    def forward(self, feats):
        """Forward.

        Args:
            feats (Tensor): Input features

        Returns:
            Tensor: Class scores predictions
            Tensor: Regression predictions
        """
        # shared part
        if len(self.shared_conv_channels) > 0:
            x = self.shared_convs(feats)

        # separate branches
        x_cls = x
        x_reg = x

        if len(self.cls_conv_channels) > 0:
            x_cls = self.cls_convs(x_cls)
        cls_score = self.conv_cls(x_cls)

        if len(self.reg_conv_channels) > 0:
            x_reg = self.reg_convs(x_reg)
        bbox_pred = self.conv_reg(x_reg)

        return cls_score, bbox_pred