Unverified Commit 6efefa27 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #20 from open-mmlab/dev

Initial public release
parents 2cf13281 54b54d88
from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_proposals,
merge_aug_bboxes, merge_aug_masks, multiclass_nms)
class RPNTestMixin(object):
def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
return proposal_list
def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
imgs_per_gpu = len(img_metas[0])
aug_proposals = [[] for _ in range(imgs_per_gpu)]
for x, img_meta in zip(feats, img_metas):
proposal_list = self.simple_test_rpn(x, img_meta, rpn_test_cfg)
for i, proposals in enumerate(proposal_list):
aug_proposals[i].append(proposals)
# after merging, proposals will be rescaled to the original image size
merged_proposals = [
merge_aug_proposals(proposals, img_meta, rpn_test_cfg)
for proposals, img_meta in zip(aug_proposals, img_metas)
]
return merged_proposals
class BBoxTestMixin(object):
def simple_test_bboxes(self,
x,
img_meta,
proposals,
rcnn_test_cfg,
rescale=False):
"""Test only det bboxes without augmentation."""
rois = bbox2roi(proposals)
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=rescale,
nms_cfg=rcnn_test_cfg)
return det_bboxes, det_labels
def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
# TODO more flexible
proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip)
rois = bbox2roi([proposals])
# recompute feature maps to save GPU memory
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
bboxes, scores = self.bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
nms_cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, self.test_cfg.rcnn)
det_bboxes, det_labels = multiclass_nms(
merged_bboxes, merged_scores, self.test_cfg.rcnn.score_thr,
self.test_cfg.rcnn.nms_thr, self.test_cfg.rcnn.max_per_img)
return det_bboxes, det_labels
class MaskTestMixin(object):
def simple_test_mask(self,
x,
img_meta,
det_bboxes,
det_labels,
rescale=False):
# image shape of the first image in the batch (only one)
ori_shape = img_meta[0]['ori_shape']
scale_factor = img_meta[0]['scale_factor']
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
_bboxes = (det_bboxes[:, :4] * scale_factor
if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_seg_masks(
mask_pred, _bboxes, det_labels, self.test_cfg.rcnn, ori_shape,
scale_factor, rescale)
return segm_result
def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
else:
aug_masks = []
for x, img_meta in zip(feats, img_metas):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)],
mask_rois)
mask_pred = self.mask_head(mask_feats)
# convert to numpy array to save memory
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
merged_masks = merge_aug_masks(aug_masks, img_metas,
self.test_cfg.rcnn)
ori_shape = img_metas[0][0]['ori_shape']
segm_result = self.mask_head.get_seg_masks(
merged_masks,
det_bboxes,
det_labels,
self.test_cfg.rcnn,
ori_shape,
scale_factor=1.0,
rescale=False)
return segm_result
import torch
import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder
from mmdet.core import sample_bboxes, bbox2roi, bbox2result, multi_apply
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):
def __init__(self,
backbone,
neck=None,
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
mask_roi_extractor=None,
mask_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(TwoStageDetector, self).__init__()
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
else:
raise NotImplementedError
if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head)
if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head)
if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
@property
def with_rpn(self):
return hasattr(self, 'rpn_head') and self.rpn_head is not None
def init_weights(self, pretrained=None):
super(TwoStageDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
if self.with_rpn:
self.rpn_head.init_weights()
if self.with_bbox:
self.bbox_roi_extractor.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, img):
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self,
img,
img_meta,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
gt_masks=None,
proposals=None):
losses = dict()
x = self.extract_feat(img)
if self.with_rpn:
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
self.train_cfg.rpn)
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
else:
proposal_list = proposals
if self.with_bbox:
(pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes,
pos_gt_labels) = multi_apply(
sample_bboxes,
proposal_list,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
cfg=self.train_cfg.rcnn)
(labels, label_weights, bbox_targets,
bbox_weights) = self.bbox_head.get_bbox_target(
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
self.train_cfg.rcnn)
rois = bbox2roi([
torch.cat([pos, neg], dim=0)
for pos, neg in zip(pos_proposals, neg_proposals)
])
# TODO: a more flexible way to configurate feat maps
roi_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels,
label_weights, bbox_targets,
bbox_weights)
losses.update(loss_bbox)
if self.with_mask:
mask_targets = self.mask_head.get_mask_target(
pos_proposals, pos_assigned_gt_inds, gt_masks,
self.train_cfg.rcnn)
pos_rois = bbox2roi(pos_proposals)
mask_feats = self.mask_roi_extractor(
x[:self.mask_roi_extractor.num_inputs], pos_rois)
mask_pred = self.mask_head(mask_feats)
loss_mask = self.mask_head.loss(mask_pred, mask_targets,
torch.cat(pos_gt_labels))
losses.update(loss_mask)
return losses
def simple_test(self, img, img_meta, proposals=None, rescale=False):
"""Test without augmentation."""
assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img)
proposal_list = self.simple_test_rpn(
x, img_meta,
self.test_cfg.rpn) if proposals is None else proposals
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
bbox_results = bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes)
if not self.with_mask:
return bbox_results
else:
segm_results = self.simple_test_mask(
x, img_meta, det_bboxes, det_labels, rescale=rescale)
return bbox_results, segm_results
def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
# recompute feats to save memory
proposal_list = self.aug_test_rpn(
self.extract_feats(imgs), img_metas, self.test_cfg.rpn)
det_bboxes, det_labels = self.aug_test_bboxes(
self.extract_feats(imgs), img_metas, proposal_list,
self.test_cfg.rcnn)
if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= img_metas[0][0]['scale_factor']
bbox_results = bbox2result(_det_bboxes, det_labels,
self.bbox_head.num_classes)
# det_bboxes always keep the original scale
if self.with_mask:
segm_results = self.aug_test_mask(
self.extract_feats(imgs), img_metas, det_bboxes, det_labels)
return bbox_results, segm_results
else:
return bbox_results
from .fcn_mask_head import FCNMaskHead
__all__ = ['FCNMaskHead']
import mmcv
import numpy as np
import pycocotools.mask as mask_util
import torch
import torch.nn as nn
from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target
class FCNMaskHead(nn.Module):
def __init__(self,
num_convs=4,
roi_feat_size=14,
in_channels=256,
conv_kernel_size=3,
conv_out_channels=256,
upsample_method='deconv',
upsample_ratio=2,
num_classes=81,
class_agnostic=False,
normalize=None):
super(FCNMaskHead, self).__init__()
if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
raise ValueError(
'Invalid upsample method {}, accepted methods '
'are "deconv", "nearest", "bilinear"'.format(upsample_method))
self.num_convs = num_convs
self.roi_feat_size = roi_feat_size # WARN: not used and reserved
self.in_channels = in_channels
self.conv_kernel_size = conv_kernel_size
self.conv_out_channels = conv_out_channels
self.upsample_method = upsample_method
self.upsample_ratio = upsample_ratio
self.num_classes = num_classes
self.class_agnostic = class_agnostic
self.normalize = normalize
self.with_bias = normalize is None
self.convs = nn.ModuleList()
for i in range(self.num_convs):
in_channels = (self.in_channels
if i == 0 else self.conv_out_channels)
padding = (self.conv_kernel_size - 1) // 2
self.convs.append(
ConvModule(
in_channels,
self.conv_out_channels,
3,
padding=padding,
normalize=normalize,
bias=self.with_bias))
if self.upsample_method is None:
self.upsample = None
elif self.upsample_method == 'deconv':
self.upsample = nn.ConvTranspose2d(
self.conv_out_channels,
self.conv_out_channels,
self.upsample_ratio,
stride=self.upsample_ratio)
else:
self.upsample = nn.Upsample(
scale_factor=self.upsample_ratio, mode=self.upsample_method)
out_channels = 1 if self.class_agnostic else self.num_classes
self.conv_logits = nn.Conv2d(self.conv_out_channels, out_channels, 1)
self.relu = nn.ReLU(inplace=True)
self.debug_imgs = None
def init_weights(self):
for m in [self.upsample, self.conv_logits]:
if m is None:
continue
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(m.bias, 0)
def forward(self, x):
for conv in self.convs:
x = conv(x)
if self.upsample is not None:
x = self.upsample(x)
if self.upsample_method == 'deconv':
x = self.relu(x)
mask_pred = self.conv_logits(x)
return mask_pred
def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks,
rcnn_train_cfg):
mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
gt_masks, rcnn_train_cfg)
return mask_targets
def loss(self, mask_pred, mask_targets, labels):
loss = dict()
loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
loss['loss_mask'] = loss_mask
return loss
def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_shape, scale_factor, rescale):
"""Get segmentation masks from mask_pred and bboxes.
Args:
mask_pred (Tensor or ndarray): shape (n, #class+1, h, w).
For single-scale testing, mask_pred is the direct output of
model, whose type is Tensor, while for multi-scale testing,
it will be converted to numpy array outside of this method.
det_bboxes (Tensor): shape (n, 4/5)
det_labels (Tensor): shape (n, )
img_shape (Tensor): shape (3, )
rcnn_test_cfg (dict): rcnn testing config
ori_shape: original image size
Returns:
list[list]: encoded masks
"""
if isinstance(mask_pred, torch.Tensor):
mask_pred = mask_pred.sigmoid().cpu().numpy()
assert isinstance(mask_pred, np.ndarray)
cls_segms = [[] for _ in range(self.num_classes - 1)]
bboxes = det_bboxes.cpu().numpy()[:, :4]
labels = det_labels.cpu().numpy() + 1
if rescale:
img_h, img_w = ori_shape[:2]
else:
img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
scale_factor = 1.0
for i in range(bboxes.shape[0]):
bbox = (bboxes[i, :] / scale_factor).astype(np.int32)
label = labels[i]
w = max(bbox[2] - bbox[0] + 1, 1)
h = max(bbox[3] - bbox[1] + 1, 1)
if not self.class_agnostic:
mask_pred_ = mask_pred[i, label, :, :]
else:
mask_pred_ = mask_pred[i, 0, :, :]
im_mask = np.zeros((img_h, img_w), dtype=np.uint8)
bbox_mask = mmcv.imresize(mask_pred_, (w, h))
bbox_mask = (bbox_mask > rcnn_test_cfg.mask_thr_binary).astype(
np.uint8)
im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask
rle = mask_util.encode(
np.array(im_mask[:, :, np.newaxis], order='F'))[0]
cls_segms[label - 1].append(rle)
return cls_segms
from .fpn import FPN
__all__ = ['FPN']
import torch.nn as nn
import torch.nn.functional as F
from ..utils import ConvModule
from ..utils import xavier_init
class FPN(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False,
normalize=None,
activation=None):
super(FPN, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.activation = activation
self.with_bias = normalize is None
if end_level == -1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level < inputs, no extra level is allowed
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
normalize=normalize,
bias=self.with_bias,
activation=self.activation,
inplace=False)
fpn_conv = ConvModule(
out_channels,
out_channels,
3,
padding=1,
normalize=normalize,
bias=self.with_bias,
activation=self.activation,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
# lvl_id = i - self.start_level
# setattr(self, 'lateral_conv{}'.format(lvl_id), l_conv)
# setattr(self, 'fpn_conv{}'.format(lvl_id), fpn_conv)
# add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level
if add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
in_channels = (self.in_channels[self.backbone_end_level - 1]
if i == 0 else out_channels)
extra_fpn_conv = ConvModule(
in_channels,
out_channels,
3,
stride=2,
padding=1,
normalize=normalize,
bias=self.with_bias,
activation=self.activation,
inplace=False)
self.fpn_convs.append(extra_fpn_conv)
# default init_weights for conv(msra) and norm in ConvModule
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
# build laterals
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
laterals[i - 1] += F.interpolate(
laterals[i], scale_factor=2, mode='nearest')
# build outputs
# part 1: from original levels
outs = [
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
]
# part 2: add extra levels
if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet)
else:
orig = inputs[self.backbone_end_level - 1]
outs.append(self.fpn_convs[used_backbone_levels](orig))
for i in range(used_backbone_levels + 1, self.num_outs):
# BUG: we should add relu before each extra conv
outs.append(self.fpn_convs[i](outs[-1]))
return tuple(outs)
from .single_level import SingleRoIExtractor
__all__ = ['SingleRoIExtractor']
from __future__ import division
import torch
import torch.nn as nn
from mmdet import ops
class SingleRoIExtractor(nn.Module):
"""Extract RoI features from a single level feature map.
If there are mulitple input feature levels, each RoI is mapped to a level
according to its scale.
Args:
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (int): Strides of input feature maps.
finest_scale (int): Scale threshold of mapping to level 0.
"""
def __init__(self,
roi_layer,
out_channels,
featmap_strides,
finest_scale=56):
super(SingleRoIExtractor, self).__init__()
self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
self.out_channels = out_channels
self.featmap_strides = featmap_strides
self.finest_scale = finest_scale
@property
def num_inputs(self):
"""int: Input feature map levels."""
return len(self.featmap_strides)
def init_weights(self):
pass
def build_roi_layers(self, layer_cfg, featmap_strides):
cfg = layer_cfg.copy()
layer_type = cfg.pop('type')
assert hasattr(ops, layer_type)
layer_cls = getattr(ops, layer_type)
roi_layers = nn.ModuleList(
[layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
return roi_layers
def map_roi_levels(self, rois, num_levels):
"""Map rois to corresponding feature levels by scales.
- scale < finest_scale: level 0
- finest_scale <= scale < finest_scale * 2: level 1
- finest_scale * 2 <= scale < finest_scale * 4: level 2
- scale >= finest_scale * 4: level 3
Args:
rois (Tensor): Input RoIs, shape (k, 5).
num_levels (int): Total level number.
Returns:
Tensor: Level index (0-based) of each RoI, shape (k, )
"""
scale = torch.sqrt(
(rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1))
target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls
def forward(self, feats, rois):
if len(feats) == 1:
return self.roi_layers[0](feats[0], rois)
out_size = self.roi_layers[0].out_size
num_levels = len(feats)
target_lvls = self.map_roi_levels(rois, num_levels)
roi_feats = torch.cuda.FloatTensor(rois.size()[0], self.out_channels,
out_size, out_size).fill_(0)
for i in range(num_levels):
inds = target_lvls == i
if inds.any():
rois_ = rois[inds, :]
roi_feats_t = self.roi_layers[i](feats[i], rois_)
roi_feats[inds] += roi_feats_t
return roi_feats
from .rpn_head import RPNHead
__all__ = ['RPNHead']
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy)
from mmdet.ops import nms
from ..utils import normal_init
class RPNHead(nn.Module):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
"""
def __init__(self,
in_channels,
feat_channels=256,
anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False):
super(RPNHead, self).__init__()
self.in_channels = in_channels
self.feat_channels = feat_channels
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
out_channels = (self.num_anchors
if self.use_sigmoid_cls else self.num_anchors * 2)
self.rpn_cls = nn.Conv2d(feat_channels, out_channels, 1)
self.rpn_reg = nn.Conv2d(feat_channels, self.num_anchors * 4, 1)
self.debug_imgs = None
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
rpn_feat = self.relu(self.rpn_conv(x))
rpn_cls_score = self.rpn_cls(rpn_feat)
rpn_bbox_pred = self.rpn_reg(rpn_feat)
return rpn_cls_score, rpn_bbox_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1)
label_weights = label_weights.contiguous().view(-1)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1)
criterion = weighted_binary_cross_entropy
else:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1, 2)
criterion = weighted_cross_entropy
loss_cls = criterion(
rpn_cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(
-1, 4)
loss_reg = weighted_smoothl1(
rpn_bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_total_samples)
return loss_cls, loss_reg
def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg):
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes)
cls_reg_targets = anchor_target(
anchor_list, valid_flag_list, gt_bboxes, img_shapes,
self.target_means, self.target_stds, cfg)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_samples) = cls_reg_targets
losses_cls, losses_reg = multi_apply(
self.loss_single,
rpn_cls_scores,
rpn_bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples,
cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_meta, cfg):
num_imgs = len(img_meta)
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
mlvl_anchors = [
self.anchor_generators[idx].grid_anchors(featmap_sizes[idx],
self.anchor_strides[idx])
for idx in range(len(featmap_sizes))
]
proposal_list = []
for img_id in range(num_imgs):
rpn_cls_score_list = [
rpn_cls_scores[idx][img_id].detach()
for idx in range(len(rpn_cls_scores))
]
rpn_bbox_pred_list = [
rpn_bbox_preds[idx][img_id].detach()
for idx in range(len(rpn_bbox_preds))
]
assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list)
proposals = self._get_proposals_single(
rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors,
img_meta[img_id]['img_shape'], cfg)
proposal_list.append(proposals)
return proposal_list
def _get_proposals_single(self, rpn_cls_scores, rpn_bbox_preds,
mlvl_anchors, img_shape, cfg):
mlvl_proposals = []
for idx in range(len(rpn_cls_scores)):
rpn_cls_score = rpn_cls_scores[idx]
rpn_bbox_pred = rpn_bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
anchors = mlvl_anchors[idx]
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(-1)
rpn_cls_prob = rpn_cls_score.sigmoid()
scores = rpn_cls_prob
else:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(
-1, 2)
rpn_cls_prob = F.softmax(rpn_cls_score, dim=1)
scores = rpn_cls_prob[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).contiguous().view(
-1, 4)
_, order = scores.sort(0, descending=True)
if cfg.nms_pre > 0:
order = order[:cfg.nms_pre]
rpn_bbox_pred = rpn_bbox_pred[order, :]
anchors = anchors[order, :]
scores = scores[order]
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
nms_keep = nms(proposals, cfg.nms_thr)[:cfg.nms_post]
proposals = proposals[nms_keep, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
nms_keep = nms(proposals, cfg.nms_thr)[:cfg.max_num]
proposals = proposals[nms_keep, :]
else:
scores = proposals[:, 4]
_, order = scores.sort(0, descending=True)
num = min(cfg.max_num, proposals.shape[0])
order = order[:num]
proposals = proposals[order, :]
return proposals
from .conv_module import ConvModule
from .norm import build_norm_layer
from .weight_init import xavier_init, normal_init, uniform_init, kaiming_init
__all__ = [
'ConvModule', 'build_norm_layer', 'xavier_init', 'normal_init',
'uniform_init', 'kaiming_init'
]
import warnings
import torch.nn as nn
from .norm import build_norm_layer
class ConvModule(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
normalize=None,
activation='relu',
inplace=True,
activate_last=True):
super(ConvModule, self).__init__()
self.with_norm = normalize is not None
self.with_activatation = activation is not None
self.with_bias = bias
self.activation = activation
self.activate_last = activate_last
if self.with_norm and self.with_bias:
warnings.warn('ConvModule has norm and bias at the same time')
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias=bias)
self.in_channels = self.conv.in_channels
self.out_channels = self.conv.out_channels
self.kernel_size = self.conv.kernel_size
self.stride = self.conv.stride
self.padding = self.conv.padding
self.dilation = self.conv.dilation
self.transposed = self.conv.transposed
self.output_padding = self.conv.output_padding
self.groups = self.conv.groups
if self.with_norm:
# self.norm_type, self.norm_params = parse_norm(normalize)
# assert self.norm_type in [None, 'BN', 'SyncBN', 'GN', 'SN']
# self.Norm2d = norm_cfg[self.norm_type]
if self.activate_last:
self.norm = build_norm_layer(normalize, out_channels)
# self.norm = self.Norm2d(out_channels, **self.norm_params)
else:
self.norm = build_norm_layer(normalize, in_channels)
# self.norm = self.Norm2d(in_channels, **self.norm_params)
if self.with_activatation:
assert activation in ['relu'], 'Only ReLU supported.'
if self.activation == 'relu':
self.activate = nn.ReLU(inplace=inplace)
# Default using msra init
self.init_weights()
def init_weights(self):
nonlinearity = 'relu' if self.activation is None else self.activation
nn.init.kaiming_normal_(
self.conv.weight, mode='fan_out', nonlinearity=nonlinearity)
if self.with_bias:
nn.init.constant_(self.conv.bias, 0)
if self.with_norm:
nn.init.constant_(self.norm.weight, 1)
nn.init.constant_(self.norm.bias, 0)
def forward(self, x, activate=True, norm=True):
if self.activate_last:
x = self.conv(x)
if norm and self.with_norm:
x = self.norm(x)
if activate and self.with_activatation:
x = self.activate(x)
else:
if norm and self.with_norm:
x = self.norm(x)
if activate and self.with_activatation:
x = self.activate(x)
x = self.conv(x)
return x
import torch.nn as nn
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': None}
def build_norm_layer(cfg, num_features):
assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy()
cfg_.setdefault('eps', 1e-5)
layer_type = cfg_.pop('type')
if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None:
raise NotImplementedError
return norm_cfg[layer_type](num_features, **cfg_)
import torch.nn as nn
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
def uniform_init(module, a=0, b=1, bias=0):
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool
__all__ = ['nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool']
PYTHON=${PYTHON:-python}
all:
echo "Compiling nms kernels..."
$(PYTHON) setup.py build_ext --inplace
clean:
rm *.so
from .nms_wrapper import nms, soft_nms
__all__ = ['nms', 'soft_nms']
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import numpy as np
cimport numpy as np
cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
return a if a >= b else b
cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
return a if a <= b else b
def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1]
cdef int ndets = dets.shape[0]
cdef np.ndarray[np.int_t, ndim=1] suppressed = \
np.zeros((ndets), dtype=np.int)
# nominal indices
cdef int _i, _j
# sorted indices
cdef int i, j
# temp variables for box i's (the box currently under consideration)
cdef np.float32_t ix1, iy1, ix2, iy2, iarea
# variables for computing overlap with box j (lower scoring box)
cdef np.float32_t xx1, yy1, xx2, yy2
cdef np.float32_t w, h
cdef np.float32_t inter, ovr
keep = []
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
keep.append(i)
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
return keep
# ----------------------------------------------------------
# Soft-NMS: Improving Object Detection With One Line of Code
# Copyright (c) University of Maryland, College Park
# Licensed under The MIT License [see LICENSE for details]
# Written by Navaneeth Bodla and Bharat Singh
# ----------------------------------------------------------
import numpy as np
cimport numpy as np
cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
return a if a >= b else b
cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
return a if a <= b else b
def cpu_soft_nms(
np.ndarray[float, ndim=2] boxes_in,
float sigma=0.5,
float Nt=0.3,
float threshold=0.001,
unsigned int method=0
):
boxes = boxes_in.copy()
cdef unsigned int N = boxes.shape[0]
cdef float iw, ih, box_area
cdef float ua
cdef int pos = 0
cdef float maxscore = 0
cdef int maxpos = 0
cdef float x1, x2, y1, y2, tx1, tx2, ty1, ty2, ts, area, weight, ov
inds = np.arange(N)
for i in range(N):
maxscore = boxes[i, 4]
maxpos = i
tx1 = boxes[i,0]
ty1 = boxes[i,1]
tx2 = boxes[i,2]
ty2 = boxes[i,3]
ts = boxes[i,4]
ti = inds[i]
pos = i + 1
# get max box
while pos < N:
if maxscore < boxes[pos, 4]:
maxscore = boxes[pos, 4]
maxpos = pos
pos = pos + 1
# add max box as a detection
boxes[i,0] = boxes[maxpos,0]
boxes[i,1] = boxes[maxpos,1]
boxes[i,2] = boxes[maxpos,2]
boxes[i,3] = boxes[maxpos,3]
boxes[i,4] = boxes[maxpos,4]
inds[i] = inds[maxpos]
# swap ith box with position of max box
boxes[maxpos,0] = tx1
boxes[maxpos,1] = ty1
boxes[maxpos,2] = tx2
boxes[maxpos,3] = ty2
boxes[maxpos,4] = ts
inds[maxpos] = ti
tx1 = boxes[i,0]
ty1 = boxes[i,1]
tx2 = boxes[i,2]
ty2 = boxes[i,3]
ts = boxes[i,4]
pos = i + 1
# NMS iterations, note that N changes if detection boxes fall below
# threshold
while pos < N:
x1 = boxes[pos, 0]
y1 = boxes[pos, 1]
x2 = boxes[pos, 2]
y2 = boxes[pos, 3]
s = boxes[pos, 4]
area = (x2 - x1 + 1) * (y2 - y1 + 1)
iw = (min(tx2, x2) - max(tx1, x1) + 1)
if iw > 0:
ih = (min(ty2, y2) - max(ty1, y1) + 1)
if ih > 0:
ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
ov = iw * ih / ua #iou between max box and detection box
if method == 1: # linear
if ov > Nt:
weight = 1 - ov
else:
weight = 1
elif method == 2: # gaussian
weight = np.exp(-(ov * ov)/sigma)
else: # original NMS
if ov > Nt:
weight = 0
else:
weight = 1
boxes[pos, 4] = weight*boxes[pos, 4]
# if box score falls below threshold, discard the box by
# swapping with last box update N
if boxes[pos, 4] < threshold:
boxes[pos,0] = boxes[N-1, 0]
boxes[pos,1] = boxes[N-1, 1]
boxes[pos,2] = boxes[N-1, 2]
boxes[pos,3] = boxes[N-1, 3]
boxes[pos,4] = boxes[N-1, 4]
inds[pos] = inds[N-1]
N = N - 1
pos = pos - 1
pos = pos + 1
return boxes[:N], inds[:N]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment