Commit af55b977 authored by Kai Chen's avatar Kai Chen
Browse files

use dict to save multi-stage results

parent 0bb723a5
...@@ -142,7 +142,7 @@ train_cfg = dict( ...@@ -142,7 +142,7 @@ train_cfg = dict(
pos_weight=-1, pos_weight=-1,
debug=False) debug=False)
], ],
loss_weight=[1, 0.5, 0.4]) stage_loss_weights=[1, 0.5, 0.25])
test_cfg = dict( test_cfg = dict(
rpn=dict( rpn=dict(
nms_across_levels=False, nms_across_levels=False,
......
...@@ -128,7 +128,7 @@ train_cfg = dict( ...@@ -128,7 +128,7 @@ train_cfg = dict(
pos_weight=-1, pos_weight=-1,
debug=False) debug=False)
], ],
loss_weight=[1, 0.5, 0.4]) stage_loss_weights=[1, 0.5, 0.25])
test_cfg = dict( test_cfg = dict(
rpn=dict( rpn=dict(
nms_across_levels=False, nms_across_levels=False,
......
from __future__ import division
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -127,7 +129,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -127,7 +129,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
for i in range(self.num_stages): for i in range(self.num_stages):
rcnn_train_cfg = self.train_cfg.rcnn[i] rcnn_train_cfg = self.train_cfg.rcnn[i]
lw = self.train_cfg.loss_weight[i] lw = self.train_cfg.stage_loss_weights[i]
# assign gts and sample proposals # assign gts and sample proposals
assign_results, sampling_results = multi_apply( assign_results, sampling_results = multi_apply(
...@@ -193,8 +195,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -193,8 +195,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
scale_factor = img_meta[0]['scale_factor'] scale_factor = img_meta[0]['scale_factor']
# "ms" in variable names means multi-stage # "ms" in variable names means multi-stage
ms_bbox_result = [] ms_bbox_result = {}
ms_segm_result = [] ms_segm_result = {}
ms_scores = [] ms_scores = []
rcnn_test_cfg = self.test_cfg.rcnn rcnn_test_cfg = self.test_cfg.rcnn
...@@ -219,11 +221,11 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -219,11 +221,11 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
nms_cfg=rcnn_test_cfg) nms_cfg=rcnn_test_cfg)
bbox_result = bbox2result(det_bboxes, det_labels, bbox_result = bbox2result(det_bboxes, det_labels,
bbox_head.num_classes) bbox_head.num_classes)
ms_bbox_result.append(bbox_result) ms_bbox_result['stage{}'.format(i)] = bbox_result
if self.with_mask: if self.with_mask:
mask_block = self.mask_blocks[i] mask_roi_extractor = self.mask_roi_extractor[i]
mask_head = self.mask_heads[i] mask_head = self.mask_head[i]
if det_bboxes.shape[0] == 0: if det_bboxes.shape[0] == 0:
segm_result = [ segm_result = [
[] for _ in range(mask_head.num_classes - 1) [] for _ in range(mask_head.num_classes - 1)
...@@ -232,20 +234,21 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -232,20 +234,21 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
_bboxes = (det_bboxes[:, :4] * scale_factor _bboxes = (det_bboxes[:, :4] * scale_factor
if rescale else det_bboxes) if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes]) mask_rois = bbox2roi([_bboxes])
mask_feats = mask_block( mask_feats = mask_roi_extractor(
x[:len(mask_block.featmap_strides)], mask_rois) x[:len(mask_roi_extractor.featmap_strides)],
mask_rois)
mask_pred = mask_head(mask_feats) mask_pred = mask_head(mask_feats)
segm_result = mask_head.get_seg_masks( segm_result = mask_head.get_seg_masks(
mask_pred, _bboxes, det_labels, rcnn_test_cfg, mask_pred, _bboxes, det_labels, rcnn_test_cfg,
ori_shape, scale_factor, rescale) ori_shape, scale_factor, rescale)
ms_segm_result.append(segm_result) ms_segm_result['stage{}'.format(i)] = segm_result
if i < self.num_stages - 1: if i < self.num_stages - 1:
bbox_label = cls_score.argmax(dim=1) bbox_label = cls_score.argmax(dim=1)
rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred, rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
img_meta[0]) img_meta[0])
cls_score = sum(ms_scores) / float(len(ms_scores)) cls_score = sum(ms_scores) / len(ms_scores)
det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes( det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
rois, rois,
cls_score, cls_score,
...@@ -256,7 +259,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -256,7 +259,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
nms_cfg=rcnn_test_cfg) nms_cfg=rcnn_test_cfg)
bbox_result = bbox2result(det_bboxes, det_labels, bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head[-1].num_classes) self.bbox_head[-1].num_classes)
ms_bbox_result.append(bbox_result) ms_bbox_result['ensemble'] = bbox_result
if self.with_mask: if self.with_mask:
if det_bboxes.shape[0] == 0: if det_bboxes.shape[0] == 0:
...@@ -280,12 +283,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -280,12 +283,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
segm_result = self.mask_head[-1].get_seg_masks( segm_result = self.mask_head[-1].get_seg_masks(
merged_masks, _bboxes, det_labels, rcnn_test_cfg, merged_masks, _bboxes, det_labels, rcnn_test_cfg,
ori_shape, scale_factor, rescale) ori_shape, scale_factor, rescale)
ms_segm_result.append(segm_result) ms_segm_result['ensemble'] = segm_result
if not self.test_cfg.keep_all_stages: if not self.test_cfg.keep_all_stages:
ms_bbox_result = ms_bbox_result[0] ms_bbox_result = ms_bbox_result['ensemble']
if self.with_mask: if self.with_mask:
ms_segm_result = ms_segm_result[0] ms_segm_result = ms_segm_result['ensemble']
if not self.with_mask: if not self.with_mask:
return ms_bbox_result return ms_bbox_result
...@@ -301,5 +304,9 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -301,5 +304,9 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
ms_bbox_result, ms_segm_result = result ms_bbox_result, ms_segm_result = result
else: else:
ms_bbox_result = result ms_bbox_result = result
super(CascadeRCNN, self).show_result(data, ms_bbox_result[-1], if isinstance(ms_bbox_result, dict):
img_norm_cfg, **kwargs) bbox_result = ms_bbox_result['ensemble']
else:
bbox_result = ms_bbox_result
super(CascadeRCNN, self).show_result(data, bbox_result, img_norm_cfg,
**kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment