fcn_mask_head.py 6.84 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
import mmcv
import numpy as np
import pycocotools.mask as mask_util
import torch
import torch.nn as nn

pangjm's avatar
pangjm committed
7
8
from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target
Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


class FCNMaskHead(nn.Module):

    def __init__(self,
                 num_convs=4,
                 roi_feat_size=14,
                 in_channels=256,
                 conv_kernel_size=3,
                 conv_out_channels=256,
                 upsample_method='deconv',
                 upsample_ratio=2,
                 num_classes=81,
                 class_agnostic=False,
                 normalize=None):
        super(FCNMaskHead, self).__init__()
        if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
            raise ValueError(
                'Invalid upsample method {}, accepted methods '
                'are "deconv", "nearest", "bilinear"'.format(upsample_method))
        self.num_convs = num_convs
        self.roi_feat_size = roi_feat_size  # WARN: not used and reserved
        self.in_channels = in_channels
        self.conv_kernel_size = conv_kernel_size
        self.conv_out_channels = conv_out_channels
        self.upsample_method = upsample_method
        self.upsample_ratio = upsample_ratio
        self.num_classes = num_classes
        self.class_agnostic = class_agnostic
        self.normalize = normalize
        self.with_bias = normalize is None

        self.convs = nn.ModuleList()
        for i in range(self.num_convs):
            in_channels = (self.in_channels
                           if i == 0 else self.conv_out_channels)
            padding = (self.conv_kernel_size - 1) // 2
            self.convs.append(
                ConvModule(
                    in_channels,
                    self.conv_out_channels,
                    3,
                    padding=padding,
                    normalize=normalize,
                    bias=self.with_bias))
        if self.upsample_method is None:
            self.upsample = None
        elif self.upsample_method == 'deconv':
            self.upsample = nn.ConvTranspose2d(
                self.conv_out_channels,
                self.conv_out_channels,
                self.upsample_ratio,
                stride=self.upsample_ratio)
        else:
            self.upsample = nn.Upsample(
                scale_factor=self.upsample_ratio, mode=self.upsample_method)

        out_channels = 1 if self.class_agnostic else self.num_classes
        self.conv_logits = nn.Conv2d(self.conv_out_channels, out_channels, 1)
        self.relu = nn.ReLU(inplace=True)
        self.debug_imgs = None

    def init_weights(self):
        for m in [self.upsample, self.conv_logits]:
            if m is None:
                continue
            nn.init.kaiming_normal_(
                m.weight, mode='fan_out', nonlinearity='relu')
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
pangjm's avatar
pangjm committed
80
81
        for conv in self.convs:
            x = conv(x)
Kai Chen's avatar
Kai Chen committed
82
83
84
85
86
87
88
        if self.upsample is not None:
            x = self.upsample(x)
            if self.upsample_method == 'deconv':
                x = self.relu(x)
        mask_pred = self.conv_logits(x)
        return mask_pred

pangjm's avatar
pangjm committed
89
    def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks,
pangjm's avatar
pangjm committed
90
                        img_meta, rcnn_train_cfg):
Kai Chen's avatar
Kai Chen committed
91
        mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
pangjm's avatar
pangjm committed
92
                                   gt_masks, img_meta, rcnn_train_cfg)
Kai Chen's avatar
Kai Chen committed
93
94
95
        return mask_targets

    def loss(self, mask_pred, mask_targets, labels):
96
        loss = dict()
Kai Chen's avatar
Kai Chen committed
97
        loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
98
99
        loss['loss_mask'] = loss_mask
        return loss
Kai Chen's avatar
Kai Chen committed
100

pangjm's avatar
pangjm committed
101
    def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
102
                      ori_shape):
Kai Chen's avatar
Kai Chen committed
103
104
105
106
107
108
109
110
111
112
        """Get segmentation masks from mask_pred and bboxes
        Args:
            mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
                For single-scale testing, mask_pred is the direct output of
                model, whose type is Tensor, while for multi-scale testing,
                it will be converted to numpy array outside of this method.
            det_bboxes (Tensor): shape (n, 4/5)
            det_labels (Tensor): shape (n, )
            img_shape (Tensor): shape (3, )
            rcnn_test_cfg (dict): rcnn testing config
113
            ori_shape: original image size
Kai Chen's avatar
Kai Chen committed
114
115
116
117
118
119
        Returns:
            list[list]: encoded masks
        """
        if isinstance(mask_pred, torch.Tensor):
            mask_pred = mask_pred.sigmoid().cpu().numpy()
        assert isinstance(mask_pred, np.ndarray)
pangjm's avatar
pangjm committed
120

Kai Chen's avatar
Kai Chen committed
121
        cls_segms = [[] for _ in range(self.num_classes - 1)]
pangjm's avatar
pangjm committed
122
        mask_size = mask_pred.shape[-1]
Kai Chen's avatar
Kai Chen committed
123
124
        bboxes = det_bboxes.cpu().numpy()[:, :4]
        labels = det_labels.cpu().numpy() + 1
125
126
        img_h = ori_shape[0]
        img_w = ori_shape[1]
Kai Chen's avatar
Kai Chen committed
127

pangjm's avatar
pangjm committed
128
129
130
131
132
        scale = (mask_size + 2.0) / mask_size
        bboxes = np.round(self._bbox_scaling(bboxes, scale)).astype(np.int32)
        padded_mask = np.zeros(
            (mask_size + 2, mask_size + 2), dtype=np.float32)

Kai Chen's avatar
Kai Chen committed
133
        for i in range(bboxes.shape[0]):
pangjm's avatar
pangjm committed
134
            bbox = bboxes[i, :].astype(int)
Kai Chen's avatar
Kai Chen committed
135
136
137
138
139
140
141
            label = labels[i]
            w = bbox[2] - bbox[0] + 1
            h = bbox[3] - bbox[1] + 1
            w = max(w, 1)
            h = max(h, 1)

            if not self.class_agnostic:
pangjm's avatar
pangjm committed
142
                padded_mask[1:-1, 1:-1] = mask_pred[i, label, :, :]
Kai Chen's avatar
Kai Chen committed
143
            else:
pangjm's avatar
pangjm committed
144
145
146
147
148
149
150
151
152
153
                padded_mask[1:-1, 1:-1] = mask_pred[i, 0, :, :]
            mask = mmcv.imresize(padded_mask, (w, h))
            mask = np.array(
                mask > rcnn_test_cfg.mask_thr_binary, dtype=np.uint8)
            im_mask = np.zeros((img_h, img_w), dtype=np.uint8)

            x0 = max(bbox[0], 0)
            x1 = min(bbox[2] + 1, img_w)
            y0 = max(bbox[1], 0)
            y1 = min(bbox[3] + 1, img_h)
Kai Chen's avatar
Kai Chen committed
154

pangjm's avatar
pangjm committed
155
156
            im_mask[y0:y1, x0:x1] = mask[(y0 - bbox[1]):(y1 - bbox[1]), (
                x0 - bbox[0]):(x1 - bbox[0])]
Kai Chen's avatar
Kai Chen committed
157
158
159
160
161

            rle = mask_util.encode(
                np.array(im_mask[:, :, np.newaxis], order='F'))[0]
            cls_segms[label - 1].append(rle)
        return cls_segms
pangjm's avatar
pangjm committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

    def _bbox_scaling(self, bboxes, scale, clip_shape=None):
        """Scaling bboxes and clip the boundary(optional)
        Args:
            bboxes(ndarray): shape(..., 4)
            scale(float): scaling factor
            clip(None or tuple): (h, w)
        Returns:
            ndarray: scaled bboxes
        """
        if float(scale) == 1.0:
            scaled_bboxes = bboxes.copy()
        else:
            w = bboxes[..., 2] - bboxes[..., 0] + 1
            h = bboxes[..., 3] - bboxes[..., 1] + 1
            dw = (w * (scale - 1)) * 0.5
            dh = (h * (scale - 1)) * 0.5
            scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
        if clip_shape is not None:
            return bbox_clip(scaled_bboxes, clip_shape)
        else:
            return scaled_bboxes