from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_proposals, merge_aug_bboxes, merge_aug_masks, multiclass_nms) class RPNTestMixin(object): def simple_test_rpn(self, x, img_meta, rpn_test_cfg): rpn_outs = self.rpn_head(x) proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg) proposal_list = self.rpn_head.get_proposals(*proposal_inputs) return proposal_list def aug_test_rpn(self, feats, img_metas, rpn_test_cfg): imgs_per_gpu = len(img_metas[0]) aug_proposals = [[] for _ in range(imgs_per_gpu)] for x, img_meta in zip(feats, img_metas): proposal_list = self.simple_test_rpn(x, img_meta, rpn_test_cfg) for i, proposals in enumerate(proposal_list): aug_proposals[i].append(proposals) # after merging, proposals will be rescaled to the original image size merged_proposals = [ merge_aug_proposals(proposals, img_meta, rpn_test_cfg) for proposals, img_meta in zip(aug_proposals, img_metas) ] return merged_proposals class BBoxTestMixin(object): def simple_test_bboxes(self, x, img_meta, proposals, rcnn_test_cfg, rescale=False): """Test only det bboxes without augmentation.""" rois = bbox2roi(proposals) roi_feats = self.bbox_roi_extractor( x[:len(self.bbox_roi_extractor.featmap_strides)], rois) cls_score, bbox_pred = self.bbox_head(roi_feats) img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] det_bboxes, det_labels = self.bbox_head.get_det_bboxes( rois, cls_score, bbox_pred, img_shape, scale_factor, rescale=rescale, nms_cfg=rcnn_test_cfg) return det_bboxes, det_labels def aug_test_bboxes(self, feats, img_metas, proposals, rcnn_test_cfg): aug_bboxes = [] aug_scores = [] for x, img_meta in zip(feats, img_metas): # only one image in the batch img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] flip = img_meta[0]['flip'] proposals = bbox_mapping(proposals[:, :4], img_shape, scale_factor, flip) rois = bbox2roi([proposals]) # recompute feature maps to save GPU memory roi_feats = self.bbox_roi_extractor( x[:len(self.bbox_roi_extractor.featmap_strides)], rois) cls_score, bbox_pred = self.bbox_head(roi_feats) bboxes, scores = self.bbox_head.get_det_bboxes( rois, cls_score, bbox_pred, img_shape, rescale=False, nms_cfg=None) aug_bboxes.append(bboxes) aug_scores.append(scores) # after merging, bboxes will be rescaled to the original image size merged_bboxes, merged_scores = merge_aug_bboxes( aug_bboxes, aug_scores, img_metas, self.rcnn_test_cfg) det_bboxes, det_labels = multiclass_nms( merged_bboxes, merged_scores, self.rcnn_test_cfg.score_thr, self.rcnn_test_cfg.nms_thr, self.rcnn_test_cfg.max_per_img) return det_bboxes, det_labels class MaskTestMixin(object): def simple_test_mask(self, x, img_meta, det_bboxes, det_labels, rescale=False): # image shape of the first image in the batch (only one) img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] if det_bboxes.shape[0] == 0: segm_result = [[] for _ in range(self.mask_head.num_classes - 1)] else: # if det_bboxes is rescaled to the original image size, we need to # rescale it back to the testing scale to obtain RoIs. _bboxes = (det_bboxes[:, :4] * scale_factor if rescale else det_bboxes) mask_rois = bbox2roi([_bboxes]) mask_feats = self.mask_roi_extractor( x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois) mask_pred = self.mask_head(mask_feats) segm_result = self.mask_head.get_seg_masks( mask_pred, det_bboxes, det_labels, img_shape, self.rcnn_test_cfg, rescale) return segm_result def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels, rescale=False): if rescale: _det_bboxes = det_bboxes else: _det_bboxes = det_bboxes.clone() _det_bboxes[:, :4] *= img_metas[0][0]['scale_factor'] if det_bboxes.shape[0] == 0: segm_result = [[] for _ in range(self.mask_head.num_classes - 1)] else: aug_masks = [] for x, img_meta in zip(feats, img_metas): img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] flip = img_meta[0]['flip'] _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, scale_factor, flip) mask_rois = bbox2roi([_bboxes]) mask_feats = self.mask_roi_extractor( x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois) mask_pred = self.mask_head(mask_feats) # convert to numpy array to save memory aug_masks.append(mask_pred.sigmoid().cpu().numpy()) merged_masks = merge_aug_masks(aug_masks, img_metas, self.rcnn_test_cfg) segm_result = self.mask_head.get_seg_masks( merged_masks, _det_bboxes, det_labels, img_metas[0]['shape_scale'][0], self.rcnn_test_cfg, rescale) return segm_result