fcn_mask_head.py 6.81 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
import mmcv
import numpy as np
import pycocotools.mask as mask_util
import torch
import torch.nn as nn

7
from mmdet.core import auto_fp16, force_fp32, mask_target
Jiangmiao Pang's avatar
Jiangmiao Pang committed
8
from ..builder import build_loss
Kai Chen's avatar
Kai Chen committed
9
from ..registry import HEADS
pangjm's avatar
pangjm committed
10
from ..utils import ConvModule
Kai Chen's avatar
Kai Chen committed
11
12


Kai Chen's avatar
Kai Chen committed
13
@HEADS.register_module
Kai Chen's avatar
Kai Chen committed
14
15
16
17
18
19
20
21
22
23
24
25
class FCNMaskHead(nn.Module):

    def __init__(self,
                 num_convs=4,
                 roi_feat_size=14,
                 in_channels=256,
                 conv_kernel_size=3,
                 conv_out_channels=256,
                 upsample_method='deconv',
                 upsample_ratio=2,
                 num_classes=81,
                 class_agnostic=False,
26
                 conv_cfg=None,
Jiangmiao Pang's avatar
Jiangmiao Pang committed
27
28
29
                 norm_cfg=None,
                 loss_mask=dict(
                     type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)):
Kai Chen's avatar
Kai Chen committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
        super(FCNMaskHead, self).__init__()
        if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
            raise ValueError(
                'Invalid upsample method {}, accepted methods '
                'are "deconv", "nearest", "bilinear"'.format(upsample_method))
        self.num_convs = num_convs
        self.roi_feat_size = roi_feat_size  # WARN: not used and reserved
        self.in_channels = in_channels
        self.conv_kernel_size = conv_kernel_size
        self.conv_out_channels = conv_out_channels
        self.upsample_method = upsample_method
        self.upsample_ratio = upsample_ratio
        self.num_classes = num_classes
        self.class_agnostic = class_agnostic
44
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
45
        self.norm_cfg = norm_cfg
Cao Yuhang's avatar
Cao Yuhang committed
46
        self.fp16_enabled = False
Jiangmiao Pang's avatar
Jiangmiao Pang committed
47
        self.loss_mask = build_loss(loss_mask)
Kai Chen's avatar
Kai Chen committed
48
49
50

        self.convs = nn.ModuleList()
        for i in range(self.num_convs):
51
52
            in_channels = (
                self.in_channels if i == 0 else self.conv_out_channels)
Kai Chen's avatar
Kai Chen committed
53
54
55
56
57
            padding = (self.conv_kernel_size - 1) // 2
            self.convs.append(
                ConvModule(
                    in_channels,
                    self.conv_out_channels,
58
                    self.conv_kernel_size,
Kai Chen's avatar
Kai Chen committed
59
                    padding=padding,
60
                    conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
61
                    norm_cfg=norm_cfg))
62
63
        upsample_in_channels = (
            self.conv_out_channels if self.num_convs > 0 else in_channels)
Kai Chen's avatar
Kai Chen committed
64
65
66
67
        if self.upsample_method is None:
            self.upsample = None
        elif self.upsample_method == 'deconv':
            self.upsample = nn.ConvTranspose2d(
myownskyW7's avatar
myownskyW7 committed
68
                upsample_in_channels,
Kai Chen's avatar
Kai Chen committed
69
70
71
72
73
74
75
76
                self.conv_out_channels,
                self.upsample_ratio,
                stride=self.upsample_ratio)
        else:
            self.upsample = nn.Upsample(
                scale_factor=self.upsample_ratio, mode=self.upsample_method)

        out_channels = 1 if self.class_agnostic else self.num_classes
77
78
79
        logits_in_channel = (
            self.conv_out_channels
            if self.upsample_method == 'deconv' else upsample_in_channels)
myownskyW7's avatar
myownskyW7 committed
80
        self.conv_logits = nn.Conv2d(logits_in_channel, out_channels, 1)
Kai Chen's avatar
Kai Chen committed
81
82
83
84
85
86
87
88
89
90
91
        self.relu = nn.ReLU(inplace=True)
        self.debug_imgs = None

    def init_weights(self):
        for m in [self.upsample, self.conv_logits]:
            if m is None:
                continue
            nn.init.kaiming_normal_(
                m.weight, mode='fan_out', nonlinearity='relu')
            nn.init.constant_(m.bias, 0)

Cao Yuhang's avatar
Cao Yuhang committed
92
    @auto_fp16()
Kai Chen's avatar
Kai Chen committed
93
    def forward(self, x):
pangjm's avatar
pangjm committed
94
95
        for conv in self.convs:
            x = conv(x)
Kai Chen's avatar
Kai Chen committed
96
97
98
99
100
101
102
        if self.upsample is not None:
            x = self.upsample(x)
            if self.upsample_method == 'deconv':
                x = self.relu(x)
        mask_pred = self.conv_logits(x)
        return mask_pred

Kai Chen's avatar
Kai Chen committed
103
104
105
106
107
    def get_target(self, sampling_results, gt_masks, rcnn_train_cfg):
        pos_proposals = [res.pos_bboxes for res in sampling_results]
        pos_assigned_gt_inds = [
            res.pos_assigned_gt_inds for res in sampling_results
        ]
Kai Chen's avatar
Kai Chen committed
108
        mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
Kai Chen's avatar
Kai Chen committed
109
                                   gt_masks, rcnn_train_cfg)
Kai Chen's avatar
Kai Chen committed
110
111
        return mask_targets

Cao Yuhang's avatar
Cao Yuhang committed
112
    @force_fp32(apply_to=('mask_pred', ))
Kai Chen's avatar
Kai Chen committed
113
    def loss(self, mask_pred, mask_targets, labels):
114
        loss = dict()
115
        if self.class_agnostic:
Jiangmiao Pang's avatar
Jiangmiao Pang committed
116
117
            loss_mask = self.loss_mask(mask_pred, mask_targets,
                                       torch.zeros_like(labels))
118
        else:
Jiangmiao Pang's avatar
Jiangmiao Pang committed
119
            loss_mask = self.loss_mask(mask_pred, mask_targets, labels)
120
121
        loss['loss_mask'] = loss_mask
        return loss
Kai Chen's avatar
Kai Chen committed
122

pangjm's avatar
pangjm committed
123
    def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
Kai Chen's avatar
Kai Chen committed
124
125
126
                      ori_shape, scale_factor, rescale):
        """Get segmentation masks from mask_pred and bboxes.

Kai Chen's avatar
Kai Chen committed
127
128
129
130
131
132
133
134
135
        Args:
            mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
                For single-scale testing, mask_pred is the direct output of
                model, whose type is Tensor, while for multi-scale testing,
                it will be converted to numpy array outside of this method.
            det_bboxes (Tensor): shape (n, 4/5)
            det_labels (Tensor): shape (n, )
            img_shape (Tensor): shape (3, )
            rcnn_test_cfg (dict): rcnn testing config
136
            ori_shape: original image size
Kai Chen's avatar
Kai Chen committed
137

Kai Chen's avatar
Kai Chen committed
138
139
140
141
142
143
        Returns:
            list[list]: encoded masks
        """
        if isinstance(mask_pred, torch.Tensor):
            mask_pred = mask_pred.sigmoid().cpu().numpy()
        assert isinstance(mask_pred, np.ndarray)
Cao Yuhang's avatar
Cao Yuhang committed
144
145
146
        # when enabling mixed precision training, mask_pred may be float16
        # numpy array
        mask_pred = mask_pred.astype(np.float32)
pangjm's avatar
pangjm committed
147

Kai Chen's avatar
Kai Chen committed
148
149
150
151
        cls_segms = [[] for _ in range(self.num_classes - 1)]
        bboxes = det_bboxes.cpu().numpy()[:, :4]
        labels = det_labels.cpu().numpy() + 1

Kai Chen's avatar
Kai Chen committed
152
153
154
155
156
157
        if rescale:
            img_h, img_w = ori_shape[:2]
        else:
            img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
            img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
            scale_factor = 1.0
pangjm's avatar
pangjm committed
158

Kai Chen's avatar
Kai Chen committed
159
        for i in range(bboxes.shape[0]):
Kai Chen's avatar
Kai Chen committed
160
            bbox = (bboxes[i, :] / scale_factor).astype(np.int32)
Kai Chen's avatar
Kai Chen committed
161
            label = labels[i]
Kai Chen's avatar
Kai Chen committed
162
163
            w = max(bbox[2] - bbox[0] + 1, 1)
            h = max(bbox[3] - bbox[1] + 1, 1)
Kai Chen's avatar
Kai Chen committed
164
165

            if not self.class_agnostic:
Kai Chen's avatar
Kai Chen committed
166
                mask_pred_ = mask_pred[i, label, :, :]
Kai Chen's avatar
Kai Chen committed
167
            else:
Kai Chen's avatar
Kai Chen committed
168
                mask_pred_ = mask_pred[i, 0, :, :]
pangjm's avatar
pangjm committed
169
170
            im_mask = np.zeros((img_h, img_w), dtype=np.uint8)

Kai Chen's avatar
Kai Chen committed
171
172
173
174
            bbox_mask = mmcv.imresize(mask_pred_, (w, h))
            bbox_mask = (bbox_mask > rcnn_test_cfg.mask_thr_binary).astype(
                np.uint8)
            im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask
Kai Chen's avatar
Kai Chen committed
175
176
177
            rle = mask_util.encode(
                np.array(im_mask[:, :, np.newaxis], order='F'))[0]
            cls_segms[label - 1].append(rle)
pangjm's avatar
pangjm committed
178

Kai Chen's avatar
Kai Chen committed
179
        return cls_segms