Commit 364698b6 authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Remove keep all stage code in HTC and Cascade RCNN (#1806)

* Remove keep all stage code

* remove keep_all_stage in config
parent 4357697a
......@@ -193,8 +193,7 @@ test_cfg = dict(
score_thr=0.001,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5),
keep_all_stages=False)
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
......
......@@ -193,8 +193,7 @@ test_cfg = dict(
score_thr=0.001,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5),
keep_all_stages=False)
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
......
......@@ -339,41 +339,6 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
cls_score, bbox_pred = bbox_head(bbox_feats)
ms_scores.append(cls_score)
if self.test_cfg.keep_all_stages:
det_bboxes, det_labels = bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=rescale,
cfg=rcnn_test_cfg)
bbox_result = bbox2result(det_bboxes, det_labels,
bbox_head.num_classes)
ms_bbox_result['stage{}'.format(i)] = bbox_result
if self.with_mask:
mask_roi_extractor = self.mask_roi_extractor[i]
mask_head = self.mask_head[i]
if det_bboxes.shape[0] == 0:
mask_classes = mask_head.num_classes - 1
segm_result = [[] for _ in range(mask_classes)]
else:
_bboxes = (
det_bboxes[:, :4] *
scale_factor if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes])
mask_feats = mask_roi_extractor(
x[:len(mask_roi_extractor.featmap_strides)],
mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats, i)
mask_pred = mask_head(mask_feats)
segm_result = mask_head.get_seg_masks(
mask_pred, _bboxes, det_labels, rcnn_test_cfg,
ori_shape, scale_factor, rescale)
ms_segm_result['stage{}'.format(i)] = segm_result
if i < self.num_stages - 1:
bbox_label = cls_score.argmax(dim=1)
rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
......@@ -425,20 +390,10 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
ori_shape, scale_factor, rescale)
ms_segm_result['ensemble'] = segm_result
if not self.test_cfg.keep_all_stages:
if self.with_mask:
results = (ms_bbox_result['ensemble'],
ms_segm_result['ensemble'])
else:
results = ms_bbox_result['ensemble']
if self.with_mask:
results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble'])
else:
if self.with_mask:
results = {
stage: (ms_bbox_result[stage], ms_segm_result[stage])
for stage in ms_bbox_result
}
else:
results = ms_bbox_result
results = ms_bbox_result['ensemble']
return results
......
......@@ -334,35 +334,6 @@ class HybridTaskCascade(CascadeRCNN):
i, x, rois, semantic_feat=semantic_feat)
ms_scores.append(cls_score)
if self.test_cfg.keep_all_stages:
det_bboxes, det_labels = bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=rescale,
cfg=rcnn_test_cfg)
bbox_result = bbox2result(det_bboxes, det_labels,
bbox_head.num_classes)
ms_bbox_result['stage{}'.format(i)] = bbox_result
if self.with_mask:
mask_head = self.mask_head[i]
if det_bboxes.shape[0] == 0:
mask_classes = mask_head.num_classes - 1
segm_result = [[] for _ in range(mask_classes)]
else:
_bboxes = (
det_bboxes[:, :4] *
scale_factor if rescale else det_bboxes)
mask_pred = self._mask_forward_test(
i, x, _bboxes, semantic_feat=semantic_feat)
segm_result = mask_head.get_seg_masks(
mask_pred, _bboxes, det_labels, rcnn_test_cfg,
ori_shape, scale_factor, rescale)
ms_segm_result['stage{}'.format(i)] = segm_result
if i < self.num_stages - 1:
bbox_label = cls_score.argmax(dim=1)
rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
......@@ -415,20 +386,10 @@ class HybridTaskCascade(CascadeRCNN):
ori_shape, scale_factor, rescale)
ms_segm_result['ensemble'] = segm_result
if not self.test_cfg.keep_all_stages:
if self.with_mask:
results = (ms_bbox_result['ensemble'],
ms_segm_result['ensemble'])
else:
results = ms_bbox_result['ensemble']
if self.with_mask:
results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble'])
else:
if self.with_mask:
results = {
stage: (ms_bbox_result[stage], ms_segm_result[stage])
for stage in ms_bbox_result
}
else:
results = ms_bbox_result
results = ms_bbox_result['ensemble']
return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment