fcn_mask_head.py 6.14 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
import mmcv
import numpy as np
import pycocotools.mask as mask_util
import torch
import torch.nn as nn

Kai Chen's avatar
Kai Chen committed
7
from ..registry import HEADS
pangjm's avatar
pangjm committed
8
9
from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target
Kai Chen's avatar
Kai Chen committed
10
11


Kai Chen's avatar
Kai Chen committed
12
@HEADS.register_module
Kai Chen's avatar
Kai Chen committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class FCNMaskHead(nn.Module):

    def __init__(self,
                 num_convs=4,
                 roi_feat_size=14,
                 in_channels=256,
                 conv_kernel_size=3,
                 conv_out_channels=256,
                 upsample_method='deconv',
                 upsample_ratio=2,
                 num_classes=81,
                 class_agnostic=False,
                 normalize=None):
        super(FCNMaskHead, self).__init__()
        if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
            raise ValueError(
                'Invalid upsample method {}, accepted methods '
                'are "deconv", "nearest", "bilinear"'.format(upsample_method))
        self.num_convs = num_convs
        self.roi_feat_size = roi_feat_size  # WARN: not used and reserved
        self.in_channels = in_channels
        self.conv_kernel_size = conv_kernel_size
        self.conv_out_channels = conv_out_channels
        self.upsample_method = upsample_method
        self.upsample_ratio = upsample_ratio
        self.num_classes = num_classes
        self.class_agnostic = class_agnostic
        self.normalize = normalize
        self.with_bias = normalize is None

        self.convs = nn.ModuleList()
        for i in range(self.num_convs):
            in_channels = (self.in_channels
                           if i == 0 else self.conv_out_channels)
            padding = (self.conv_kernel_size - 1) // 2
            self.convs.append(
                ConvModule(
                    in_channels,
                    self.conv_out_channels,
                    3,
                    padding=padding,
                    normalize=normalize,
                    bias=self.with_bias))
        if self.upsample_method is None:
            self.upsample = None
        elif self.upsample_method == 'deconv':
            self.upsample = nn.ConvTranspose2d(
                self.conv_out_channels,
                self.conv_out_channels,
                self.upsample_ratio,
                stride=self.upsample_ratio)
        else:
            self.upsample = nn.Upsample(
                scale_factor=self.upsample_ratio, mode=self.upsample_method)

        out_channels = 1 if self.class_agnostic else self.num_classes
        self.conv_logits = nn.Conv2d(self.conv_out_channels, out_channels, 1)
        self.relu = nn.ReLU(inplace=True)
        self.debug_imgs = None

    def init_weights(self):
        for m in [self.upsample, self.conv_logits]:
            if m is None:
                continue
            nn.init.kaiming_normal_(
                m.weight, mode='fan_out', nonlinearity='relu')
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
pangjm's avatar
pangjm committed
82
83
        for conv in self.convs:
            x = conv(x)
Kai Chen's avatar
Kai Chen committed
84
85
86
87
88
89
90
        if self.upsample is not None:
            x = self.upsample(x)
            if self.upsample_method == 'deconv':
                x = self.relu(x)
        mask_pred = self.conv_logits(x)
        return mask_pred

Kai Chen's avatar
Kai Chen committed
91
92
93
94
95
    def get_target(self, sampling_results, gt_masks, rcnn_train_cfg):
        pos_proposals = [res.pos_bboxes for res in sampling_results]
        pos_assigned_gt_inds = [
            res.pos_assigned_gt_inds for res in sampling_results
        ]
Kai Chen's avatar
Kai Chen committed
96
        mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
Kai Chen's avatar
Kai Chen committed
97
                                   gt_masks, rcnn_train_cfg)
Kai Chen's avatar
Kai Chen committed
98
99
100
        return mask_targets

    def loss(self, mask_pred, mask_targets, labels):
101
        loss = dict()
102
103
104
105
106
        if self.class_agnostic:
            loss_mask = mask_cross_entropy(mask_pred, mask_targets,
                                           torch.zeros_like(labels))
        else:
            loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
107
108
        loss['loss_mask'] = loss_mask
        return loss
Kai Chen's avatar
Kai Chen committed
109

pangjm's avatar
pangjm committed
110
    def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
Kai Chen's avatar
Kai Chen committed
111
112
113
                      ori_shape, scale_factor, rescale):
        """Get segmentation masks from mask_pred and bboxes.

Kai Chen's avatar
Kai Chen committed
114
115
116
117
118
119
120
121
122
        Args:
            mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
                For single-scale testing, mask_pred is the direct output of
                model, whose type is Tensor, while for multi-scale testing,
                it will be converted to numpy array outside of this method.
            det_bboxes (Tensor): shape (n, 4/5)
            det_labels (Tensor): shape (n, )
            img_shape (Tensor): shape (3, )
            rcnn_test_cfg (dict): rcnn testing config
123
            ori_shape: original image size
Kai Chen's avatar
Kai Chen committed
124

Kai Chen's avatar
Kai Chen committed
125
126
127
128
129
130
        Returns:
            list[list]: encoded masks
        """
        if isinstance(mask_pred, torch.Tensor):
            mask_pred = mask_pred.sigmoid().cpu().numpy()
        assert isinstance(mask_pred, np.ndarray)
pangjm's avatar
pangjm committed
131

Kai Chen's avatar
Kai Chen committed
132
133
134
135
        cls_segms = [[] for _ in range(self.num_classes - 1)]
        bboxes = det_bboxes.cpu().numpy()[:, :4]
        labels = det_labels.cpu().numpy() + 1

Kai Chen's avatar
Kai Chen committed
136
137
138
139
140
141
        if rescale:
            img_h, img_w = ori_shape[:2]
        else:
            img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
            img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
            scale_factor = 1.0
pangjm's avatar
pangjm committed
142

Kai Chen's avatar
Kai Chen committed
143
        for i in range(bboxes.shape[0]):
Kai Chen's avatar
Kai Chen committed
144
            bbox = (bboxes[i, :] / scale_factor).astype(np.int32)
Kai Chen's avatar
Kai Chen committed
145
            label = labels[i]
Kai Chen's avatar
Kai Chen committed
146
147
            w = max(bbox[2] - bbox[0] + 1, 1)
            h = max(bbox[3] - bbox[1] + 1, 1)
Kai Chen's avatar
Kai Chen committed
148
149

            if not self.class_agnostic:
Kai Chen's avatar
Kai Chen committed
150
                mask_pred_ = mask_pred[i, label, :, :]
Kai Chen's avatar
Kai Chen committed
151
            else:
Kai Chen's avatar
Kai Chen committed
152
                mask_pred_ = mask_pred[i, 0, :, :]
pangjm's avatar
pangjm committed
153
154
            im_mask = np.zeros((img_h, img_w), dtype=np.uint8)

Kai Chen's avatar
Kai Chen committed
155
156
157
158
            bbox_mask = mmcv.imresize(mask_pred_, (w, h))
            bbox_mask = (bbox_mask > rcnn_test_cfg.mask_thr_binary).astype(
                np.uint8)
            im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask
Kai Chen's avatar
Kai Chen committed
159
160
161
            rle = mask_util.encode(
                np.array(im_mask[:, :, np.newaxis], order='F'))[0]
            cls_segms[label - 1].append(rle)
pangjm's avatar
pangjm committed
162

Kai Chen's avatar
Kai Chen committed
163
        return cls_segms