test_mixins.py 11 KB
Newer Older
1
2
3
import logging
import sys

Wenwei Zhang's avatar
Wenwei Zhang committed
4
5
import torch

6
7
from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_bboxes,
                        merge_aug_masks, merge_aug_proposals, multiclass_nms)
Kai Chen's avatar
Kai Chen committed
8

9
10
11
12
13
logger = logging.getLogger(__name__)

if sys.version_info >= (3, 7):
    from mmdet.utils.contextmanagers import completed

Kai Chen's avatar
Kai Chen committed
14
15
16

class RPNTestMixin(object):

17
18
19
20
21
22
23
24
25
26
27
28
29
30
    if sys.version_info >= (3, 7):

        async def async_test_rpn(self, x, img_meta, rpn_test_cfg):
            sleep_interval = rpn_test_cfg.pop("async_sleep_interval", 0.025)
            async with completed(
                    __name__, "rpn_head_forward",
                    sleep_interval=sleep_interval):
                rpn_outs = self.rpn_head(x)

            proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)

            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
            return proposal_list

Kai Chen's avatar
Kai Chen committed
31
32
33
    def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
        rpn_outs = self.rpn_head(x)
        proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
34
        proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
Kai Chen's avatar
Kai Chen committed
35
36
37
38
39
40
41
42
43
        return proposal_list

    def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
        imgs_per_gpu = len(img_metas[0])
        aug_proposals = [[] for _ in range(imgs_per_gpu)]
        for x, img_meta in zip(feats, img_metas):
            proposal_list = self.simple_test_rpn(x, img_meta, rpn_test_cfg)
            for i, proposals in enumerate(proposal_list):
                aug_proposals[i].append(proposals)
sty-yyj's avatar
sty-yyj committed
44
45
46
47
48
49
50
51
        # reorganize the order of 'img_metas' to match the dimensions
        # of 'aug_proposals'
        aug_img_metas = []
        for i in range(imgs_per_gpu):
            aug_img_meta = []
            for j in range(len(img_metas)):
                aug_img_meta.append(img_metas[j][i])
            aug_img_metas.append(aug_img_meta)
Kai Chen's avatar
Kai Chen committed
52
53
        # after merging, proposals will be rescaled to the original image size
        merged_proposals = [
sty-yyj's avatar
sty-yyj committed
54
55
            merge_aug_proposals(proposals, aug_img_meta, rpn_test_cfg)
            for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas)
Kai Chen's avatar
Kai Chen committed
56
57
58
59
60
61
        ]
        return merged_proposals


class BBoxTestMixin(object):

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    if sys.version_info >= (3, 7):

        async def async_test_bboxes(self,
                                    x,
                                    img_meta,
                                    proposals,
                                    rcnn_test_cfg,
                                    rescale=False,
                                    bbox_semaphore=None,
                                    global_lock=None):
            """Async test only det bboxes without augmentation."""
            rois = bbox2roi(proposals)
            roi_feats = self.bbox_roi_extractor(
                x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
            if self.with_shared_head:
                roi_feats = self.shared_head(roi_feats)
            sleep_interval = rcnn_test_cfg.get("async_sleep_interval", 0.017)

            async with completed(
                    __name__, "bbox_head_forward",
                    sleep_interval=sleep_interval):
                cls_score, bbox_pred = self.bbox_head(roi_feats)

            img_shape = img_meta[0]['img_shape']
            scale_factor = img_meta[0]['scale_factor']
            det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
                rois,
                cls_score,
                bbox_pred,
                img_shape,
                scale_factor,
                rescale=rescale,
                cfg=rcnn_test_cfg)
            return det_bboxes, det_labels

Kai Chen's avatar
Kai Chen committed
97
98
99
100
101
102
103
104
105
106
    def simple_test_bboxes(self,
                           x,
                           img_meta,
                           proposals,
                           rcnn_test_cfg,
                           rescale=False):
        """Test only det bboxes without augmentation."""
        rois = bbox2roi(proposals)
        roi_feats = self.bbox_roi_extractor(
            x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
myownskyW7's avatar
myownskyW7 committed
107
108
        if self.with_shared_head:
            roi_feats = self.shared_head(roi_feats)
Kai Chen's avatar
Kai Chen committed
109
110
111
112
113
114
115
116
117
118
        cls_score, bbox_pred = self.bbox_head(roi_feats)
        img_shape = img_meta[0]['img_shape']
        scale_factor = img_meta[0]['scale_factor']
        det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
            rois,
            cls_score,
            bbox_pred,
            img_shape,
            scale_factor,
            rescale=rescale,
119
            cfg=rcnn_test_cfg)
Kai Chen's avatar
Kai Chen committed
120
121
        return det_bboxes, det_labels

122
    def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
Kai Chen's avatar
Kai Chen committed
123
124
125
126
127
128
129
        aug_bboxes = []
        aug_scores = []
        for x, img_meta in zip(feats, img_metas):
            # only one image in the batch
            img_shape = img_meta[0]['img_shape']
            scale_factor = img_meta[0]['scale_factor']
            flip = img_meta[0]['flip']
130
131
132
            # TODO more flexible
            proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
                                     scale_factor, flip)
Kai Chen's avatar
Kai Chen committed
133
134
135
136
            rois = bbox2roi([proposals])
            # recompute feature maps to save GPU memory
            roi_feats = self.bbox_roi_extractor(
                x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
myownskyW7's avatar
myownskyW7 committed
137
138
            if self.with_shared_head:
                roi_feats = self.shared_head(roi_feats)
Kai Chen's avatar
Kai Chen committed
139
140
141
142
143
144
            cls_score, bbox_pred = self.bbox_head(roi_feats)
            bboxes, scores = self.bbox_head.get_det_bboxes(
                rois,
                cls_score,
                bbox_pred,
                img_shape,
145
                scale_factor,
Kai Chen's avatar
Kai Chen committed
146
                rescale=False,
147
                cfg=None)
Kai Chen's avatar
Kai Chen committed
148
149
150
151
            aug_bboxes.append(bboxes)
            aug_scores.append(scores)
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes, merged_scores = merge_aug_bboxes(
152
            aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
153
154
155
156
        det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
                                                rcnn_test_cfg.score_thr,
                                                rcnn_test_cfg.nms,
                                                rcnn_test_cfg.max_per_img)
Kai Chen's avatar
Kai Chen committed
157
158
159
160
161
        return det_bboxes, det_labels


class MaskTestMixin(object):

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    if sys.version_info >= (3, 7):

        async def async_test_mask(self,
                                  x,
                                  img_meta,
                                  det_bboxes,
                                  det_labels,
                                  rescale=False,
                                  mask_test_cfg=None):
            # image shape of the first image in the batch (only one)
            ori_shape = img_meta[0]['ori_shape']
            scale_factor = img_meta[0]['scale_factor']
            if det_bboxes.shape[0] == 0:
                segm_result = [[]
                               for _ in range(self.mask_head.num_classes - 1)]
            else:
                _bboxes = (
                    det_bboxes[:, :4] *
                    scale_factor if rescale else det_bboxes)
                mask_rois = bbox2roi([_bboxes])
                mask_feats = self.mask_roi_extractor(
                    x[:len(self.mask_roi_extractor.featmap_strides)],
                    mask_rois)

                if self.with_shared_head:
                    mask_feats = self.shared_head(mask_feats)
                if mask_test_cfg and mask_test_cfg.get('async_sleep_interval'):
                    sleep_interval = mask_test_cfg['async_sleep_interval']
                else:
                    sleep_interval = 0.035
                async with completed(
                        __name__,
                        "mask_head_forward",
                        sleep_interval=sleep_interval):
                    mask_pred = self.mask_head(mask_feats)
                segm_result = self.mask_head.get_seg_masks(
                    mask_pred, _bboxes, det_labels, self.test_cfg.rcnn,
                    ori_shape, scale_factor, rescale)
            return segm_result

Kai Chen's avatar
Kai Chen committed
202
203
204
205
206
207
208
    def simple_test_mask(self,
                         x,
                         img_meta,
                         det_bboxes,
                         det_labels,
                         rescale=False):
        # image shape of the first image in the batch (only one)
209
        ori_shape = img_meta[0]['ori_shape']
Kai Chen's avatar
Kai Chen committed
210
211
212
213
214
215
        scale_factor = img_meta[0]['scale_factor']
        if det_bboxes.shape[0] == 0:
            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
        else:
            # if det_bboxes is rescaled to the original image size, we need to
            # rescale it back to the testing scale to obtain RoIs.
Wenwei Zhang's avatar
Wenwei Zhang committed
216
217
218
            if rescale and not isinstance(scale_factor, float):
                scale_factor = torch.from_numpy(scale_factor).to(
                    det_bboxes.device)
219
220
            _bboxes = (
                det_bboxes[:, :4] * scale_factor if rescale else det_bboxes)
Kai Chen's avatar
Kai Chen committed
221
222
223
            mask_rois = bbox2roi([_bboxes])
            mask_feats = self.mask_roi_extractor(
                x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
myownskyW7's avatar
myownskyW7 committed
224
225
            if self.with_shared_head:
                mask_feats = self.shared_head(mask_feats)
Kai Chen's avatar
Kai Chen committed
226
            mask_pred = self.mask_head(mask_feats)
227
228
229
230
231
            segm_result = self.mask_head.get_seg_masks(mask_pred, _bboxes,
                                                       det_labels,
                                                       self.test_cfg.rcnn,
                                                       ori_shape, scale_factor,
                                                       rescale)
Kai Chen's avatar
Kai Chen committed
232
233
        return segm_result

234
    def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
Kai Chen's avatar
Kai Chen committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        if det_bboxes.shape[0] == 0:
            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
        else:
            aug_masks = []
            for x, img_meta in zip(feats, img_metas):
                img_shape = img_meta[0]['img_shape']
                scale_factor = img_meta[0]['scale_factor']
                flip = img_meta[0]['flip']
                _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
                                       scale_factor, flip)
                mask_rois = bbox2roi([_bboxes])
                mask_feats = self.mask_roi_extractor(
                    x[:len(self.mask_roi_extractor.featmap_strides)],
                    mask_rois)
myownskyW7's avatar
myownskyW7 committed
249
250
                if self.with_shared_head:
                    mask_feats = self.shared_head(mask_feats)
Kai Chen's avatar
Kai Chen committed
251
252
253
254
                mask_pred = self.mask_head(mask_feats)
                # convert to numpy array to save memory
                aug_masks.append(mask_pred.sigmoid().cpu().numpy())
            merged_masks = merge_aug_masks(aug_masks, img_metas,
255
256
257
                                           self.test_cfg.rcnn)

            ori_shape = img_metas[0][0]['ori_shape']
Kai Chen's avatar
Kai Chen committed
258
            segm_result = self.mask_head.get_seg_masks(
pangjm's avatar
pangjm committed
259
260
261
262
263
264
265
                merged_masks,
                det_bboxes,
                det_labels,
                self.test_cfg.rcnn,
                ori_shape,
                scale_factor=1.0,
                rescale=False)
Kai Chen's avatar
Kai Chen committed
266
        return segm_result