"official/vision/modeling/retinanet_model_test.py" did not exist on "f7754fe5b8e8c10092d4155715ab2395580d4901"
fcn_mask_head.py 6.62 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import mmcv
import numpy as np
import pycocotools.mask as mask_util
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp

from ..common import ConvModule
from mmdet.core import mask_target, mask_cross_entropy


class FCNMaskHead(nn.Module):

    def __init__(self,
                 num_convs=4,
                 roi_feat_size=14,
                 in_channels=256,
                 conv_kernel_size=3,
                 conv_out_channels=256,
                 upsample_method='deconv',
                 upsample_ratio=2,
                 num_classes=81,
                 class_agnostic=False,
                 with_cp=False,
                 normalize=None):
        super(FCNMaskHead, self).__init__()
        if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
            raise ValueError(
                'Invalid upsample method {}, accepted methods '
                'are "deconv", "nearest", "bilinear"'.format(upsample_method))
        self.num_convs = num_convs
        self.roi_feat_size = roi_feat_size  # WARN: not used and reserved
        self.in_channels = in_channels
        self.conv_kernel_size = conv_kernel_size
        self.conv_out_channels = conv_out_channels
        self.upsample_method = upsample_method
        self.upsample_ratio = upsample_ratio
        self.num_classes = num_classes
        self.class_agnostic = class_agnostic
        self.normalize = normalize
        self.with_bias = normalize is None
        self.with_cp = with_cp

        self.convs = nn.ModuleList()
        for i in range(self.num_convs):
            in_channels = (self.in_channels
                           if i == 0 else self.conv_out_channels)
            padding = (self.conv_kernel_size - 1) // 2
            self.convs.append(
                ConvModule(
                    in_channels,
                    self.conv_out_channels,
                    3,
                    padding=padding,
                    normalize=normalize,
                    bias=self.with_bias))
        if self.upsample_method is None:
            self.upsample = None
        elif self.upsample_method == 'deconv':
            self.upsample = nn.ConvTranspose2d(
                self.conv_out_channels,
                self.conv_out_channels,
                self.upsample_ratio,
                stride=self.upsample_ratio)
        else:
            self.upsample = nn.Upsample(
                scale_factor=self.upsample_ratio, mode=self.upsample_method)

        out_channels = 1 if self.class_agnostic else self.num_classes
        self.conv_logits = nn.Conv2d(self.conv_out_channels, out_channels, 1)
        self.relu = nn.ReLU(inplace=True)
        self.debug_imgs = None

    def init_weights(self):
        for m in [self.upsample, self.conv_logits]:
            if m is None:
                continue
            nn.init.kaiming_normal_(
                m.weight, mode='fan_out', nonlinearity='relu')
            nn.init.constant_(m.bias, 0)

    def convs_forward(self, x):

        def m_lvl_convs_forward(x):
            for conv in self.convs[1:-1]:
                x = conv(x)
            return x

        if self.num_convs > 0:
            x = self.convs[0](x)
            if self.num_convs > 1:
                if self.with_cp and x.requires_grad:
                    x = cp.checkpoint(m_lvl_convs_forward, x)
                else:
                    x = m_lvl_convs_forward(x)
                x = self.convs[-1](x)
        return x

    def forward(self, x):
        x = self.convs_forward(x)
        if self.upsample is not None:
            x = self.upsample(x)
            if self.upsample_method == 'deconv':
                x = self.relu(x)
        mask_pred = self.conv_logits(x)
        return mask_pred

    def mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks,
                    img_shapes, rcnn_train_cfg):
        mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
                                   gt_masks, img_shapes, rcnn_train_cfg)
        return mask_targets

    def loss(self, mask_pred, mask_targets, labels):
        loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
        return loss_mask

    def get_seg_masks(self,
                      mask_pred,
                      det_bboxes,
                      det_labels,
                      img_shape,
                      rcnn_test_cfg,
                      ori_scale,
                      rescale=True):
        """Get segmentation masks from mask_pred and bboxes
        Args:
            mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
                For single-scale testing, mask_pred is the direct output of
                model, whose type is Tensor, while for multi-scale testing,
                it will be converted to numpy array outside of this method.
            det_bboxes (Tensor): shape (n, 4/5)
            det_labels (Tensor): shape (n, )
            img_shape (Tensor): shape (3, )
            rcnn_test_cfg (dict): rcnn testing config
            rescale (bool): whether rescale masks to original image size
        Returns:
            list[list]: encoded masks
        """
        if isinstance(mask_pred, torch.Tensor):
            mask_pred = mask_pred.sigmoid().cpu().numpy()
        assert isinstance(mask_pred, np.ndarray)
        cls_segms = [[] for _ in range(self.num_classes - 1)]
        bboxes = det_bboxes.cpu().numpy()[:, :4]
        labels = det_labels.cpu().numpy() + 1
        scale_factor = img_shape[-1] if rescale else 1.0
        img_h = ori_scale['height'] if rescale else np.round(
            ori_scale['height'].item() * img_shape[-1].item()).astype(np.int32)
        img_w = ori_scale['width'] if rescale else np.round(
            ori_scale['width'].item() * img_shape[-1].item()).astype(np.int32)

        for i in range(bboxes.shape[0]):
            bbox = (bboxes[i, :] / float(scale_factor)).astype(int)
            label = labels[i]
            w = bbox[2] - bbox[0] + 1
            h = bbox[3] - bbox[1] + 1
            w = max(w, 1)
            h = max(h, 1)

            if not self.class_agnostic:
                mask_pred_ = mask_pred[i, label, :, :]
            else:
                mask_pred_ = mask_pred[i, 0, :, :]

            im_mask = np.zeros((img_h, img_w), dtype=np.float32)

            im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = mmcv.resize(
                mask_pred_, (w, h))
            # im_mask = cv2.resize(im_mask, (img_w, img_h))
            im_mask = np.array(
                im_mask > rcnn_test_cfg.mask_thr_binary, dtype=np.uint8)
            rle = mask_util.encode(
                np.array(im_mask[:, :, np.newaxis], order='F'))[0]
            cls_segms[label - 1].append(rle)
        return cls_segms