test_mixins.py 6.39 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
7
8
9
from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_proposals,
                        merge_aug_bboxes, merge_aug_masks, multiclass_nms)


class RPNTestMixin(object):

    def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
        rpn_outs = self.rpn_head(x)
        proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
10
        proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
        return proposal_list

    def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
        imgs_per_gpu = len(img_metas[0])
        aug_proposals = [[] for _ in range(imgs_per_gpu)]
        for x, img_meta in zip(feats, img_metas):
            proposal_list = self.simple_test_rpn(x, img_meta, rpn_test_cfg)
            for i, proposals in enumerate(proposal_list):
                aug_proposals[i].append(proposals)
        # after merging, proposals will be rescaled to the original image size
        merged_proposals = [
            merge_aug_proposals(proposals, img_meta, rpn_test_cfg)
            for proposals, img_meta in zip(aug_proposals, img_metas)
        ]
        return merged_proposals


class BBoxTestMixin(object):

    def simple_test_bboxes(self,
                           x,
                           img_meta,
                           proposals,
                           rcnn_test_cfg,
                           rescale=False):
        """Test only det bboxes without augmentation."""
        rois = bbox2roi(proposals)
        roi_feats = self.bbox_roi_extractor(
            x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
myownskyW7's avatar
myownskyW7 committed
40
41
        if self.with_shared_head:
            roi_feats = self.shared_head(roi_feats)
Kai Chen's avatar
Kai Chen committed
42
43
44
45
46
47
48
49
50
51
        cls_score, bbox_pred = self.bbox_head(roi_feats)
        img_shape = img_meta[0]['img_shape']
        scale_factor = img_meta[0]['scale_factor']
        det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
            rois,
            cls_score,
            bbox_pred,
            img_shape,
            scale_factor,
            rescale=rescale,
52
            cfg=rcnn_test_cfg)
Kai Chen's avatar
Kai Chen committed
53
54
        return det_bboxes, det_labels

55
    def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
Kai Chen's avatar
Kai Chen committed
56
57
58
59
60
61
62
        aug_bboxes = []
        aug_scores = []
        for x, img_meta in zip(feats, img_metas):
            # only one image in the batch
            img_shape = img_meta[0]['img_shape']
            scale_factor = img_meta[0]['scale_factor']
            flip = img_meta[0]['flip']
63
64
65
            # TODO more flexible
            proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
                                     scale_factor, flip)
Kai Chen's avatar
Kai Chen committed
66
67
68
69
            rois = bbox2roi([proposals])
            # recompute feature maps to save GPU memory
            roi_feats = self.bbox_roi_extractor(
                x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
myownskyW7's avatar
myownskyW7 committed
70
71
            if self.with_shared_head:
                roi_feats = self.shared_head(roi_feats)
Kai Chen's avatar
Kai Chen committed
72
73
74
75
76
77
            cls_score, bbox_pred = self.bbox_head(roi_feats)
            bboxes, scores = self.bbox_head.get_det_bboxes(
                rois,
                cls_score,
                bbox_pred,
                img_shape,
78
                scale_factor,
Kai Chen's avatar
Kai Chen committed
79
                rescale=False,
80
                cfg=None)
Kai Chen's avatar
Kai Chen committed
81
82
83
84
            aug_bboxes.append(bboxes)
            aug_scores.append(scores)
        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes, merged_scores = merge_aug_bboxes(
85
            aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
Kai Chen's avatar
Kai Chen committed
86
        det_bboxes, det_labels = multiclass_nms(
87
88
            merged_bboxes, merged_scores, rcnn_test_cfg.score_thr,
            rcnn_test_cfg.nms, rcnn_test_cfg.max_per_img)
Kai Chen's avatar
Kai Chen committed
89
90
91
92
93
94
95
96
97
98
99
100
        return det_bboxes, det_labels


class MaskTestMixin(object):

    def simple_test_mask(self,
                         x,
                         img_meta,
                         det_bboxes,
                         det_labels,
                         rescale=False):
        # image shape of the first image in the batch (only one)
101
        ori_shape = img_meta[0]['ori_shape']
Kai Chen's avatar
Kai Chen committed
102
103
104
105
106
107
108
109
110
111
112
        scale_factor = img_meta[0]['scale_factor']
        if det_bboxes.shape[0] == 0:
            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
        else:
            # if det_bboxes is rescaled to the original image size, we need to
            # rescale it back to the testing scale to obtain RoIs.
            _bboxes = (det_bboxes[:, :4] * scale_factor
                       if rescale else det_bboxes)
            mask_rois = bbox2roi([_bboxes])
            mask_feats = self.mask_roi_extractor(
                x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
myownskyW7's avatar
myownskyW7 committed
113
114
            if self.with_shared_head:
                mask_feats = self.shared_head(mask_feats)
Kai Chen's avatar
Kai Chen committed
115
116
            mask_pred = self.mask_head(mask_feats)
            segm_result = self.mask_head.get_seg_masks(
Kai Chen's avatar
Kai Chen committed
117
118
                mask_pred, _bboxes, det_labels, self.test_cfg.rcnn, ori_shape,
                scale_factor, rescale)
Kai Chen's avatar
Kai Chen committed
119
120
        return segm_result

121
    def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
Kai Chen's avatar
Kai Chen committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        if det_bboxes.shape[0] == 0:
            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
        else:
            aug_masks = []
            for x, img_meta in zip(feats, img_metas):
                img_shape = img_meta[0]['img_shape']
                scale_factor = img_meta[0]['scale_factor']
                flip = img_meta[0]['flip']
                _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
                                       scale_factor, flip)
                mask_rois = bbox2roi([_bboxes])
                mask_feats = self.mask_roi_extractor(
                    x[:len(self.mask_roi_extractor.featmap_strides)],
                    mask_rois)
myownskyW7's avatar
myownskyW7 committed
136
137
                if self.with_shared_head:
                    mask_feats = self.shared_head(mask_feats)
Kai Chen's avatar
Kai Chen committed
138
139
140
141
                mask_pred = self.mask_head(mask_feats)
                # convert to numpy array to save memory
                aug_masks.append(mask_pred.sigmoid().cpu().numpy())
            merged_masks = merge_aug_masks(aug_masks, img_metas,
142
143
144
                                           self.test_cfg.rcnn)

            ori_shape = img_metas[0][0]['ori_shape']
Kai Chen's avatar
Kai Chen committed
145
            segm_result = self.mask_head.get_seg_masks(
pangjm's avatar
pangjm committed
146
147
148
149
150
151
152
                merged_masks,
                det_bboxes,
                det_labels,
                self.test_cfg.rcnn,
                ori_shape,
                scale_factor=1.0,
                rescale=False)
Kai Chen's avatar
Kai Chen committed
153
        return segm_result