Commit 427c8902 authored by pangjm's avatar pangjm
Browse files

add Faster RCNN & Mask RCNN training API and some test related

parent 65642939
from .geometry import bbox_overlaps from .geometry import bbox_overlaps
from .sampling import (random_choice, bbox_assign, bbox_assign_via_overlaps, from .sampling import (random_choice, bbox_assign, bbox_assign_via_overlaps,
bbox_sampling, sample_positives, sample_negatives, bbox_sampling, sample_positives, sample_negatives)
sample_proposals)
from .transforms import (bbox_transform, bbox_transform_inv, bbox_flip, from .transforms import (bbox_transform, bbox_transform_inv, bbox_flip,
bbox_mapping, bbox_mapping_back, bbox2roi, roi2bbox, bbox_mapping, bbox_mapping_back, bbox2roi, roi2bbox,
bbox2result) bbox2result)
...@@ -12,5 +11,5 @@ __all__ = [ ...@@ -12,5 +11,5 @@ __all__ = [
'bbox_assign_via_overlaps', 'bbox_sampling', 'sample_positives', 'bbox_assign_via_overlaps', 'bbox_sampling', 'sample_positives',
'sample_negatives', 'bbox_transform', 'bbox_transform_inv', 'bbox_flip', 'sample_negatives', 'bbox_transform', 'bbox_transform_inv', 'bbox_flip',
'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result',
'bbox_target', 'sample_proposals' 'bbox_target'
] ]
...@@ -58,7 +58,7 @@ def mask_cross_entropy(pred, target, label): ...@@ -58,7 +58,7 @@ def mask_cross_entropy(pred, target, label):
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1) pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits( return F.binary_cross_entropy_with_logits(
pred_slice, target, reduction='sum')[None] pred_slice, target, reduction='elementwise_mean')[None]
def weighted_mask_cross_entropy(pred, target, weight, label): def weighted_mask_cross_entropy(pred, target, weight, label):
......
from .segms import (flip_segms, polys_to_mask, mask_to_bbox, from .segms import (flip_segms, polys_to_mask, mask_to_bbox,
polys_to_mask_wrt_box, polys_to_boxes, rle_mask_voting, polys_to_mask_wrt_box, polys_to_boxes, rle_mask_voting,
rle_mask_nms, rle_masks_to_boxes) rle_mask_nms, rle_masks_to_boxes)
from .utils import split_combined_gt_polys from .utils import split_combined_polys
from .mask_target import mask_target from .mask_target import mask_target
__all__ = [ __all__ = [
'flip_segms', 'polys_to_mask', 'mask_to_bbox', 'polys_to_mask_wrt_box', 'flip_segms', 'polys_to_mask', 'mask_to_bbox', 'polys_to_mask_wrt_box',
'polys_to_boxes', 'rle_mask_voting', 'rle_mask_nms', 'rle_masks_to_boxes', 'polys_to_boxes', 'rle_mask_voting', 'rle_mask_nms', 'rle_masks_to_boxes',
'split_combined_gt_polys', 'mask_target' 'split_combined_polys', 'mask_target'
] ]
...@@ -4,27 +4,31 @@ import numpy as np ...@@ -4,27 +4,31 @@ import numpy as np
from .segms import polys_to_mask_wrt_box from .segms import polys_to_mask_wrt_box
def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_polys_list, def mask_target(pos_proposals_list,
img_meta, cfg): pos_assigned_gt_inds_list,
gt_polys_list,
img_meta,
cfg):
cfg_list = [cfg for _ in range(len(pos_proposals_list))] cfg_list = [cfg for _ in range(len(pos_proposals_list))]
img_metas = [img_meta for _ in range(len(pos_proposals_list))]
mask_targets = map(mask_target_single, pos_proposals_list, mask_targets = map(mask_target_single, pos_proposals_list,
pos_assigned_gt_inds_list, gt_polys_list, img_metas, pos_assigned_gt_inds_list, gt_polys_list, img_meta,
cfg_list) cfg_list)
mask_targets = torch.cat(tuple(mask_targets), dim=0) mask_targets = torch.cat(tuple(mask_targets), dim=0)
return mask_targets return mask_targets
def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_polys, def mask_target_single(pos_proposals,
img_meta, cfg): pos_assigned_gt_inds,
gt_polys,
img_meta,
cfg):
mask_size = cfg.mask_size mask_size = cfg.mask_size
num_pos = pos_proposals.size(0) num_pos = pos_proposals.size(0)
mask_targets = pos_proposals.new_zeros((num_pos, mask_size, mask_size)) mask_targets = pos_proposals.new_zeros((num_pos, mask_size, mask_size))
if num_pos > 0: if num_pos > 0:
pos_proposals = pos_proposals.cpu().numpy() pos_proposals = pos_proposals.cpu().numpy()
pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
scale_factor = img_meta['scale_factor'][0].cpu().numpy() scale_factor = img_meta['scale_factor']
for i in range(num_pos): for i in range(num_pos):
bbox = pos_proposals[i, :] / scale_factor bbox = pos_proposals[i, :] / scale_factor
polys = gt_polys[pos_assigned_gt_inds[i]] polys = gt_polys[pos_assigned_gt_inds[i]]
......
import mmcv import mmcv
def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask): def split_combined_polys(polys, poly_lens, polys_per_mask):
"""Split the combined 1-D polys into masks. """Split the combined 1-D polys into masks.
A mask is represented as a list of polys, and a poly is represented as A mask is represented as a list of polys, and a poly is represented as
...@@ -9,9 +9,9 @@ def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask): ...@@ -9,9 +9,9 @@ def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask):
tensor. Here we need to split the tensor into original representations. tensor. Here we need to split the tensor into original representations.
Args: Args:
gt_polys (list): a list (length = image num) of 1-D tensors polys (list): a list (length = image num) of 1-D tensors
gt_poly_lens (list): a list (length = image num) of poly length poly_lens (list): a list (length = image num) of poly length
num_polys_per_mask (list): a list (length = image num) of poly number polys_per_mask (list): a list (length = image num) of poly number
of each mask of each mask
Returns: Returns:
...@@ -19,13 +19,12 @@ def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask): ...@@ -19,13 +19,12 @@ def split_combined_gt_polys(gt_polys, gt_poly_lens, num_polys_per_mask):
list (length = poly num) of numpy array list (length = poly num) of numpy array
""" """
mask_polys_list = [] mask_polys_list = []
for img_id in range(len(gt_polys)): for img_id in range(len(polys)):
gt_polys_single = gt_polys[img_id].cpu().numpy() polys_single = polys[img_id]
gt_polys_lens_single = gt_poly_lens[img_id].cpu().numpy().tolist() polys_lens_single = poly_lens[img_id].tolist()
num_polys_per_mask_single = num_polys_per_mask[ polys_per_mask_single = polys_per_mask[img_id].tolist()
img_id].cpu().numpy().tolist()
split_gt_polys = mmcv.slice_list(gt_polys_single, gt_polys_lens_single) split_polys = mmcv.slice_list(polys_single, polys_lens_single)
mask_polys = mmcv.slice_list(split_gt_polys, num_polys_per_mask_single) mask_polys = mmcv.slice_list(split_polys, polys_per_mask_single)
mask_polys_list.append(mask_polys) mask_polys_list.append(mask_polys)
return mask_polys_list return mask_polys_list
from .base import BaseDetector from .base import BaseDetector
from .rpn import RPN from .rpn import RPN
from .faster_rcnn import FasterRCNN
from .mask_rcnn import MaskRCNN
__all__ = ['BaseDetector', 'RPN'] __all__ = ['BaseDetector', 'RPN', 'FasterRCNN', 'MaskRCNN']
from .two_stage import TwoStageDetector
class FasterRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
rpn_head,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
pretrained=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
from .two_stage import TwoStageDetector
class MaskRCNN(TwoStageDetector):
def __init__(self,
backbone,
neck,
rpn_head,
bbox_roi_extractor,
bbox_head,
mask_roi_extractor,
mask_head,
train_cfg,
test_cfg,
pretrained=None):
super(MaskRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
mask_roi_extractor=mask_roi_extractor,
mask_head=mask_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
import torch import torch
import torch.nn as nn import torch.nn as nn
from .base import Detector from .base import BaseDetector
from .testing_mixins import RPNTestMixin, BBoxTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from mmdet.core import bbox2roi, bbox2result, sample_proposals from mmdet.core import bbox2roi, bbox2result, split_combined_polys, multi_apply
class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin): class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):
def __init__(self, def __init__(self,
backbone, backbone,
...@@ -15,13 +16,16 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin): ...@@ -15,13 +16,16 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
rpn_head=None, rpn_head=None,
bbox_roi_extractor=None, bbox_roi_extractor=None,
bbox_head=None, bbox_head=None,
mask_roi_extractor=None,
mask_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None):
super(Detector, self).__init__() super(TwoStageDetector, self).__init__()
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
self.with_neck = True if neck is not None else False self.with_neck = True if neck is not None else False
assert self.with_neck, "TwoStageDetector must be implemented with FPN now."
if self.with_neck: if self.with_neck:
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
...@@ -35,6 +39,12 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin): ...@@ -35,6 +39,12 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
bbox_roi_extractor) bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head) self.bbox_head = builder.build_bbox_head(bbox_head)
self.with_mask = True if mask_head is not None else False
if self.with_mask:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
...@@ -68,6 +78,7 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin): ...@@ -68,6 +78,7 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
gt_bboxes, gt_bboxes,
gt_bboxes_ignore, gt_bboxes_ignore,
gt_labels, gt_labels,
gt_masks=None,
proposals=None): proposals=None):
losses = dict() losses = dict()
...@@ -80,54 +91,73 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin): ...@@ -80,54 +91,73 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs) rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
(pos_inds, neg_inds, pos_proposals, neg_proposals, if self.with_bbox:
pos_assigned_gt_inds, rcnn_train_cfg_list = [
pos_gt_bboxes, pos_gt_labels) = sample_proposals( self.train_cfg.rcnn for _ in range(len(proposal_list))
proposal_list, gt_bboxes, gt_bboxes_ignore, gt_labels, ]
self.train_cfg.rcnn) (pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes,
pos_gt_labels) = multi_apply(
labels, label_weights, bbox_targets, bbox_weights = \ self.bbox_roi_extractor.sample_proposals, proposal_list,
self.bbox_head.get_bbox_target( gt_bboxes, gt_bboxes_ignore, gt_labels, rcnn_train_cfg_list)
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels, labels, label_weights, bbox_targets, bbox_weights = \
self.bbox_head.get_bbox_target(pos_proposals, neg_proposals,
pos_gt_bboxes, pos_gt_labels, self.train_cfg.rcnn)
rois = bbox2roi([
torch.cat([pos, neg], dim=0)
for pos, neg in zip(pos_proposals, neg_proposals)
])
# TODO: a more flexible way to configurate feat maps
roi_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels,
label_weights, bbox_targets,
bbox_weights)
losses.update(loss_bbox)
if self.with_mask:
gt_polys = split_combined_polys(**gt_masks)
mask_targets = self.mask_head.get_mask_target(
pos_proposals, pos_assigned_gt_inds, gt_polys, img_meta,
self.train_cfg.rcnn) self.train_cfg.rcnn)
pos_rois = bbox2roi(pos_proposals)
rois = bbox2roi([ mask_feats = self.mask_roi_extractor(
torch.cat([pos, neg], dim=0) x[:self.mask_roi_extractor.num_inputs], pos_rois)
for pos, neg in zip(pos_proposals, neg_proposals) mask_pred = self.mask_head(mask_feats)
]) loss_mask = self.mask_head.loss(mask_pred, mask_targets,
# TODO: a more flexible way to configurate feat maps torch.cat(pos_gt_labels))
roi_feats = self.bbox_roi_extractor( losses.update(loss_mask)
x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels,
label_weights, bbox_targets,
bbox_weights)
losses.update(loss_bbox)
return losses return losses
def simple_test(self, img, img_meta, proposals=None, rescale=False): def simple_test(self, img, img_meta, proposals=None, rescale=False):
"""Test without augmentation.""" """Test without augmentation."""
assert proposals == None, "Fast RCNN hasn't been implemented."
assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img) x = self.extract_feat(img)
if proposals is None:
proposals = self.simple_test_rpn(x, img_meta) proposal_list = self.simple_test_rpn(
if self.with_bbox: x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
# BUG proposals shape?
det_bboxes, det_labels = self.simple_test_bboxes( det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, [proposals], rescale=rescale) x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
bbox_result = bbox2result(det_bboxes, det_labels, bbox_results = bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes) self.bbox_head.num_classes)
return bbox_result
if self.with_mask:
segm_results = self.simple_test_mask(
x, img_meta, det_bboxes, det_labels, rescale=rescale)
return bbox_results, segm_results
else: else:
proposals[:, :4] /= img_meta['scale_factor'].float() return bbox_results
return proposals.cpu().numpy()
def aug_test(self, imgs, img_metas, rescale=False): def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations. """Test with augmentations.
...@@ -135,15 +165,28 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin): ...@@ -135,15 +165,28 @@ class TwoStageDetector(Detector, RPNTestMixin, BBoxTestMixin):
If rescale is False, then returned bboxes and masks will fit the scale If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0]. of imgs[0].
""" """
proposals = self.aug_test_rpn( # recompute self.extract_feats(imgs) because of 'yield' and memory
self.extract_feats(imgs), img_metas, self.rpn_test_cfg) proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
det_bboxes, det_labels = self.aug_test_bboxes( det_bboxes, det_labels = self.aug_test_bboxes(
self.extract_feats(imgs), img_metas, proposals, self.rcnn_test_cfg) self.extract_feats(imgs), img_metas, proposal_list,
self.test_cfg.rcnn)
if rescale: if rescale:
_det_bboxes = det_bboxes _det_bboxes = det_bboxes
else: else:
_det_bboxes = det_bboxes.clone() _det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= img_metas[0]['shape_scale'][0][-1] _det_bboxes[:, :4] *= img_metas[0][0]['scale_factor']
bbox_result = bbox2result(_det_bboxes, det_labels, bbox_results = bbox2result(_det_bboxes, det_labels,
self.bbox_head.num_classes) self.bbox_head.num_classes)
return bbox_result
# det_bboxes always keep the original scale
if self.with_mask:
segm_results = self.aug_test_mask(
self.extract_feats(imgs),
img_metas,
det_bboxes,
det_labels)
return bbox_results, segm_results
else:
return bbox_results
...@@ -93,11 +93,13 @@ class FCNMaskHead(nn.Module): ...@@ -93,11 +93,13 @@ class FCNMaskHead(nn.Module):
return mask_targets return mask_targets
def loss(self, mask_pred, mask_targets, labels): def loss(self, mask_pred, mask_targets, labels):
loss = dict()
loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels) loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
return loss_mask loss['loss_mask'] = loss_mask
return loss
def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_scale): ori_shape):
"""Get segmentation masks from mask_pred and bboxes """Get segmentation masks from mask_pred and bboxes
Args: Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w). mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
...@@ -108,7 +110,7 @@ class FCNMaskHead(nn.Module): ...@@ -108,7 +110,7 @@ class FCNMaskHead(nn.Module):
det_labels (Tensor): shape (n, ) det_labels (Tensor): shape (n, )
img_shape (Tensor): shape (3, ) img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config rcnn_test_cfg (dict): rcnn testing config
rescale (bool): whether rescale masks to original image size ori_shape: original image size
Returns: Returns:
list[list]: encoded masks list[list]: encoded masks
""" """
...@@ -118,8 +120,8 @@ class FCNMaskHead(nn.Module): ...@@ -118,8 +120,8 @@ class FCNMaskHead(nn.Module):
cls_segms = [[] for _ in range(self.num_classes - 1)] cls_segms = [[] for _ in range(self.num_classes - 1)]
bboxes = det_bboxes.cpu().numpy()[:, :4] bboxes = det_bboxes.cpu().numpy()[:, :4]
labels = det_labels.cpu().numpy() + 1 labels = det_labels.cpu().numpy() + 1
img_h = ori_scale[0] img_h = ori_shape[0]
img_w = ori_scale[1] img_w = ori_shape[1]
for i in range(bboxes.shape[0]): for i in range(bboxes.shape[0]):
bbox = bboxes[i, :].astype(int) bbox = bboxes[i, :].astype(int)
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmdet import ops from mmdet import ops
from mmdet.core import bbox_assign, bbox_sampling
class SingleLevelRoI(nn.Module): class SingleLevelRoI(nn.Module):
...@@ -51,6 +52,36 @@ class SingleLevelRoI(nn.Module): ...@@ -51,6 +52,36 @@ class SingleLevelRoI(nn.Module):
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls return target_lvls
def sample_proposals(self, proposals, gt_bboxes, gt_crowds, gt_labels,
cfg):
proposals = proposals[:, :4]
assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
bbox_assign(proposals, gt_bboxes, gt_crowds, gt_labels,
cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.pos_iou_thr, cfg.crowd_thr)
if cfg.add_gt_as_proposals:
proposals = torch.cat([gt_bboxes, proposals], dim=0)
gt_assign_self = torch.arange(
1,
len(gt_labels) + 1,
dtype=torch.long,
device=proposals.device)
assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
assigned_labels = torch.cat([gt_labels, assigned_labels])
pos_inds, neg_inds = bbox_sampling(
assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction,
cfg.neg_pos_ub, cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
pos_proposals = proposals[pos_inds]
neg_proposals = proposals[neg_inds]
pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
pos_gt_labels = assigned_labels[pos_inds]
return (pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes, pos_gt_labels)
def forward(self, feats, rois): def forward(self, feats, rois):
"""Extract roi features with the roi layer. If multiple feature levels """Extract roi features with the roi layer. If multiple feature levels
are used, then rois are mapped to corresponding levels according to are used, then rois are mapped to corresponding levels according to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment