htc_mask_head.py 1.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from .fcn_mask_head import FCNMaskHead
from ..registry import HEADS
from ..utils import ConvModule


@HEADS.register_module
class HTCMaskHead(FCNMaskHead):

    def __init__(self, *args, **kwargs):
        super(HTCMaskHead, self).__init__(*args, **kwargs)
        self.conv_res = ConvModule(
            self.conv_out_channels,
            self.conv_out_channels,
            1,
15
            conv_cfg=self.conv_cfg,
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
            normalize=self.normalize,
            bias=self.with_bias)

    def init_weights(self):
        super(HTCMaskHead, self).init_weights()
        self.conv_res.init_weights()

    def forward(self, x, res_feat=None, return_logits=True, return_feat=True):
        if res_feat is not None:
            res_feat = self.conv_res(res_feat)
            x = x + res_feat
        for conv in self.convs:
            x = conv(x)
        res_feat = x
        outs = []
        if return_logits:
            x = self.upsample(x)
            if self.upsample_method == 'deconv':
                x = self.relu(x)
            mask_pred = self.conv_logits(x)
            outs.append(mask_pred)
        if return_feat:
            outs.append(res_feat)
        return outs if len(outs) > 1 else outs[0]