from .fcn_mask_head import FCNMaskHead from ..registry import HEADS from ..utils import ConvModule @HEADS.register_module class HTCMaskHead(FCNMaskHead): def __init__(self, *args, **kwargs): super(HTCMaskHead, self).__init__(*args, **kwargs) self.conv_res = ConvModule( self.conv_out_channels, self.conv_out_channels, 1, normalize=self.normalize, bias=self.with_bias) def init_weights(self): super(HTCMaskHead, self).init_weights() self.conv_res.init_weights() def forward(self, x, res_feat=None, return_logits=True, return_feat=True): if res_feat is not None: res_feat = self.conv_res(res_feat) x = x + res_feat for conv in self.convs: x = conv(x) res_feat = x outs = [] if return_logits: x = self.upsample(x) if self.upsample_method == 'deconv': x = self.relu(x) mask_pred = self.conv_logits(x) outs.append(mask_pred) if return_feat: outs.append(res_feat) return outs if len(outs) > 1 else outs[0]