fcn_mask_head.py 6.88 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
import mmcv
import numpy as np
import pycocotools.mask as mask_util
import torch
import torch.nn as nn
6
from torch.nn.modules.utils import _pair
Kai Chen's avatar
Kai Chen committed
7

8
from mmdet.core import auto_fp16, force_fp32, mask_target
Jiangmiao Pang's avatar
Jiangmiao Pang committed
9
from ..builder import build_loss
Kai Chen's avatar
Kai Chen committed
10
from ..registry import HEADS
pangjm's avatar
pangjm committed
11
from ..utils import ConvModule
Kai Chen's avatar
Kai Chen committed
12
13


Kai Chen's avatar
Kai Chen committed
14
@HEADS.register_module
Kai Chen's avatar
Kai Chen committed
15
16
17
18
19
20
21
22
23
24
25
26
class FCNMaskHead(nn.Module):

    def __init__(self,
                 num_convs=4,
                 roi_feat_size=14,
                 in_channels=256,
                 conv_kernel_size=3,
                 conv_out_channels=256,
                 upsample_method='deconv',
                 upsample_ratio=2,
                 num_classes=81,
                 class_agnostic=False,
27
                 conv_cfg=None,
Jiangmiao Pang's avatar
Jiangmiao Pang committed
28
29
30
                 norm_cfg=None,
                 loss_mask=dict(
                     type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)):
Kai Chen's avatar
Kai Chen committed
31
32
33
34
35
36
        super(FCNMaskHead, self).__init__()
        if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
            raise ValueError(
                'Invalid upsample method {}, accepted methods '
                'are "deconv", "nearest", "bilinear"'.format(upsample_method))
        self.num_convs = num_convs
37
38
        # WARN: roi_feat_size is reserved and not used
        self.roi_feat_size = _pair(roi_feat_size)
Kai Chen's avatar
Kai Chen committed
39
40
41
42
43
44
45
        self.in_channels = in_channels
        self.conv_kernel_size = conv_kernel_size
        self.conv_out_channels = conv_out_channels
        self.upsample_method = upsample_method
        self.upsample_ratio = upsample_ratio
        self.num_classes = num_classes
        self.class_agnostic = class_agnostic
46
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
47
        self.norm_cfg = norm_cfg
Cao Yuhang's avatar
Cao Yuhang committed
48
        self.fp16_enabled = False
Jiangmiao Pang's avatar
Jiangmiao Pang committed
49
        self.loss_mask = build_loss(loss_mask)
Kai Chen's avatar
Kai Chen committed
50
51
52

        self.convs = nn.ModuleList()
        for i in range(self.num_convs):
53
54
            in_channels = (
                self.in_channels if i == 0 else self.conv_out_channels)
Kai Chen's avatar
Kai Chen committed
55
56
57
58
59
            padding = (self.conv_kernel_size - 1) // 2
            self.convs.append(
                ConvModule(
                    in_channels,
                    self.conv_out_channels,
60
                    self.conv_kernel_size,
Kai Chen's avatar
Kai Chen committed
61
                    padding=padding,
62
                    conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
63
                    norm_cfg=norm_cfg))
64
65
        upsample_in_channels = (
            self.conv_out_channels if self.num_convs > 0 else in_channels)
Kai Chen's avatar
Kai Chen committed
66
67
68
69
        if self.upsample_method is None:
            self.upsample = None
        elif self.upsample_method == 'deconv':
            self.upsample = nn.ConvTranspose2d(
myownskyW7's avatar
myownskyW7 committed
70
                upsample_in_channels,
Kai Chen's avatar
Kai Chen committed
71
72
73
74
75
76
77
78
                self.conv_out_channels,
                self.upsample_ratio,
                stride=self.upsample_ratio)
        else:
            self.upsample = nn.Upsample(
                scale_factor=self.upsample_ratio, mode=self.upsample_method)

        out_channels = 1 if self.class_agnostic else self.num_classes
79
80
81
        logits_in_channel = (
            self.conv_out_channels
            if self.upsample_method == 'deconv' else upsample_in_channels)
myownskyW7's avatar
myownskyW7 committed
82
        self.conv_logits = nn.Conv2d(logits_in_channel, out_channels, 1)
Kai Chen's avatar
Kai Chen committed
83
84
85
86
87
88
89
90
91
92
93
        self.relu = nn.ReLU(inplace=True)
        self.debug_imgs = None

    def init_weights(self):
        for m in [self.upsample, self.conv_logits]:
            if m is None:
                continue
            nn.init.kaiming_normal_(
                m.weight, mode='fan_out', nonlinearity='relu')
            nn.init.constant_(m.bias, 0)

Cao Yuhang's avatar
Cao Yuhang committed
94
    @auto_fp16()
Kai Chen's avatar
Kai Chen committed
95
    def forward(self, x):
pangjm's avatar
pangjm committed
96
97
        for conv in self.convs:
            x = conv(x)
Kai Chen's avatar
Kai Chen committed
98
99
100
101
102
103
104
        if self.upsample is not None:
            x = self.upsample(x)
            if self.upsample_method == 'deconv':
                x = self.relu(x)
        mask_pred = self.conv_logits(x)
        return mask_pred

Kai Chen's avatar
Kai Chen committed
105
106
107
108
109
    def get_target(self, sampling_results, gt_masks, rcnn_train_cfg):
        pos_proposals = [res.pos_bboxes for res in sampling_results]
        pos_assigned_gt_inds = [
            res.pos_assigned_gt_inds for res in sampling_results
        ]
Kai Chen's avatar
Kai Chen committed
110
        mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
Kai Chen's avatar
Kai Chen committed
111
                                   gt_masks, rcnn_train_cfg)
Kai Chen's avatar
Kai Chen committed
112
113
        return mask_targets

Cao Yuhang's avatar
Cao Yuhang committed
114
    @force_fp32(apply_to=('mask_pred', ))
Kai Chen's avatar
Kai Chen committed
115
    def loss(self, mask_pred, mask_targets, labels):
116
        loss = dict()
117
        if self.class_agnostic:
Jiangmiao Pang's avatar
Jiangmiao Pang committed
118
119
            loss_mask = self.loss_mask(mask_pred, mask_targets,
                                       torch.zeros_like(labels))
120
        else:
Jiangmiao Pang's avatar
Jiangmiao Pang committed
121
            loss_mask = self.loss_mask(mask_pred, mask_targets, labels)
122
123
        loss['loss_mask'] = loss_mask
        return loss
Kai Chen's avatar
Kai Chen committed
124

pangjm's avatar
pangjm committed
125
    def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
Kai Chen's avatar
Kai Chen committed
126
127
128
                      ori_shape, scale_factor, rescale):
        """Get segmentation masks from mask_pred and bboxes.

Kai Chen's avatar
Kai Chen committed
129
130
131
132
133
134
135
136
137
        Args:
            mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
                For single-scale testing, mask_pred is the direct output of
                model, whose type is Tensor, while for multi-scale testing,
                it will be converted to numpy array outside of this method.
            det_bboxes (Tensor): shape (n, 4/5)
            det_labels (Tensor): shape (n, )
            img_shape (Tensor): shape (3, )
            rcnn_test_cfg (dict): rcnn testing config
138
            ori_shape: original image size
Kai Chen's avatar
Kai Chen committed
139

Kai Chen's avatar
Kai Chen committed
140
141
142
143
144
145
        Returns:
            list[list]: encoded masks
        """
        if isinstance(mask_pred, torch.Tensor):
            mask_pred = mask_pred.sigmoid().cpu().numpy()
        assert isinstance(mask_pred, np.ndarray)
Cao Yuhang's avatar
Cao Yuhang committed
146
147
148
        # when enabling mixed precision training, mask_pred may be float16
        # numpy array
        mask_pred = mask_pred.astype(np.float32)
pangjm's avatar
pangjm committed
149

Kai Chen's avatar
Kai Chen committed
150
151
152
153
        cls_segms = [[] for _ in range(self.num_classes - 1)]
        bboxes = det_bboxes.cpu().numpy()[:, :4]
        labels = det_labels.cpu().numpy() + 1

Kai Chen's avatar
Kai Chen committed
154
155
156
157
158
159
        if rescale:
            img_h, img_w = ori_shape[:2]
        else:
            img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
            img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
            scale_factor = 1.0
pangjm's avatar
pangjm committed
160

Kai Chen's avatar
Kai Chen committed
161
        for i in range(bboxes.shape[0]):
Kai Chen's avatar
Kai Chen committed
162
            bbox = (bboxes[i, :] / scale_factor).astype(np.int32)
Kai Chen's avatar
Kai Chen committed
163
            label = labels[i]
Kai Chen's avatar
Kai Chen committed
164
165
            w = max(bbox[2] - bbox[0] + 1, 1)
            h = max(bbox[3] - bbox[1] + 1, 1)
Kai Chen's avatar
Kai Chen committed
166
167

            if not self.class_agnostic:
Kai Chen's avatar
Kai Chen committed
168
                mask_pred_ = mask_pred[i, label, :, :]
Kai Chen's avatar
Kai Chen committed
169
            else:
Kai Chen's avatar
Kai Chen committed
170
                mask_pred_ = mask_pred[i, 0, :, :]
pangjm's avatar
pangjm committed
171
172
            im_mask = np.zeros((img_h, img_w), dtype=np.uint8)

Kai Chen's avatar
Kai Chen committed
173
174
175
176
            bbox_mask = mmcv.imresize(mask_pred_, (w, h))
            bbox_mask = (bbox_mask > rcnn_test_cfg.mask_thr_binary).astype(
                np.uint8)
            im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask
Kai Chen's avatar
Kai Chen committed
177
178
179
            rle = mask_util.encode(
                np.array(im_mask[:, :, np.newaxis], order='F'))[0]
            cls_segms[label - 1].append(rle)
pangjm's avatar
pangjm committed
180

Kai Chen's avatar
Kai Chen committed
181
        return cls_segms