Unverified Commit a8ec6fb3 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #252 from hellock/anchor-head

Unify RPNHead and single stage heads with AnchorHead
parents d7d4a991 2df1e0a0
...@@ -12,7 +12,7 @@ def anchor_target(anchor_list, ...@@ -12,7 +12,7 @@ def anchor_target(anchor_list,
target_stds, target_stds,
cfg, cfg,
gt_labels_list=None, gt_labels_list=None,
cls_out_channels=1, label_channels=1,
sampling=True, sampling=True,
unmap_outputs=True): unmap_outputs=True):
"""Compute regression and classification targets for anchors. """Compute regression and classification targets for anchors.
...@@ -54,7 +54,7 @@ def anchor_target(anchor_list, ...@@ -54,7 +54,7 @@ def anchor_target(anchor_list,
target_means=target_means, target_means=target_means,
target_stds=target_stds, target_stds=target_stds,
cfg=cfg, cfg=cfg,
cls_out_channels=cls_out_channels, label_channels=label_channels,
sampling=sampling, sampling=sampling,
unmap_outputs=unmap_outputs) unmap_outputs=unmap_outputs)
# no valid anchors # no valid anchors
...@@ -95,7 +95,7 @@ def anchor_target_single(flat_anchors, ...@@ -95,7 +95,7 @@ def anchor_target_single(flat_anchors,
target_means, target_means,
target_stds, target_stds,
cfg, cfg,
cls_out_channels=1, label_channels=1,
sampling=True, sampling=True,
unmap_outputs=True): unmap_outputs=True):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags, inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
...@@ -147,9 +147,9 @@ def anchor_target_single(flat_anchors, ...@@ -147,9 +147,9 @@ def anchor_target_single(flat_anchors,
num_total_anchors = flat_anchors.size(0) num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags) labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags) label_weights = unmap(label_weights, num_total_anchors, inside_flags)
if cls_out_channels > 1: if label_channels > 1:
labels, label_weights = expand_binary_labels(labels, label_weights, labels, label_weights = expand_binary_labels(
cls_out_channels) labels, label_weights, label_channels)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
...@@ -157,14 +157,14 @@ def anchor_target_single(flat_anchors, ...@@ -157,14 +157,14 @@ def anchor_target_single(flat_anchors,
neg_inds) neg_inds)
def expand_binary_labels(labels, label_weights, cls_out_channels): def expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full( bin_labels = labels.new_full(
(labels.size(0), cls_out_channels), 0, dtype=torch.float32) (labels.size(0), label_channels), 0, dtype=torch.float32)
inds = torch.nonzero(labels >= 1).squeeze() inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0: if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1 bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand( bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), cls_out_channels) label_weights.size(0), label_channels)
return bin_labels, bin_label_weights return bin_labels, bin_label_weights
......
from .detectors import (BaseDetector, TwoStageDetector, RPN, FastRCNN, from .detectors import (BaseDetector, TwoStageDetector, RPN, FastRCNN,
FasterRCNN, MaskRCNN) FasterRCNN, MaskRCNN)
from .builder import (build_neck, build_rpn_head, build_roi_extractor, from .builder import (build_neck, build_anchor_head, build_roi_extractor,
build_bbox_head, build_mask_head, build_detector) build_bbox_head, build_mask_head, build_detector)
__all__ = [ __all__ = [
'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN',
'MaskRCNN', 'build_backbone', 'build_neck', 'build_rpn_head', 'MaskRCNN', 'build_backbone', 'build_neck', 'build_anchor_head',
'build_roi_extractor', 'build_bbox_head', 'build_mask_head', 'build_roi_extractor', 'build_bbox_head', 'build_mask_head',
'build_detector' 'build_detector'
] ]
from .anchor_head import AnchorHead
from .rpn_head import RPNHead
from .retina_head import RetinaHead from .retina_head import RetinaHead
from .ssd_head import SSDHead from .ssd_head import SSDHead
__all__ = ['RetinaHead', 'SSDHead'] __all__ = ['AnchorHead', 'RPNHead', 'RetinaHead', 'SSDHead']
...@@ -4,113 +4,81 @@ import numpy as np ...@@ -4,113 +4,81 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmdet.core import (AnchorGenerator, anchor_target, multi_apply, from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
delta2bbox, weighted_smoothl1, multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy,
weighted_sigmoid_focal_loss, multiclass_nms) weighted_sigmoid_focal_loss, multiclass_nms)
from ..utils import normal_init, bias_init_with_prob from ..utils import normal_init
class RetinaHead(nn.Module): class AnchorHead(nn.Module):
"""Head of RetinaNet. """Anchor-based head (RPN, RetinaNet, SSD, etc.).
/ cls_convs - retina_cls (3x3 conv)
input -
\ reg_convs - retina_reg (3x3 conv)
Args: Args:
in_channels (int): Number of channels in the input feature map. in_channels (int): Number of channels in the input feature map.
num_classes (int): Class number (including background). feat_channels (int): Number of channels of the feature map.
stacked_convs (int): Number of convolutional layers added for cls and anchor_scales (Iterable): Anchor scales.
reg branch.
feat_channels (int): Number of channels for the RPN feature map.
scales_per_octave (int): Number of anchor scales per octave.
octave_base_scale (int): Base octave scale. Anchor scales are computed
as `s*2^(i/n)`, for i in [0, n-1], where s is `octave_base_scale`
and n is `scales_per_octave`.
anchor_ratios (Iterable): Anchor aspect ratios. anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides. anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets. target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets. target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
use_focal_loss (bool): Whether to use focal loss for classification.
""" # noqa: W605 """ # noqa: W605
def __init__(self, def __init__(self,
in_channels,
num_classes, num_classes,
stacked_convs=4, in_channels,
feat_channels=256, feat_channels=256,
octave_base_scale=4, anchor_scales=[8, 16, 32],
scales_per_octave=3,
anchor_ratios=[0.5, 1.0, 2.0], anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[8, 16, 32, 64, 128], anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None, anchor_base_sizes=None,
target_means=(.0, .0, .0, .0), target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)): target_stds=(1.0, 1.0, 1.0, 1.0),
super(RetinaHead, self).__init__() use_sigmoid_cls=False,
use_focal_loss=False):
super(AnchorHead, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
self.octave_base_scale = octave_base_scale self.feat_channels = feat_channels
self.scales_per_octave = scales_per_octave self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides self.anchor_strides = anchor_strides
self.anchor_base_sizes = list( self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means self.target_means = target_means
self.target_stds = target_stds self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.use_focal_loss = use_focal_loss
self.anchor_generators = [] self.anchor_generators = []
for anchor_base in self.anchor_base_sizes: for anchor_base in self.anchor_base_sizes:
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
self.anchor_generators.append( self.anchor_generators.append(
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios)) AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.relu = nn.ReLU(inplace=True)
self.num_anchors = int(
len(self.anchor_ratios) * self.scales_per_octave)
self.cls_out_channels = self.num_classes - 1
self.bbox_pred_dim = 4
self.stacked_convs = stacked_convs self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
self.cls_convs = nn.ModuleList() if self.use_sigmoid_cls:
self.reg_convs = nn.ModuleList() self.cls_out_channels = self.num_classes - 1
for i in range(self.stacked_convs): else:
chn = in_channels if i == 0 else feat_channels self.cls_out_channels = self.num_classes
self.cls_convs.append(
nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1)) self._init_layers()
self.reg_convs.append(
nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1)) def _init_layers(self):
self.retina_cls = nn.Conv2d( self.conv_cls = nn.Conv2d(self.feat_channels,
feat_channels, self.num_anchors * self.cls_out_channels, 1)
self.num_anchors * self.cls_out_channels, self.conv_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
3,
stride=1,
padding=1)
self.retina_reg = nn.Conv2d(
feat_channels,
self.num_anchors * self.bbox_pred_dim,
3,
stride=1,
padding=1)
self.debug_imgs = None
def init_weights(self): def init_weights(self):
for m in self.cls_convs: normal_init(self.conv_cls, std=0.01)
normal_init(m, std=0.01) normal_init(self.conv_reg, std=0.01)
for m in self.reg_convs:
normal_init(m, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
def forward_single(self, x): def forward_single(self, x):
cls_feat = x cls_score = self.conv_cls(x)
reg_feat = x bbox_pred = self.conv_reg(x)
for cls_conv in self.cls_convs:
cls_feat = self.relu(cls_conv(cls_feat))
for reg_conv in self.reg_convs:
reg_feat = self.relu(reg_conv(reg_feat))
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
return cls_score, bbox_pred return cls_score, bbox_pred
def forward(self, feats): def forward(self, feats):
...@@ -156,30 +124,47 @@ class RetinaHead(nn.Module): ...@@ -156,30 +124,47 @@ class RetinaHead(nn.Module):
return anchor_list, valid_flag_list return anchor_list, valid_flag_list
def loss_single(self, cls_score, bbox_pred, labels, label_weights, def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_pos_samples, cfg): bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss # classification loss
labels = labels.contiguous().view(-1, self.cls_out_channels) if self.use_sigmoid_cls:
label_weights = label_weights.contiguous().view( labels = labels.reshape(-1, self.cls_out_channels)
-1, self.cls_out_channels) label_weights = label_weights.reshape(-1, self.cls_out_channels)
cls_score = cls_score.permute(0, 2, 3, 1).contiguous().view( else:
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels) -1, self.cls_out_channels)
loss_cls = weighted_sigmoid_focal_loss( if self.use_sigmoid_cls:
cls_score, if self.use_focal_loss:
labels, cls_criterion = weighted_sigmoid_focal_loss
label_weights, else:
cfg.gamma, cls_criterion = weighted_binary_cross_entropy
cfg.alpha, else:
avg_factor=num_pos_samples) if self.use_focal_loss:
raise NotImplementedError
else:
cls_criterion = weighted_cross_entropy
if self.use_focal_loss:
loss_cls = cls_criterion(
cls_score,
labels,
label_weights,
gamma=cfg.gamma,
alpha=cfg.alpha,
avg_factor=num_total_samples)
else:
loss_cls = cls_criterion(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss # regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4) bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4) bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).contiguous().view(-1, 4) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
loss_reg = weighted_smoothl1( loss_reg = weighted_smoothl1(
bbox_pred, bbox_pred,
bbox_targets, bbox_targets,
bbox_weights, bbox_weights,
beta=cfg.smoothl1_beta, beta=cfg.smoothl1_beta,
avg_factor=num_pos_samples) avg_factor=num_total_samples)
return loss_cls, loss_reg return loss_cls, loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
...@@ -189,6 +174,8 @@ class RetinaHead(nn.Module): ...@@ -189,6 +174,8 @@ class RetinaHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors( anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas) featmap_sizes, img_metas)
sampling = False if self.use_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target( cls_reg_targets = anchor_target(
anchor_list, anchor_list,
valid_flag_list, valid_flag_list,
...@@ -198,13 +185,14 @@ class RetinaHead(nn.Module): ...@@ -198,13 +185,14 @@ class RetinaHead(nn.Module):
self.target_stds, self.target_stds,
cfg, cfg,
gt_labels_list=gt_labels, gt_labels_list=gt_labels,
cls_out_channels=self.cls_out_channels, label_channels=label_channels,
sampling=False) sampling=sampling)
if cls_reg_targets is None: if cls_reg_targets is None:
return None return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (num_total_pos if self.use_focal_loss else
num_total_pos + num_total_neg)
losses_cls, losses_reg = multi_apply( losses_cls, losses_reg = multi_apply(
self.loss_single, self.loss_single,
cls_scores, cls_scores,
...@@ -213,16 +201,12 @@ class RetinaHead(nn.Module): ...@@ -213,16 +201,12 @@ class RetinaHead(nn.Module):
label_weights_list, label_weights_list,
bbox_targets_list, bbox_targets_list,
bbox_weights_list, bbox_weights_list,
num_pos_samples=num_total_pos, num_total_samples=num_total_samples,
cfg=cfg) cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg) return dict(loss_cls=losses_cls, loss_reg=losses_reg)
def get_det_bboxes(self, def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
cls_scores, rescale=False):
bbox_preds,
img_metas,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores) num_levels = len(cls_scores)
...@@ -231,7 +215,6 @@ class RetinaHead(nn.Module): ...@@ -231,7 +215,6 @@ class RetinaHead(nn.Module):
self.anchor_strides[i]) self.anchor_strides[i])
for i in range(num_levels) for i in range(num_levels)
] ]
result_list = [] result_list = []
for img_id in range(len(img_metas)): for img_id in range(len(img_metas)):
cls_score_list = [ cls_score_list = [
...@@ -242,46 +225,54 @@ class RetinaHead(nn.Module): ...@@ -242,46 +225,54 @@ class RetinaHead(nn.Module):
] ]
img_shape = img_metas[img_id]['img_shape'] img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor'] scale_factor = img_metas[img_id]['scale_factor']
results = self._get_det_bboxes_single( proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
cls_score_list, bbox_pred_list, mlvl_anchors, img_shape, mlvl_anchors, img_shape,
scale_factor, cfg, rescale) scale_factor, cfg, rescale)
result_list.append(results) result_list.append(proposals)
return result_list return result_list
def _get_det_bboxes_single(self, def get_bboxes_single(self,
cls_scores, cls_scores,
bbox_preds, bbox_preds,
mlvl_anchors, mlvl_anchors,
img_shape, img_shape,
scale_factor, scale_factor,
cfg, cfg,
rescale=False): rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_proposals = [] mlvl_bboxes = []
mlvl_scores = [] mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds, for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
mlvl_anchors): mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:] assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2, 0).contiguous().view( cls_score = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels) -1, self.cls_out_channels)
scores = cls_score.sigmoid() if self.use_sigmoid_cls:
bbox_pred = bbox_pred.permute(1, 2, 0).contiguous().view(-1, 4) scores = cls_score.sigmoid()
proposals = delta2bbox(anchors, bbox_pred, self.target_means, else:
self.target_stds, img_shape) scores = cls_score.softmax(-1)
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
maxscores, _ = scores.max(dim=1) nms_pre = cfg.get('nms_pre', -1)
_, topk_inds = maxscores.topk(cfg.nms_pre) if nms_pre > 0 and scores.shape[0] > nms_pre:
proposals = proposals[topk_inds, :] if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
max_scores, _ = scores[:, 1:].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :] scores = scores[topk_inds, :]
mlvl_proposals.append(proposals) bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores) mlvl_scores.append(scores)
mlvl_proposals = torch.cat(mlvl_proposals) mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale: if rescale:
mlvl_proposals /= scale_factor mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores) mlvl_scores = torch.cat(mlvl_scores)
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) if self.use_sigmoid_cls:
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores, mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
cfg.score_thr, cfg.nms, det_bboxes, det_labels = multiclass_nms(
cfg.max_per_img) mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img)
return det_bboxes, det_labels return det_bboxes, det_labels
import numpy as np
import torch.nn as nn
from mmcv.cnn import normal_init
from .anchor_head import AnchorHead
from ..utils import bias_init_with_prob
class RetinaHead(AnchorHead):
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
octave_base_scale=4,
scales_per_octave=3,
**kwargs):
self.stacked_convs = stacked_convs
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
super(RetinaHead, self).__init__(
num_classes,
in_channels,
anchor_scales=anchor_scales,
use_sigmoid_cls=True,
use_focal_loss=True,
**kwargs)
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
nn.Conv2d(chn, self.feat_channels, 3, stride=1, padding=1))
self.reg_convs.append(
nn.Conv2d(chn, self.feat_channels, 3, stride=1, padding=1))
self.retina_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.retina_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
def init_weights(self):
for m in self.cls_convs:
normal_init(m, std=0.01)
for m in self.reg_convs:
normal_init(m, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
def forward_single(self, x):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = self.relu(cls_conv(cls_feat))
for reg_conv in self.reg_convs:
reg_feat = self.relu(reg_conv(reg_feat))
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
return cls_score, bbox_pred
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.core import delta2bbox
from mmdet.ops import nms
from .anchor_head import AnchorHead
class RPNHead(AnchorHead):
def __init__(self, in_channels, **kwargs):
super(RPNHead, self).__init__(2, in_channels, **kwargs)
def _init_layers(self):
self.rpn_conv = nn.Conv2d(
self.in_channels, self.feat_channels, 3, padding=1)
self.rpn_cls = nn.Conv2d(self.feat_channels,
self.num_anchors * self.cls_out_channels, 1)
self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
x = self.rpn_conv(x)
x = F.relu(x, inplace=True)
rpn_cls_score = self.rpn_cls(x)
rpn_bbox_pred = self.rpn_reg(x)
return rpn_cls_score, rpn_bbox_pred
def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, cfg):
losses = super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes,
None, img_metas, cfg)
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_reg=losses['loss_reg'])
def get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
mlvl_proposals = []
for idx in range(len(cls_scores)):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
anchors = mlvl_anchors[idx]
rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(-1)
scores = rpn_cls_score.sigmoid()
else:
rpn_cls_score = rpn_cls_score.reshape(-1, 2)
scores = rpn_cls_score.softmax(dim=1)[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
_, topk_inds = scores.topk(cfg.nms_pre)
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
anchors = anchors[topk_inds, :]
scores = scores[topk_inds]
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
if cfg.min_bbox_size > 0:
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
num = min(cfg.max_num, proposals.shape[0])
_, topk_inds = scores.topk(num)
proposals = proposals[topk_inds, :]
return proposals
from __future__ import division
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import xavier_init from mmcv.cnn import xavier_init
from mmdet.core import (AnchorGenerator, anchor_target, multi_apply, from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1,
delta2bbox, weighted_smoothl1, multiclass_nms) multi_apply)
from .anchor_head import AnchorHead
class SSDHead(nn.Module): class SSDHead(AnchorHead):
def __init__(self, def __init__(self,
input_size=300, input_size=300,
in_channels=(512, 1024, 512, 256, 256, 256),
num_classes=81, num_classes=81,
in_channels=(512, 1024, 512, 256, 256, 256),
anchor_strides=(8, 16, 32, 64, 100, 300), anchor_strides=(8, 16, 32, 64, 100, 300),
basesize_ratio_range=(0.1, 0.9), basesize_ratio_range=(0.1, 0.9),
anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]), anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
target_means=(.0, .0, .0, .0), target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)): target_stds=(1.0, 1.0, 1.0, 1.0)):
super(SSDHead, self).__init__() super(AnchorHead, self).__init__()
# construct head self.input_size = input_size
num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
self.in_channels = in_channels
self.cls_out_channels = num_classes self.cls_out_channels = num_classes
num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
reg_convs = [] reg_convs = []
cls_convs = [] cls_convs = []
for i in range(len(in_channels)): for i in range(len(in_channels)):
...@@ -88,6 +87,8 @@ class SSDHead(nn.Module): ...@@ -88,6 +87,8 @@ class SSDHead(nn.Module):
self.target_means = target_means self.target_means = target_means
self.target_stds = target_stds self.target_stds = target_stds
self.use_sigmoid_cls = False
self.use_focal_loss = False
def init_weights(self): def init_weights(self):
for m in self.modules(): for m in self.modules():
...@@ -103,68 +104,28 @@ class SSDHead(nn.Module): ...@@ -103,68 +104,28 @@ class SSDHead(nn.Module):
bbox_preds.append(reg_conv(feat)) bbox_preds.append(reg_conv(feat))
return cls_scores, bbox_preds return cls_scores, bbox_preds
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, cls_score, bbox_pred, labels, label_weights, def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_pos_samples, cfg): bbox_targets, bbox_weights, num_total_samples, cfg):
loss_cls_all = F.cross_entropy( loss_cls_all = F.cross_entropy(
cls_score, labels, reduction='none') * label_weights cls_score, labels, reduction='none') * label_weights
pos_label_inds = (labels > 0).nonzero().view(-1) pos_inds = (labels > 0).nonzero().view(-1)
neg_label_inds = (labels == 0).nonzero().view(-1) neg_inds = (labels == 0).nonzero().view(-1)
num_sample_pos = pos_label_inds.size(0) num_pos_samples = pos_inds.size(0)
num_sample_neg = cfg.neg_pos_ratio * num_sample_pos num_neg_samples = cfg.neg_pos_ratio * num_pos_samples
if num_sample_neg > neg_label_inds.size(0): if num_neg_samples > neg_inds.size(0):
num_sample_neg = neg_label_inds.size(0) num_neg_samples = neg_inds.size(0)
topk_loss_cls_neg, topk_loss_cls_neg_inds = \ topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
loss_cls_all[neg_label_inds].topk(num_sample_neg) loss_cls_pos = loss_cls_all[pos_inds].sum()
loss_cls_pos = loss_cls_all[pos_label_inds].sum()
loss_cls_neg = topk_loss_cls_neg.sum() loss_cls_neg = topk_loss_cls_neg.sum()
loss_cls = (loss_cls_pos + loss_cls_neg) / num_pos_samples loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
loss_reg = weighted_smoothl1( loss_reg = weighted_smoothl1(
bbox_pred, bbox_pred,
bbox_targets, bbox_targets,
bbox_weights, bbox_weights,
beta=cfg.smoothl1_beta, beta=cfg.smoothl1_beta,
avg_factor=num_pos_samples) avg_factor=num_total_samples)
return loss_cls[None], loss_reg return loss_cls[None], loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
...@@ -193,14 +154,14 @@ class SSDHead(nn.Module): ...@@ -193,14 +154,14 @@ class SSDHead(nn.Module):
num_images = len(img_metas) num_images = len(img_metas)
all_cls_scores = torch.cat([ all_cls_scores = torch.cat([
s.permute(0, 2, 3, 1).contiguous().view( s.permute(0, 2, 3, 1).reshape(
num_images, -1, self.cls_out_channels) for s in cls_scores num_images, -1, self.cls_out_channels) for s in cls_scores
], 1) ], 1)
all_labels = torch.cat(labels_list, -1).view(num_images, -1) all_labels = torch.cat(labels_list, -1).view(num_images, -1)
all_label_weights = torch.cat(label_weights_list, -1).view( all_label_weights = torch.cat(label_weights_list, -1).view(
num_images, -1) num_images, -1)
all_bbox_preds = torch.cat([ all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).contiguous().view(num_images, -1, 4) b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
for b in bbox_preds for b in bbox_preds
], -2) ], -2)
all_bbox_targets = torch.cat(bbox_targets_list, -2).view( all_bbox_targets = torch.cat(bbox_targets_list, -2).view(
...@@ -216,68 +177,6 @@ class SSDHead(nn.Module): ...@@ -216,68 +177,6 @@ class SSDHead(nn.Module):
all_label_weights, all_label_weights,
all_bbox_targets, all_bbox_targets,
all_bbox_weights, all_bbox_weights,
num_pos_samples=num_total_pos, num_total_samples=num_total_pos,
cfg=cfg) cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg) return dict(loss_cls=losses_cls, loss_reg=losses_reg)
def get_det_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
mlvl_anchors = [
self.anchor_generators[i].grid_anchors(cls_scores[i].size()[-2:],
self.anchor_strides[i])
for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
results = self._get_det_bboxes_single(
cls_score_list, bbox_pred_list, mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
result_list.append(results)
return result_list
def _get_det_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_proposals = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2, 0).contiguous().view(
-1, self.cls_out_channels)
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).contiguous().view(-1, 4)
proposals = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_proposals.append(proposals)
mlvl_scores.append(scores)
mlvl_proposals = torch.cat(mlvl_proposals)
if rescale:
mlvl_proposals /= mlvl_proposals.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
from mmcv.runner import obj_from_dict from mmcv.runner import obj_from_dict
from torch import nn from torch import nn
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads, from . import (backbones, necks, roi_extractors, anchor_heads, bbox_heads,
mask_heads, single_stage_heads) mask_heads)
__all__ = [
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
'build_bbox_head', 'build_mask_head', 'build_single_stage_head',
'build_detector'
]
def _build_module(cfg, parrent=None, default_args=None): def _build_module(cfg, parrent=None, default_args=None):
...@@ -32,8 +26,8 @@ def build_neck(cfg): ...@@ -32,8 +26,8 @@ def build_neck(cfg):
return build(cfg, necks) return build(cfg, necks)
def build_rpn_head(cfg): def build_anchor_head(cfg):
return build(cfg, rpn_heads) return build(cfg, anchor_heads)
def build_roi_extractor(cfg): def build_roi_extractor(cfg):
...@@ -48,10 +42,6 @@ def build_mask_head(cfg): ...@@ -48,10 +42,6 @@ def build_mask_head(cfg):
return build(cfg, mask_heads) return build(cfg, mask_heads)
def build_single_stage_head(cfg):
return build(cfg, single_stage_heads)
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors from . import detectors
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg)) return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
...@@ -37,7 +37,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -37,7 +37,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
raise NotImplementedError raise NotImplementedError
if rpn_head is not None: if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_anchor_head(rpn_head)
if bbox_head is not None: if bbox_head is not None:
self.bbox_roi_extractor = nn.ModuleList() self.bbox_roi_extractor = nn.ModuleList()
...@@ -123,7 +123,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -123,7 +123,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
......
...@@ -18,7 +18,7 @@ class RPN(BaseDetector, RPNTestMixin): ...@@ -18,7 +18,7 @@ class RPN(BaseDetector, RPNTestMixin):
super(RPN, self).__init__() super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_anchor_head(rpn_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
......
...@@ -18,7 +18,7 @@ class SingleStageDetector(BaseDetector): ...@@ -18,7 +18,7 @@ class SingleStageDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
if neck is not None: if neck is not None:
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
self.bbox_head = builder.build_single_stage_head(bbox_head) self.bbox_head = builder.build_anchor_head(bbox_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
...@@ -51,7 +51,7 @@ class SingleStageDetector(BaseDetector): ...@@ -51,7 +51,7 @@ class SingleStageDetector(BaseDetector):
x = self.extract_feat(img) x = self.extract_feat(img)
outs = self.bbox_head(x) outs = self.bbox_head(x)
bbox_inputs = outs + (img_meta, self.test_cfg, rescale) bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
bbox_list = self.bbox_head.get_det_bboxes(*bbox_inputs) bbox_list = self.bbox_head.get_bboxes(*bbox_inputs)
bbox_results = [ bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list for det_bboxes, det_labels in bbox_list
......
...@@ -7,7 +7,7 @@ class RPNTestMixin(object): ...@@ -7,7 +7,7 @@ class RPNTestMixin(object):
def simple_test_rpn(self, x, img_meta, rpn_test_cfg): def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x) rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg) proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
return proposal_list return proposal_list
def aug_test_rpn(self, feats, img_metas, rpn_test_cfg): def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
......
...@@ -30,7 +30,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -30,7 +30,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise NotImplementedError raise NotImplementedError
if rpn_head is not None: if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_anchor_head(rpn_head)
if bbox_head is not None: if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor( self.bbox_roi_extractor = builder.build_roi_extractor(
...@@ -96,7 +96,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -96,7 +96,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
......
from .rpn_head import RPNHead
__all__ = ['RPNHead']
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy)
from mmdet.ops import nms
from ..utils import normal_init
class RPNHead(nn.Module):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
""" # noqa: W605
def __init__(self,
in_channels,
feat_channels=256,
anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False):
super(RPNHead, self).__init__()
self.in_channels = in_channels
self.feat_channels = feat_channels
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
out_channels = (self.num_anchors
if self.use_sigmoid_cls else self.num_anchors * 2)
self.rpn_cls = nn.Conv2d(feat_channels, out_channels, 1)
self.rpn_reg = nn.Conv2d(feat_channels, self.num_anchors * 4, 1)
self.debug_imgs = None
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
rpn_feat = self.relu(self.rpn_conv(x))
rpn_cls_score = self.rpn_cls(rpn_feat)
rpn_bbox_pred = self.rpn_reg(rpn_feat)
return rpn_cls_score, rpn_bbox_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1)
label_weights = label_weights.contiguous().view(-1)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1)
criterion = weighted_binary_cross_entropy
else:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1, 2)
criterion = weighted_cross_entropy
loss_cls = criterion(
rpn_cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(
-1, 4)
loss_reg = weighted_smoothl1(
rpn_bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_total_samples)
return loss_cls, loss_reg
def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg):
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes)
cls_reg_targets = anchor_target(
anchor_list, valid_flag_list, gt_bboxes, img_shapes,
self.target_means, self.target_stds, cfg)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
losses_cls, losses_reg = multi_apply(
self.loss_single,
rpn_cls_scores,
rpn_bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_pos + num_total_neg,
cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_meta, cfg):
num_imgs = len(img_meta)
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
mlvl_anchors = [
self.anchor_generators[idx].grid_anchors(featmap_sizes[idx],
self.anchor_strides[idx])
for idx in range(len(featmap_sizes))
]
proposal_list = []
for img_id in range(num_imgs):
rpn_cls_score_list = [
rpn_cls_scores[idx][img_id].detach()
for idx in range(len(rpn_cls_scores))
]
rpn_bbox_pred_list = [
rpn_bbox_preds[idx][img_id].detach()
for idx in range(len(rpn_bbox_preds))
]
assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list)
proposals = self._get_proposals_single(
rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors,
img_meta[img_id]['img_shape'], cfg)
proposal_list.append(proposals)
return proposal_list
def _get_proposals_single(self, rpn_cls_scores, rpn_bbox_preds,
mlvl_anchors, img_shape, cfg):
mlvl_proposals = []
for idx in range(len(rpn_cls_scores)):
rpn_cls_score = rpn_cls_scores[idx]
rpn_bbox_pred = rpn_bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
anchors = mlvl_anchors[idx]
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(-1)
rpn_cls_prob = rpn_cls_score.sigmoid()
scores = rpn_cls_prob
else:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(
-1, 2)
rpn_cls_prob = F.softmax(rpn_cls_score, dim=1)
scores = rpn_cls_prob[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).contiguous().view(
-1, 4)
_, order = scores.sort(0, descending=True)
if cfg.nms_pre > 0:
order = order[:cfg.nms_pre]
rpn_bbox_pred = rpn_bbox_pred[order, :]
anchors = anchors[order, :]
scores = scores[order]
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
_, order = scores.sort(0, descending=True)
num = min(cfg.max_num, proposals.shape[0])
order = order[:num]
proposals = proposals[order, :]
return proposals
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment