from __future__ import division import torch import torch.nn as nn from .base import BaseDetector from .test_mixins import RPNTestMixin from .. import builder from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply, merge_aug_masks) class CascadeRCNN(BaseDetector, RPNTestMixin): def __init__(self, num_stages, backbone, neck=None, rpn_head=None, bbox_roi_extractor=None, bbox_head=None, mask_roi_extractor=None, mask_head=None, train_cfg=None, test_cfg=None, pretrained=None): assert bbox_roi_extractor is not None assert bbox_head is not None super(CascadeRCNN, self).__init__() self.num_stages = num_stages self.backbone = builder.build_backbone(backbone) if neck is not None: self.neck = builder.build_neck(neck) else: raise NotImplementedError if rpn_head is not None: self.rpn_head = builder.build_rpn_head(rpn_head) if bbox_head is not None: self.bbox_roi_extractor = nn.ModuleList() self.bbox_head = nn.ModuleList() if not isinstance(bbox_roi_extractor, list): bbox_roi_extractor = [ bbox_roi_extractor for _ in range(num_stages) ] if not isinstance(bbox_head, list): bbox_head = [bbox_head for _ in range(num_stages)] assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages for roi_extractor, head in zip(bbox_roi_extractor, bbox_head): self.bbox_roi_extractor.append( builder.build_roi_extractor(roi_extractor)) self.bbox_head.append(builder.build_bbox_head(head)) if mask_head is not None: self.mask_roi_extractor = nn.ModuleList() self.mask_head = nn.ModuleList() if not isinstance(mask_roi_extractor, list): mask_roi_extractor = [ mask_roi_extractor for _ in range(num_stages) ] if not isinstance(mask_head, list): mask_head = [mask_head for _ in range(num_stages)] assert len(mask_roi_extractor) == len(mask_head) == self.num_stages for roi_extractor, head in zip(mask_roi_extractor, mask_head): self.mask_roi_extractor.append( builder.build_roi_extractor(roi_extractor)) self.mask_head.append(builder.build_mask_head(head)) self.train_cfg = train_cfg self.test_cfg = test_cfg self.init_weights(pretrained=pretrained) @property def with_rpn(self): return hasattr(self, 'rpn_head') and self.rpn_head is not None def init_weights(self, pretrained=None): super(CascadeRCNN, self).init_weights(pretrained) self.backbone.init_weights(pretrained=pretrained) if self.with_neck: if isinstance(self.neck, nn.Sequential): for m in self.neck: m.init_weights() else: self.neck.init_weights() if self.with_rpn: self.rpn_head.init_weights() for i in range(self.num_stages): if self.with_bbox: self.bbox_roi_extractor[i].init_weights() self.bbox_head[i].init_weights() if self.with_mask: self.mask_roi_extractor[i].init_weights() self.mask_head[i].init_weights() def extract_feat(self, img): x = self.backbone(img) if self.with_neck: x = self.neck(x) return x def forward_train(self, img, img_meta, gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks=None, proposals=None): x = self.extract_feat(img) losses = dict() if self.with_rpn: rpn_outs = self.rpn_head(x) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn) rpn_losses = self.rpn_head.loss(*rpn_loss_inputs) losses.update(rpn_losses) proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn) proposal_list = self.rpn_head.get_proposals(*proposal_inputs) else: proposal_list = proposals for i in range(self.num_stages): rcnn_train_cfg = self.train_cfg.rcnn[i] lw = self.train_cfg.stage_loss_weights[i] # assign gts and sample proposals assign_results, sampling_results = multi_apply( assign_and_sample, proposal_list, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg=rcnn_train_cfg) # bbox head forward and loss bbox_roi_extractor = self.bbox_roi_extractor[i] bbox_head = self.bbox_head[i] rois = bbox2roi([res.bboxes for res in sampling_results]) bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], rois) cls_score, bbox_pred = bbox_head(bbox_feats) bbox_targets = bbox_head.get_target(sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg) loss_bbox = bbox_head.loss(cls_score, bbox_pred, *bbox_targets) for name, value in loss_bbox.items(): losses['s{}.{}'.format(i, name)] = (value * lw if 'loss' in name else value) # mask head forward and loss if self.with_mask: mask_roi_extractor = self.mask_roi_extractor[i] mask_head = self.mask_head[i] pos_rois = bbox2roi( [res.pos_bboxes for res in sampling_results]) mask_feats = mask_roi_extractor( x[:mask_roi_extractor.num_inputs], pos_rois) mask_pred = mask_head(mask_feats) mask_targets = mask_head.get_target(sampling_results, gt_masks, rcnn_train_cfg) pos_labels = torch.cat( [res.pos_gt_labels for res in sampling_results]) loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels) for name, value in loss_mask.items(): losses['s{}.{}'.format(i, name)] = (value * lw if 'loss' in name else value) # refine bboxes if i < self.num_stages - 1: pos_is_gts = [res.pos_is_gt for res in sampling_results] roi_labels = bbox_targets[0] # bbox_targets is a tuple with torch.no_grad(): proposal_list = bbox_head.refine_bboxes( rois, roi_labels, bbox_pred, pos_is_gts, img_meta) return losses def simple_test(self, img, img_meta, proposals=None, rescale=False): x = self.extract_feat(img) proposal_list = self.simple_test_rpn( x, img_meta, self.test_cfg.rpn) if proposals is None else proposals img_shape = img_meta[0]['img_shape'] ori_shape = img_meta[0]['ori_shape'] scale_factor = img_meta[0]['scale_factor'] # "ms" in variable names means multi-stage ms_bbox_result = {} ms_segm_result = {} ms_scores = [] rcnn_test_cfg = self.test_cfg.rcnn rois = bbox2roi(proposal_list) for i in range(self.num_stages): bbox_roi_extractor = self.bbox_roi_extractor[i] bbox_head = self.bbox_head[i] bbox_feats = bbox_roi_extractor( x[:len(bbox_roi_extractor.featmap_strides)], rois) cls_score, bbox_pred = bbox_head(bbox_feats) ms_scores.append(cls_score) if self.test_cfg.keep_all_stages: det_bboxes, det_labels = bbox_head.get_det_bboxes( rois, cls_score, bbox_pred, img_shape, scale_factor, rescale=rescale, nms_cfg=rcnn_test_cfg) bbox_result = bbox2result(det_bboxes, det_labels, bbox_head.num_classes) ms_bbox_result['stage{}'.format(i)] = bbox_result if self.with_mask: mask_roi_extractor = self.mask_roi_extractor[i] mask_head = self.mask_head[i] if det_bboxes.shape[0] == 0: segm_result = [ [] for _ in range(mask_head.num_classes - 1) ] else: _bboxes = (det_bboxes[:, :4] * scale_factor if rescale else det_bboxes) mask_rois = bbox2roi([_bboxes]) mask_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], mask_rois) mask_pred = mask_head(mask_feats) segm_result = mask_head.get_seg_masks( mask_pred, _bboxes, det_labels, rcnn_test_cfg, ori_shape, scale_factor, rescale) ms_segm_result['stage{}'.format(i)] = segm_result if i < self.num_stages - 1: bbox_label = cls_score.argmax(dim=1) rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred, img_meta[0]) cls_score = sum(ms_scores) / self.num_stages det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes( rois, cls_score, bbox_pred, img_shape, scale_factor, rescale=rescale, nms_cfg=rcnn_test_cfg) bbox_result = bbox2result(det_bboxes, det_labels, self.bbox_head[-1].num_classes) ms_bbox_result['ensemble'] = bbox_result if self.with_mask: if det_bboxes.shape[0] == 0: segm_result = [ [] for _ in range(self.mask_head[-1].num_classes - 1) ] else: _bboxes = (det_bboxes[:, :4] * scale_factor if rescale else det_bboxes) mask_rois = bbox2roi([_bboxes]) aug_masks = [] for i in range(self.num_stages): mask_roi_extractor = self.mask_roi_extractor[i] mask_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], mask_rois) mask_pred = self.mask_head[i](mask_feats) aug_masks.append(mask_pred.sigmoid().cpu().numpy()) merged_masks = merge_aug_masks(aug_masks, [img_meta] * self.num_stages, self.test_cfg.rcnn) segm_result = self.mask_head[-1].get_seg_masks( merged_masks, _bboxes, det_labels, rcnn_test_cfg, ori_shape, scale_factor, rescale) ms_segm_result['ensemble'] = segm_result if not self.test_cfg.keep_all_stages: if self.with_mask: results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble']) else: results = ms_bbox_result['ensemble'] else: if self.with_mask: results = { stage: (ms_bbox_result[stage], ms_segm_result[stage]) for stage in ms_bbox_result } else: results = ms_bbox_result return results def aug_test(self, img, img_meta, proposals=None, rescale=False): raise NotImplementedError def show_result(self, data, result, img_norm_cfg, **kwargs): # TODO: show segmentation masks if self.with_mask: ms_bbox_result, ms_segm_result = result else: ms_bbox_result = result if isinstance(ms_bbox_result, dict): bbox_result = ms_bbox_result['ensemble'] else: bbox_result = ms_bbox_result super(CascadeRCNN, self).show_result(data, bbox_result, img_norm_cfg, **kwargs)