Unverified Commit 7d343fd2 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #8 from hellock/dev

API cleaning and code refactoring (WIP)
parents 0e0b9246 630687f4
import mmcv
from mmdet.core import tensor2imgs, bbox_mapping
from .base import BaseDetector
from .test_mixins import RPNTestMixin
from .. import builder
class RPN(BaseDetector, RPNTestMixin):
def __init__(self,
backbone,
neck,
rpn_head,
train_cfg,
test_cfg,
pretrained=None):
super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_rpn_head(rpn_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
super(RPN, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
self.neck.init_weights()
self.rpn_head.init_weights()
def extract_feat(self, img):
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self, img, img_meta, gt_bboxes=None):
if self.train_cfg.rpn.get('debug', False):
self.rpn_head.debug_imgs = tensor2imgs(img)
x = self.extract_feat(img)
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn)
losses = self.rpn_head.loss(*rpn_loss_inputs)
return losses
def simple_test(self, img, img_meta, rescale=False):
x = self.extract_feat(img)
proposal_list = self.simple_test_rpn(x, img_meta, self.test_cfg.rpn)
if rescale:
for proposals, meta in zip(proposal_list, img_meta):
proposals[:, :4] /= meta['scale_factor']
# TODO: remove this restriction
return proposal_list[0].cpu().numpy()
def aug_test(self, imgs, img_metas, rescale=False):
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
if not rescale:
for proposals, img_meta in zip(proposal_list, img_metas[0]):
img_shape = img_meta['img_shape']
scale_factor = img_meta['scale_factor']
flip = img_meta['flip']
proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape,
scale_factor, flip)
# TODO: remove this restriction
return proposal_list[0].cpu().numpy()
def show_result(self, data, result, img_norm_cfg):
"""Show RPN proposals on the image.
Although we assume batch size is 1, this method supports arbitrary
batch size.
"""
img_tensor = data['img'][0]
img_metas = data['img_meta'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas)
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
mmcv.imshow_bboxes(img_show, result, top_k=20)
from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_proposals,
merge_aug_bboxes, merge_aug_masks, multiclass_nms)
class RPNTestMixin(object):
def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
return proposal_list
def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
imgs_per_gpu = len(img_metas[0])
aug_proposals = [[] for _ in range(imgs_per_gpu)]
for x, img_meta in zip(feats, img_metas):
proposal_list = self.simple_test_rpn(x, img_meta, rpn_test_cfg)
for i, proposals in enumerate(proposal_list):
aug_proposals[i].append(proposals)
# after merging, proposals will be rescaled to the original image size
merged_proposals = [
merge_aug_proposals(proposals, img_meta, rpn_test_cfg)
for proposals, img_meta in zip(aug_proposals, img_metas)
]
return merged_proposals
class BBoxTestMixin(object):
def simple_test_bboxes(self,
x,
img_meta,
proposals,
rcnn_test_cfg,
rescale=False):
"""Test only det bboxes without augmentation."""
rois = bbox2roi(proposals)
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=rescale,
nms_cfg=rcnn_test_cfg)
return det_bboxes, det_labels
def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
# TODO more flexible
proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip)
rois = bbox2roi([proposals])
# recompute feature maps to save GPU memory
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
bboxes, scores = self.bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
nms_cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, self.test_cfg.rcnn)
det_bboxes, det_labels = multiclass_nms(
merged_bboxes, merged_scores, self.test_cfg.rcnn.score_thr,
self.test_cfg.rcnn.nms_thr, self.test_cfg.rcnn.max_per_img)
return det_bboxes, det_labels
class MaskTestMixin(object):
def simple_test_mask(self,
x,
img_meta,
det_bboxes,
det_labels,
rescale=False):
# image shape of the first image in the batch (only one)
ori_shape = img_meta[0]['ori_shape']
scale_factor = img_meta[0]['scale_factor']
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
_bboxes = (det_bboxes[:, :4] * scale_factor
if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_seg_masks(
mask_pred, _bboxes, det_labels, self.test_cfg.rcnn, ori_shape,
scale_factor, rescale)
return segm_result
def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
else:
aug_masks = []
for x, img_meta in zip(feats, img_metas):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)],
mask_rois)
mask_pred = self.mask_head(mask_feats)
# convert to numpy array to save memory
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
merged_masks = merge_aug_masks(aug_masks, img_metas,
self.test_cfg.rcnn)
ori_shape = img_metas[0][0]['ori_shape']
segm_result = self.mask_head.get_seg_masks(
merged_masks, det_bboxes, det_labels, self.test_cfg.rcnn,
ori_shape)
return segm_result
import torch
import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder
from mmdet.core import sample_bboxes, bbox2roi, bbox2result, multi_apply
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):
def __init__(self,
backbone,
neck=None,
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
mask_roi_extractor=None,
mask_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(TwoStageDetector, self).__init__()
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
else:
raise NotImplementedError
if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head)
if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head)
if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
@property
def with_rpn(self):
return hasattr(self, 'rpn_head') and self.rpn_head is not None
def init_weights(self, pretrained=None):
super(TwoStageDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
if self.with_rpn:
self.rpn_head.init_weights()
if self.with_bbox:
self.bbox_roi_extractor.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, img):
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self,
img,
img_meta,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
gt_masks=None,
proposals=None):
losses = dict()
x = self.extract_feat(img)
if self.with_rpn:
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
self.train_cfg.rpn)
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
else:
proposal_list = proposals
if self.with_bbox:
(pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes,
pos_gt_labels) = multi_apply(
sample_bboxes,
proposal_list,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
cfg=self.train_cfg.rcnn)
(labels, label_weights, bbox_targets,
bbox_weights) = self.bbox_head.get_bbox_target(
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
self.train_cfg.rcnn)
rois = bbox2roi([
torch.cat([pos, neg], dim=0)
for pos, neg in zip(pos_proposals, neg_proposals)
])
# TODO: a more flexible way to configurate feat maps
roi_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels,
label_weights, bbox_targets,
bbox_weights)
losses.update(loss_bbox)
if self.with_mask:
mask_targets = self.mask_head.get_mask_target(
pos_proposals, pos_assigned_gt_inds, gt_masks,
self.train_cfg.rcnn)
pos_rois = bbox2roi(pos_proposals)
mask_feats = self.mask_roi_extractor(
x[:self.mask_roi_extractor.num_inputs], pos_rois)
mask_pred = self.mask_head(mask_feats)
loss_mask = self.mask_head.loss(mask_pred, mask_targets,
torch.cat(pos_gt_labels))
losses.update(loss_mask)
return losses
def simple_test(self, img, img_meta, proposals=None, rescale=False):
"""Test without augmentation."""
assert proposals is None, "Fast RCNN hasn't been implemented."
assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img)
proposal_list = self.simple_test_rpn(
x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
bbox_results = bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes)
if not self.with_mask:
return bbox_results
else:
segm_results = self.simple_test_mask(
x, img_meta, det_bboxes, det_labels, rescale=rescale)
return bbox_results, segm_results
def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
# recompute feats to save memory
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
det_bboxes, det_labels = self.aug_test_bboxes(
self.extract_feats(imgs), img_metas, proposal_list,
self.test_cfg.rcnn)
if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= img_metas[0][0]['scale_factor']
bbox_results = bbox2result(_det_bboxes, det_labels,
self.bbox_head.num_classes)
# det_bboxes always keep the original scale
if self.with_mask:
segm_results = self.aug_test_mask(
self.extract_feats(imgs), img_metas, det_bboxes, det_labels)
return bbox_results, segm_results
else:
return bbox_results
...@@ -87,18 +87,21 @@ class FCNMaskHead(nn.Module): ...@@ -87,18 +87,21 @@ class FCNMaskHead(nn.Module):
return mask_pred return mask_pred
def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks, def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks,
img_meta, rcnn_train_cfg): rcnn_train_cfg):
mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
gt_masks, img_meta, rcnn_train_cfg) gt_masks, rcnn_train_cfg)
return mask_targets return mask_targets
def loss(self, mask_pred, mask_targets, labels): def loss(self, mask_pred, mask_targets, labels):
loss = dict()
loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels) loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
return loss_mask loss['loss_mask'] = loss_mask
return loss
def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_scale): ori_shape, scale_factor, rescale):
"""Get segmentation masks from mask_pred and bboxes """Get segmentation masks from mask_pred and bboxes.
Args: Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w). mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
For single-scale testing, mask_pred is the direct output of For single-scale testing, mask_pred is the direct output of
...@@ -108,40 +111,44 @@ class FCNMaskHead(nn.Module): ...@@ -108,40 +111,44 @@ class FCNMaskHead(nn.Module):
det_labels (Tensor): shape (n, ) det_labels (Tensor): shape (n, )
img_shape (Tensor): shape (3, ) img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config rcnn_test_cfg (dict): rcnn testing config
rescale (bool): whether rescale masks to original image size ori_shape: original image size
Returns: Returns:
list[list]: encoded masks list[list]: encoded masks
""" """
if isinstance(mask_pred, torch.Tensor): if isinstance(mask_pred, torch.Tensor):
mask_pred = mask_pred.sigmoid().cpu().numpy() mask_pred = mask_pred.sigmoid().cpu().numpy()
assert isinstance(mask_pred, np.ndarray) assert isinstance(mask_pred, np.ndarray)
cls_segms = [[] for _ in range(self.num_classes - 1)] cls_segms = [[] for _ in range(self.num_classes - 1)]
bboxes = det_bboxes.cpu().numpy()[:, :4] bboxes = det_bboxes.cpu().numpy()[:, :4]
labels = det_labels.cpu().numpy() + 1 labels = det_labels.cpu().numpy() + 1
img_h = ori_scale[0]
img_w = ori_scale[1] if rescale:
img_h, img_w = ori_shape[:2]
else:
img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
scale_factor = 1.0
for i in range(bboxes.shape[0]): for i in range(bboxes.shape[0]):
bbox = bboxes[i, :].astype(int) bbox = (bboxes[i, :] / scale_factor).astype(np.int32)
label = labels[i] label = labels[i]
w = bbox[2] - bbox[0] + 1 w = max(bbox[2] - bbox[0] + 1, 1)
h = bbox[3] - bbox[1] + 1 h = max(bbox[3] - bbox[1] + 1, 1)
w = max(w, 1)
h = max(h, 1)
if not self.class_agnostic: if not self.class_agnostic:
mask_pred_ = mask_pred[i, label, :, :] mask_pred_ = mask_pred[i, label, :, :]
else: else:
mask_pred_ = mask_pred[i, 0, :, :] mask_pred_ = mask_pred[i, 0, :, :]
im_mask = np.zeros((img_h, img_w), dtype=np.uint8)
im_mask = np.zeros((img_h, img_w), dtype=np.float32) bbox_mask = mmcv.imresize(mask_pred_, (w, h))
bbox_mask = (bbox_mask > rcnn_test_cfg.mask_thr_binary).astype(
im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = mmcv.imresize( np.uint8)
mask_pred_, (w, h)) im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask
# im_mask = cv2.resize(im_mask, (img_w, img_h))
im_mask = np.array(
im_mask > rcnn_test_cfg.mask_thr_binary, dtype=np.uint8)
rle = mask_util.encode( rle = mask_util.encode(
np.array(im_mask[:, :, np.newaxis], order='F'))[0] np.array(im_mask[:, :, np.newaxis], order='F'))[0]
cls_segms[label - 1].append(rle) cls_segms[label - 1].append(rle)
return cls_segms return cls_segms
...@@ -101,7 +101,7 @@ class FPN(nn.Module): ...@@ -101,7 +101,7 @@ class FPN(nn.Module):
# build top-down path # build top-down path
used_backbone_levels = len(laterals) used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1): for i in range(used_backbone_levels - 1, 0, -1):
laterals[i - 1] += F.upsample( laterals[i - 1] += F.interpolate(
laterals[i], scale_factor=2, mode='nearest') laterals[i], scale_factor=2, mode='nearest')
# build outputs # build outputs
...@@ -111,7 +111,8 @@ class FPN(nn.Module): ...@@ -111,7 +111,8 @@ class FPN(nn.Module):
] ]
# part 2: add extra levels # part 2: add extra levels
if self.num_outs > len(outs): if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN) # use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs: if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels): for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2)) outs.append(F.max_pool2d(outs[-1], 1, stride=2))
......
from .single_level import SingleLevelRoI from .single_level import SingleRoIExtractor
__all__ = ['SingleLevelRoI'] __all__ = ['SingleRoIExtractor']
...@@ -6,16 +6,25 @@ import torch.nn as nn ...@@ -6,16 +6,25 @@ import torch.nn as nn
from mmdet import ops from mmdet import ops
class SingleLevelRoI(nn.Module): class SingleRoIExtractor(nn.Module):
"""Extract RoI features from a single level feature map. Each RoI is """Extract RoI features from a single level feature map.
mapped to a level according to its scale."""
If there are mulitple input feature levels, each RoI is mapped to a level
according to its scale.
Args:
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (int): Strides of input feature maps.
finest_scale (int): Scale threshold of mapping to level 0.
"""
def __init__(self, def __init__(self,
roi_layer, roi_layer,
out_channels, out_channels,
featmap_strides, featmap_strides,
finest_scale=56): finest_scale=56):
super(SingleLevelRoI, self).__init__() super(SingleRoIExtractor, self).__init__()
self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides) self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
self.out_channels = out_channels self.out_channels = out_channels
self.featmap_strides = featmap_strides self.featmap_strides = featmap_strides
...@@ -23,6 +32,7 @@ class SingleLevelRoI(nn.Module): ...@@ -23,6 +32,7 @@ class SingleLevelRoI(nn.Module):
@property @property
def num_inputs(self): def num_inputs(self):
"""int: Input feature map levels."""
return len(self.featmap_strides) return len(self.featmap_strides)
def init_weights(self): def init_weights(self):
...@@ -38,12 +48,19 @@ class SingleLevelRoI(nn.Module): ...@@ -38,12 +48,19 @@ class SingleLevelRoI(nn.Module):
return roi_layers return roi_layers
def map_roi_levels(self, rois, num_levels): def map_roi_levels(self, rois, num_levels):
"""Map rois to corresponding feature levels (0-based) by scales. """Map rois to corresponding feature levels by scales.
scale < finest_scale: level 0 - scale < finest_scale: level 0
finest_scale <= scale < finest_scale * 2: level 1 - finest_scale <= scale < finest_scale * 2: level 1
finest_scale * 2 <= scale < finest_scale * 4: level 2 - finest_scale * 2 <= scale < finest_scale * 4: level 2
scale >= finest_scale * 4: level 3 - scale >= finest_scale * 4: level 3
Args:
rois (Tensor): Input RoIs, shape (k, 5).
num_levels (int): Total level number.
Returns:
Tensor: Level index (0-based) of each RoI, shape (k, )
""" """
scale = torch.sqrt( scale = torch.sqrt(
(rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1)) (rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1))
...@@ -52,10 +69,6 @@ class SingleLevelRoI(nn.Module): ...@@ -52,10 +69,6 @@ class SingleLevelRoI(nn.Module):
return target_lvls return target_lvls
def forward(self, feats, rois): def forward(self, feats, rois):
"""Extract roi features with the roi layer. If multiple feature levels
are used, then rois are mapped to corresponding levels according to
their scales.
"""
if len(feats) == 1: if len(feats) == 1:
return self.roi_layers[0](feats[0], rois) return self.roi_layers[0](feats[0], rois)
......
...@@ -5,20 +5,36 @@ import torch ...@@ -5,20 +5,36 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmdet.core import (AnchorGenerator, anchor_target, bbox_transform_inv, from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
weighted_cross_entropy, weighted_smoothl1, multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy) weighted_binary_cross_entropy)
from mmdet.ops import nms from mmdet.ops import nms
from ..utils import multi_apply
from ..utils import normal_init from ..utils import normal_init
class RPNHead(nn.Module): class RPNHead(nn.Module):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
"""
def __init__(self, def __init__(self,
in_channels, in_channels,
feat_channels=512, feat_channels=256,
coarsest_stride=32,
anchor_scales=[8, 16, 32], anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0], anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64], anchor_strides=[4, 8, 16, 32, 64],
...@@ -29,7 +45,6 @@ class RPNHead(nn.Module): ...@@ -29,7 +45,6 @@ class RPNHead(nn.Module):
super(RPNHead, self).__init__() super(RPNHead, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.feat_channels = feat_channels self.feat_channels = feat_channels
self.coarsest_stride = coarsest_stride
self.anchor_scales = anchor_scales self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides self.anchor_strides = anchor_strides
...@@ -66,63 +81,63 @@ class RPNHead(nn.Module): ...@@ -66,63 +81,63 @@ class RPNHead(nn.Module):
def forward(self, feats): def forward(self, feats):
return multi_apply(self.forward_single, feats) return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_shapes): def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors given a list of feature map sizes, and get valid flags """Get anchors according to feature map sizes.
at the same time. (Extra padding regions should be marked as invalid)
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
""" """
# calculate actual image shapes num_imgs = len(img_metas)
padded_img_shapes = [] num_levels = len(featmap_sizes)
for img_shape in img_shapes:
h, w = img_shape[:2] # since feature map sizes of all images are the same, we only compute
padded_h = int( # anchors for one time
np.ceil(h / self.coarsest_stride) * self.coarsest_stride) multi_level_anchors = []
padded_w = int( for i in range(num_levels):
np.ceil(w / self.coarsest_stride) * self.coarsest_stride)
padded_img_shapes.append((padded_h, padded_w))
# generate anchors for different feature levels
# len = feature levels
anchor_list = []
# len = imgs per gpu
valid_flag_list = [[] for _ in range(len(img_shapes))]
for i in range(len(featmap_sizes)):
anchor_stride = self.anchor_strides[i]
anchors = self.anchor_generators[i].grid_anchors( anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], anchor_stride) featmap_sizes[i], self.anchor_strides[i])
anchor_list.append(anchors) multi_level_anchors.append(anchors)
# for each image in this feature level, get valid flags anchor_list = [multi_level_anchors for _ in range(num_imgs)]
featmap_size = featmap_sizes[i]
for img_id, (h, w) in enumerate(padded_img_shapes): # for each image, we compute valid flags of multi level anchors
valid_feat_h = min( valid_flag_list = []
int(np.ceil(h / anchor_stride)), featmap_size[0]) for img_id, img_meta in enumerate(img_metas):
valid_feat_w = min( multi_level_flags = []
int(np.ceil(w / anchor_stride)), featmap_size[1]) for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags( flags = self.anchor_generators[i].valid_flags(
featmap_size, (valid_feat_h, valid_feat_w)) (feat_h, feat_w), (valid_feat_h, valid_feat_w))
valid_flag_list[img_id].append(flags) multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list return anchor_list, valid_flag_list
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights, def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg): bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1) labels = labels.contiguous().view(-1)
label_weights = label_weights.contiguous().view(-1) label_weights = label_weights.contiguous().view(-1)
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3, rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1) 1).contiguous().view(-1)
loss_cls = weighted_binary_cross_entropy( criterion = weighted_binary_cross_entropy
rpn_cls_score,
labels,
label_weights,
ave_factor=num_total_samples)
else: else:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3, rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1, 2) 1).contiguous().view(-1, 2)
loss_cls = weighted_cross_entropy( criterion = weighted_cross_entropy
rpn_cls_score, loss_cls = criterion(
labels, rpn_cls_score, labels, label_weights, avg_factor=num_total_samples)
label_weights, # regression loss
ave_factor=num_total_samples) bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view( rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(
-1, 4) -1, 4)
loss_reg = weighted_smoothl1( loss_reg = weighted_smoothl1(
...@@ -130,7 +145,7 @@ class RPNHead(nn.Module): ...@@ -130,7 +145,7 @@ class RPNHead(nn.Module):
bbox_targets, bbox_targets,
bbox_weights, bbox_weights,
beta=cfg.smoothl1_beta, beta=cfg.smoothl1_beta,
ave_factor=num_total_samples) avg_factor=num_total_samples)
return loss_cls, loss_reg return loss_cls, loss_reg
def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg): def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg):
...@@ -140,7 +155,7 @@ class RPNHead(nn.Module): ...@@ -140,7 +155,7 @@ class RPNHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors( anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes) featmap_sizes, img_shapes)
cls_reg_targets = anchor_target( cls_reg_targets = anchor_target(
anchor_list, valid_flag_list, featmap_sizes, gt_bboxes, img_shapes, anchor_list, valid_flag_list, gt_bboxes, img_shapes,
self.target_means, self.target_stds, cfg) self.target_means, self.target_stds, cfg)
if cls_reg_targets is None: if cls_reg_targets is None:
return None return None
...@@ -158,8 +173,8 @@ class RPNHead(nn.Module): ...@@ -158,8 +173,8 @@ class RPNHead(nn.Module):
cfg=cfg) cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg) return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_shapes, cfg): def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_meta, cfg):
img_per_gpu = len(img_shapes) num_imgs = len(img_meta)
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores] featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
mlvl_anchors = [ mlvl_anchors = [
self.anchor_generators[idx].grid_anchors(featmap_sizes[idx], self.anchor_generators[idx].grid_anchors(featmap_sizes[idx],
...@@ -167,7 +182,7 @@ class RPNHead(nn.Module): ...@@ -167,7 +182,7 @@ class RPNHead(nn.Module):
for idx in range(len(featmap_sizes)) for idx in range(len(featmap_sizes))
] ]
proposal_list = [] proposal_list = []
for img_id in range(img_per_gpu): for img_id in range(num_imgs):
rpn_cls_score_list = [ rpn_cls_score_list = [
rpn_cls_scores[idx][img_id].detach() rpn_cls_scores[idx][img_id].detach()
for idx in range(len(rpn_cls_scores)) for idx in range(len(rpn_cls_scores))
...@@ -177,10 +192,9 @@ class RPNHead(nn.Module): ...@@ -177,10 +192,9 @@ class RPNHead(nn.Module):
for idx in range(len(rpn_bbox_preds)) for idx in range(len(rpn_bbox_preds))
] ]
assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list) assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list)
img_shape = img_shapes[img_id]
proposals = self._get_proposals_single( proposals = self._get_proposals_single(
rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors, rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors,
img_shape, cfg) img_meta[img_id]['img_shape'], cfg)
proposal_list.append(proposals) proposal_list.append(proposals)
return proposal_list return proposal_list
...@@ -195,7 +209,7 @@ class RPNHead(nn.Module): ...@@ -195,7 +209,7 @@ class RPNHead(nn.Module):
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(1, 2, rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(-1) 0).contiguous().view(-1)
rpn_cls_prob = F.sigmoid(rpn_cls_score) rpn_cls_prob = rpn_cls_score.sigmoid()
scores = rpn_cls_prob scores = rpn_cls_prob
else: else:
rpn_cls_score = rpn_cls_score.permute(1, 2, rpn_cls_score = rpn_cls_score.permute(1, 2,
...@@ -211,9 +225,8 @@ class RPNHead(nn.Module): ...@@ -211,9 +225,8 @@ class RPNHead(nn.Module):
rpn_bbox_pred = rpn_bbox_pred[order, :] rpn_bbox_pred = rpn_bbox_pred[order, :]
anchors = anchors[order, :] anchors = anchors[order, :]
scores = scores[order] scores = scores[order]
proposals = bbox_transform_inv(anchors, rpn_bbox_pred, proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_means, self.target_stds, self.target_stds, img_shape)
img_shape)
w = proposals[:, 2] - proposals[:, 0] + 1 w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1 h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) & valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
......
This diff is collapsed.
This diff is collapsed.
from .parallel import MMDataParallel, MMDistributedDataParallel
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from .nms import nms, soft_nms from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool from .roi_pool import RoIPool, roi_pool
__all__ = ['nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool']
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment