Commit 70700512 authored by Kai Chen's avatar Kai Chen
Browse files

use AnchorHead to unify rpn head and single stage heads

parent 1b9f9b88
from .detectors import (BaseDetector, TwoStageDetector, RPN, FastRCNN, from .detectors import (BaseDetector, TwoStageDetector, RPN, FastRCNN,
FasterRCNN, MaskRCNN) FasterRCNN, MaskRCNN)
from .builder import (build_neck, build_rpn_head, build_roi_extractor, from .builder import (build_neck, build_anchor_head, build_roi_extractor,
build_bbox_head, build_mask_head, build_detector) build_bbox_head, build_mask_head, build_detector)
__all__ = [ __all__ = [
'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN',
'MaskRCNN', 'build_backbone', 'build_neck', 'build_rpn_head', 'MaskRCNN', 'build_backbone', 'build_neck', 'build_anchor_head',
'build_roi_extractor', 'build_bbox_head', 'build_mask_head', 'build_roi_extractor', 'build_bbox_head', 'build_mask_head',
'build_detector' 'build_detector'
] ]
from .anchor_head import AnchorHead
from .rpn_head import RPNHead
from .retina_head import RetinaHead from .retina_head import RetinaHead
from .ssd_head import SSDHead from .ssd_head import SSDHead
__all__ = ['RetinaHead', 'SSDHead'] __all__ = ['AnchorHead', 'RPNHead', 'RetinaHead', 'SSDHead']
...@@ -4,114 +4,85 @@ import numpy as np ...@@ -4,114 +4,85 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmdet.core import (AnchorGenerator, anchor_target, multi_apply, from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
delta2bbox, weighted_smoothl1, multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy,
weighted_sigmoid_focal_loss, multiclass_nms) weighted_sigmoid_focal_loss, multiclass_nms)
from ..utils import normal_init, bias_init_with_prob from ..utils import normal_init
class RetinaHead(nn.Module): class AnchorHead(nn.Module):
"""Head of RetinaNet. """Anchor-based head (RPN, RetinaNet, SSD, etc.).
/ cls_convs - retina_cls (3x3 conv) / - conv_cls (1x1 conv)
input - input - rpn_conv (3x3 conv) -
\ reg_convs - retina_reg (3x3 conv) \ - conv_reg (1x1 conv)
Args: Args:
in_channels (int): Number of channels in the input feature map. in_channels (int): Number of channels in the input feature map.
num_classes (int): Class number (including background).
stacked_convs (int): Number of convolutional layers added for cls and
reg branch.
feat_channels (int): Number of channels for the RPN feature map. feat_channels (int): Number of channels for the RPN feature map.
scales_per_octave (int): Number of anchor scales per octave. anchor_scales (Iterable): Anchor scales.
octave_base_scale (int): Base octave scale. Anchor scales are computed
as `s*2^(i/n)`, for i in [0, n-1], where s is `octave_base_scale`
and n is `scales_per_octave`.
anchor_ratios (Iterable): Anchor aspect ratios. anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides. anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets. target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets. target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
""" # noqa: W605 """ # noqa: W605
def __init__(self, def __init__(self,
in_channels,
num_classes, num_classes,
stacked_convs=4, in_channels,
feat_channels=256, feat_channels=256,
octave_base_scale=4, anchor_scales=[8, 16, 32],
scales_per_octave=3,
anchor_ratios=[0.5, 1.0, 2.0], anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[8, 16, 32, 64, 128], anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None, anchor_base_sizes=None,
target_means=(.0, .0, .0, .0), target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)): target_stds=(1.0, 1.0, 1.0, 1.0),
super(RetinaHead, self).__init__() use_sigmoid_cls=False,
use_focal_loss=False):
super(AnchorHead, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
self.octave_base_scale = octave_base_scale self.feat_channels = feat_channels
self.scales_per_octave = scales_per_octave self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides self.anchor_strides = anchor_strides
self.anchor_base_sizes = list( self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means self.target_means = target_means
self.target_stds = target_stds self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.use_focal_loss = use_focal_loss
self.anchor_generators = [] self.anchor_generators = []
for anchor_base in self.anchor_base_sizes: for anchor_base in self.anchor_base_sizes:
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
self.anchor_generators.append( self.anchor_generators.append(
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios)) AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.relu = nn.ReLU(inplace=True)
self.num_anchors = int(
len(self.anchor_ratios) * self.scales_per_octave)
self.cls_out_channels = self.num_classes - 1
self.bbox_pred_dim = 4
self.stacked_convs = stacked_convs self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
self.cls_convs = nn.ModuleList() if self.use_sigmoid_cls:
self.reg_convs = nn.ModuleList() self.cls_out_channels = self.num_classes - 1
for i in range(self.stacked_convs): else:
chn = in_channels if i == 0 else feat_channels self.cls_out_channels = self.num_classes
self.cls_convs.append(
nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1)) self._init_layers()
self.reg_convs.append(
nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1)) def _init_layers(self):
self.retina_cls = nn.Conv2d( self.conv_cls = nn.Conv2d(self.feat_channels,
feat_channels, self.num_anchors * self.cls_out_channels, 1)
self.num_anchors * self.cls_out_channels, self.conv_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
3,
stride=1,
padding=1)
self.retina_reg = nn.Conv2d(
feat_channels,
self.num_anchors * self.bbox_pred_dim,
3,
stride=1,
padding=1)
self.debug_imgs = None
def init_weights(self): def init_weights(self):
for m in self.cls_convs: normal_init(self.conv_cls, std=0.01)
normal_init(m, std=0.01) normal_init(self.conv_reg, std=0.01)
for m in self.reg_convs:
normal_init(m, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
def forward_single(self, x): def forward_single(self, x):
cls_feat = x rpn_cls_score = self.conv_cls(x)
reg_feat = x rpn_bbox_pred = self.conv_reg(x)
for cls_conv in self.cls_convs: return rpn_cls_score, rpn_bbox_pred
cls_feat = self.relu(cls_conv(cls_feat))
for reg_conv in self.reg_convs:
reg_feat = self.relu(reg_conv(reg_feat))
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
return cls_score, bbox_pred
def forward(self, feats): def forward(self, feats):
return multi_apply(self.forward_single, feats) return multi_apply(self.forward_single, feats)
...@@ -156,20 +127,34 @@ class RetinaHead(nn.Module): ...@@ -156,20 +127,34 @@ class RetinaHead(nn.Module):
return anchor_list, valid_flag_list return anchor_list, valid_flag_list
def loss_single(self, cls_score, bbox_pred, labels, label_weights, def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_pos_samples, cfg): bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss # classification loss
labels = labels.contiguous().view(-1, self.cls_out_channels) labels = labels.contiguous().view(-1, self.cls_out_channels)
label_weights = label_weights.contiguous().view( label_weights = label_weights.contiguous().view(
-1, self.cls_out_channels) -1, self.cls_out_channels)
cls_score = cls_score.permute(0, 2, 3, 1).contiguous().view( cls_score = cls_score.permute(0, 2, 3, 1).contiguous().view(
-1, self.cls_out_channels) -1, self.cls_out_channels)
loss_cls = weighted_sigmoid_focal_loss( if self.use_sigmoid_cls:
cls_score, if self.use_focal_loss:
labels, cls_criterion = weighted_sigmoid_focal_loss
label_weights, else:
cfg.gamma, cls_criterion = weighted_binary_cross_entropy
cfg.alpha, else:
avg_factor=num_pos_samples) if self.use_focal_loss:
raise NotImplementedError
else:
cls_criterion = weighted_cross_entropy
if self.use_focal_loss:
loss_cls = cls_criterion(
cls_score,
labels,
label_weights,
gamma=cfg.gamma,
alpha=cfg.alpha,
avg_factor=num_total_samples)
else:
loss_cls = cls_criterion(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss # regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4) bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4) bbox_weights = bbox_weights.contiguous().view(-1, 4)
...@@ -179,7 +164,7 @@ class RetinaHead(nn.Module): ...@@ -179,7 +164,7 @@ class RetinaHead(nn.Module):
bbox_targets, bbox_targets,
bbox_weights, bbox_weights,
beta=cfg.smoothl1_beta, beta=cfg.smoothl1_beta,
avg_factor=num_pos_samples) avg_factor=num_total_samples)
return loss_cls, loss_reg return loss_cls, loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
...@@ -189,6 +174,7 @@ class RetinaHead(nn.Module): ...@@ -189,6 +174,7 @@ class RetinaHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors( anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas) featmap_sizes, img_metas)
sampling = False if self.use_focal_loss else True
cls_reg_targets = anchor_target( cls_reg_targets = anchor_target(
anchor_list, anchor_list,
valid_flag_list, valid_flag_list,
...@@ -199,12 +185,13 @@ class RetinaHead(nn.Module): ...@@ -199,12 +185,13 @@ class RetinaHead(nn.Module):
cfg, cfg,
gt_labels_list=gt_labels, gt_labels_list=gt_labels,
cls_out_channels=self.cls_out_channels, cls_out_channels=self.cls_out_channels,
sampling=False) sampling=sampling)
if cls_reg_targets is None: if cls_reg_targets is None:
return None return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (num_total_pos if self.use_focal_loss else
num_total_pos + num_total_neg)
losses_cls, losses_reg = multi_apply( losses_cls, losses_reg = multi_apply(
self.loss_single, self.loss_single,
cls_scores, cls_scores,
...@@ -213,16 +200,12 @@ class RetinaHead(nn.Module): ...@@ -213,16 +200,12 @@ class RetinaHead(nn.Module):
label_weights_list, label_weights_list,
bbox_targets_list, bbox_targets_list,
bbox_weights_list, bbox_weights_list,
num_pos_samples=num_total_pos, num_total_samples=num_total_samples,
cfg=cfg) cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg) return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
def get_det_bboxes(self, def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
cls_scores, rescale=False):
bbox_preds,
img_metas,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores) num_levels = len(cls_scores)
...@@ -231,7 +214,6 @@ class RetinaHead(nn.Module): ...@@ -231,7 +214,6 @@ class RetinaHead(nn.Module):
self.anchor_strides[i]) self.anchor_strides[i])
for i in range(num_levels) for i in range(num_levels)
] ]
result_list = [] result_list = []
for img_id in range(len(img_metas)): for img_id in range(len(img_metas)):
cls_score_list = [ cls_score_list = [
...@@ -242,46 +224,54 @@ class RetinaHead(nn.Module): ...@@ -242,46 +224,54 @@ class RetinaHead(nn.Module):
] ]
img_shape = img_metas[img_id]['img_shape'] img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor'] scale_factor = img_metas[img_id]['scale_factor']
results = self._get_det_bboxes_single( proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
cls_score_list, bbox_pred_list, mlvl_anchors, img_shape, mlvl_anchors, img_shape,
scale_factor, cfg, rescale) scale_factor, cfg, rescale)
result_list.append(results) result_list.append(proposals)
return result_list return result_list
def _get_det_bboxes_single(self, def get_bboxes_single(self,
cls_scores, cls_scores,
bbox_preds, bbox_preds,
mlvl_anchors, mlvl_anchors,
img_shape, img_shape,
scale_factor, scale_factor,
cfg, cfg,
rescale=False): rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_proposals = [] mlvl_bboxes = []
mlvl_scores = [] mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds, for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
mlvl_anchors): mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:] assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2, 0).contiguous().view( cls_score = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels) -1, self.cls_out_channels)
scores = cls_score.sigmoid() if self.use_sigmoid_cls:
bbox_pred = bbox_pred.permute(1, 2, 0).contiguous().view(-1, 4) scores = cls_score.sigmoid()
proposals = delta2bbox(anchors, bbox_pred, self.target_means, else:
self.target_stds, img_shape) scores = cls_score.softmax(-1)
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
maxscores, _ = scores.max(dim=1) nms_pre = cfg.get('nms_pre', -1)
_, topk_inds = maxscores.topk(cfg.nms_pre) if nms_pre > 0 and scores.shape[0] > nms_pre:
proposals = proposals[topk_inds, :] if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
max_scores, _ = scores[:, 1:].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :] scores = scores[topk_inds, :]
mlvl_proposals.append(proposals) bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores) mlvl_scores.append(scores)
mlvl_proposals = torch.cat(mlvl_proposals) mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale: if rescale:
mlvl_proposals /= scale_factor mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores) mlvl_scores = torch.cat(mlvl_scores)
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) if self.use_sigmoid_cls:
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores, mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
cfg.score_thr, cfg.nms, det_bboxes, det_labels = multiclass_nms(
cfg.max_per_img) mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img)
return det_bboxes, det_labels return det_bboxes, det_labels
import numpy as np
import torch.nn as nn
from mmcv.cnn import normal_init
from .anchor_head import AnchorHead
from ..utils import bias_init_with_prob
class RetinaHead(AnchorHead):
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
octave_base_scale=4,
scales_per_octave=3,
**kwargs):
self.stacked_convs = stacked_convs
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
super(RetinaHead, self).__init__(
num_classes,
in_channels,
anchor_scales=anchor_scales,
use_sigmoid_cls=True,
use_focal_loss=True,
**kwargs)
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
nn.Conv2d(chn, self.feat_channels, 3, stride=1, padding=1))
self.reg_convs.append(
nn.Conv2d(chn, self.feat_channels, 3, stride=1, padding=1))
self.retina_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.retina_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
def init_weights(self):
for m in self.cls_convs:
normal_init(m, std=0.01)
for m in self.reg_convs:
normal_init(m, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
def forward_single(self, x):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = self.relu(cls_conv(cls_feat))
for reg_conv in self.reg_convs:
reg_feat = self.relu(reg_conv(reg_feat))
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
return cls_score, bbox_pred
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.core import delta2bbox
from mmdet.ops import nms
from .anchor_head import AnchorHead
class RPNHead(AnchorHead):
def __init__(self, in_channels, **kwargs):
super(RPNHead, self).__init__(2, in_channels, **kwargs)
def _init_layers(self):
self.rpn_conv = nn.Conv2d(
self.in_channels, self.feat_channels, 3, padding=1)
self.rpn_cls = nn.Conv2d(self.feat_channels,
self.num_anchors * self.cls_out_channels, 1)
self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
x = self.rpn_conv(x)
x = F.relu(x, inplace=True)
rpn_cls_score = self.rpn_cls(x)
rpn_bbox_pred = self.rpn_reg(x)
return rpn_cls_score, rpn_bbox_pred
def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, cfg):
return super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes,
None, img_metas, cfg)
def get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
mlvl_proposals = []
for idx in range(len(cls_scores)):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
anchors = mlvl_anchors[idx]
rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(-1)
scores = rpn_cls_score.sigmoid()
else:
rpn_cls_score = rpn_cls_score.reshape(-1, 2)
scores = rpn_cls_score.softmax(dim=1)[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
_, topk_inds = scores.topk(cfg.nms_pre)
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
anchors = anchors[topk_inds, :]
scores = scores[topk_inds]
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
if cfg.min_bbox_size > 0:
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
num = min(cfg.max_num, proposals.shape[0])
_, topk_inds = scores.topk(num)
proposals = proposals[topk_inds, :]
return proposals
from __future__ import division
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import xavier_init from mmcv.cnn import xavier_init
from mmdet.core import (AnchorGenerator, anchor_target, multi_apply, from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1,
delta2bbox, weighted_smoothl1, multiclass_nms) multi_apply)
from .anchor_head import AnchorHead
class SSDHead(nn.Module): class SSDHead(AnchorHead):
def __init__(self, def __init__(self,
input_size=300, input_size=300,
in_channels=(512, 1024, 512, 256, 256, 256),
num_classes=81, num_classes=81,
in_channels=(512, 1024, 512, 256, 256, 256),
anchor_strides=(8, 16, 32, 64, 100, 300), anchor_strides=(8, 16, 32, 64, 100, 300),
basesize_ratio_range=(0.1, 0.9), basesize_ratio_range=(0.1, 0.9),
anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]), anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
target_means=(.0, .0, .0, .0), target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)): target_stds=(1.0, 1.0, 1.0, 1.0)):
super(SSDHead, self).__init__() super(AnchorHead, self).__init__()
# construct head self.input_size = input_size
num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
self.in_channels = in_channels
self.cls_out_channels = num_classes self.cls_out_channels = num_classes
num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
reg_convs = [] reg_convs = []
cls_convs = [] cls_convs = []
for i in range(len(in_channels)): for i in range(len(in_channels)):
...@@ -88,6 +87,8 @@ class SSDHead(nn.Module): ...@@ -88,6 +87,8 @@ class SSDHead(nn.Module):
self.target_means = target_means self.target_means = target_means
self.target_stds = target_stds self.target_stds = target_stds
self.use_sigmoid_cls = False
self.use_focal_loss = False
def init_weights(self): def init_weights(self):
for m in self.modules(): for m in self.modules():
...@@ -103,68 +104,28 @@ class SSDHead(nn.Module): ...@@ -103,68 +104,28 @@ class SSDHead(nn.Module):
bbox_preds.append(reg_conv(feat)) bbox_preds.append(reg_conv(feat))
return cls_scores, bbox_preds return cls_scores, bbox_preds
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, cls_score, bbox_pred, labels, label_weights, def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_pos_samples, cfg): bbox_targets, bbox_weights, num_total_samples, cfg):
loss_cls_all = F.cross_entropy( loss_cls_all = F.cross_entropy(
cls_score, labels, reduction='none') * label_weights cls_score, labels, reduction='none') * label_weights
pos_label_inds = (labels > 0).nonzero().view(-1) pos_inds = (labels > 0).nonzero().view(-1)
neg_label_inds = (labels == 0).nonzero().view(-1) neg_inds = (labels == 0).nonzero().view(-1)
num_sample_pos = pos_label_inds.size(0) num_pos_samples = pos_inds.size(0)
num_sample_neg = cfg.neg_pos_ratio * num_sample_pos num_neg_samples = cfg.neg_pos_ratio * num_pos_samples
if num_sample_neg > neg_label_inds.size(0): if num_neg_samples > neg_inds.size(0):
num_sample_neg = neg_label_inds.size(0) num_neg_samples = neg_inds.size(0)
topk_loss_cls_neg, topk_loss_cls_neg_inds = \ topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
loss_cls_all[neg_label_inds].topk(num_sample_neg) loss_cls_pos = loss_cls_all[pos_inds].sum()
loss_cls_pos = loss_cls_all[pos_label_inds].sum()
loss_cls_neg = topk_loss_cls_neg.sum() loss_cls_neg = topk_loss_cls_neg.sum()
loss_cls = (loss_cls_pos + loss_cls_neg) / num_pos_samples loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
loss_reg = weighted_smoothl1( loss_reg = weighted_smoothl1(
bbox_pred, bbox_pred,
bbox_targets, bbox_targets,
bbox_weights, bbox_weights,
beta=cfg.smoothl1_beta, beta=cfg.smoothl1_beta,
avg_factor=num_pos_samples) avg_factor=num_total_samples)
return loss_cls[None], loss_reg return loss_cls[None], loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
...@@ -193,14 +154,14 @@ class SSDHead(nn.Module): ...@@ -193,14 +154,14 @@ class SSDHead(nn.Module):
num_images = len(img_metas) num_images = len(img_metas)
all_cls_scores = torch.cat([ all_cls_scores = torch.cat([
s.permute(0, 2, 3, 1).contiguous().view( s.permute(0, 2, 3, 1).reshape(
num_images, -1, self.cls_out_channels) for s in cls_scores num_images, -1, self.cls_out_channels) for s in cls_scores
], 1) ], 1)
all_labels = torch.cat(labels_list, -1).view(num_images, -1) all_labels = torch.cat(labels_list, -1).view(num_images, -1)
all_label_weights = torch.cat(label_weights_list, -1).view( all_label_weights = torch.cat(label_weights_list, -1).view(
num_images, -1) num_images, -1)
all_bbox_preds = torch.cat([ all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).contiguous().view(num_images, -1, 4) b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
for b in bbox_preds for b in bbox_preds
], -2) ], -2)
all_bbox_targets = torch.cat(bbox_targets_list, -2).view( all_bbox_targets = torch.cat(bbox_targets_list, -2).view(
...@@ -216,68 +177,6 @@ class SSDHead(nn.Module): ...@@ -216,68 +177,6 @@ class SSDHead(nn.Module):
all_label_weights, all_label_weights,
all_bbox_targets, all_bbox_targets,
all_bbox_weights, all_bbox_weights,
num_pos_samples=num_total_pos, num_total_samples=num_total_pos,
cfg=cfg) cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg) return dict(loss_cls=losses_cls, loss_reg=losses_reg)
def get_det_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
mlvl_anchors = [
self.anchor_generators[i].grid_anchors(cls_scores[i].size()[-2:],
self.anchor_strides[i])
for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
results = self._get_det_bboxes_single(
cls_score_list, bbox_pred_list, mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
result_list.append(results)
return result_list
def _get_det_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_proposals = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2, 0).contiguous().view(
-1, self.cls_out_channels)
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).contiguous().view(-1, 4)
proposals = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_proposals.append(proposals)
mlvl_scores.append(scores)
mlvl_proposals = torch.cat(mlvl_proposals)
if rescale:
mlvl_proposals /= mlvl_proposals.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
from mmcv.runner import obj_from_dict from mmcv.runner import obj_from_dict
from torch import nn from torch import nn
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads, from . import (backbones, necks, roi_extractors, anchor_heads, bbox_heads,
mask_heads, single_stage_heads) mask_heads)
__all__ = [
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
'build_bbox_head', 'build_mask_head', 'build_single_stage_head',
'build_detector'
]
def _build_module(cfg, parrent=None, default_args=None): def _build_module(cfg, parrent=None, default_args=None):
...@@ -32,8 +26,8 @@ def build_neck(cfg): ...@@ -32,8 +26,8 @@ def build_neck(cfg):
return build(cfg, necks) return build(cfg, necks)
def build_rpn_head(cfg): def build_anchor_head(cfg):
return build(cfg, rpn_heads) return build(cfg, anchor_heads)
def build_roi_extractor(cfg): def build_roi_extractor(cfg):
...@@ -48,10 +42,6 @@ def build_mask_head(cfg): ...@@ -48,10 +42,6 @@ def build_mask_head(cfg):
return build(cfg, mask_heads) return build(cfg, mask_heads)
def build_single_stage_head(cfg):
return build(cfg, single_stage_heads)
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors from . import detectors
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg)) return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
...@@ -37,7 +37,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -37,7 +37,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
raise NotImplementedError raise NotImplementedError
if rpn_head is not None: if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_anchor_head(rpn_head)
if bbox_head is not None: if bbox_head is not None:
self.bbox_roi_extractor = nn.ModuleList() self.bbox_roi_extractor = nn.ModuleList()
...@@ -123,7 +123,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -123,7 +123,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
......
...@@ -18,7 +18,7 @@ class RPN(BaseDetector, RPNTestMixin): ...@@ -18,7 +18,7 @@ class RPN(BaseDetector, RPNTestMixin):
super(RPN, self).__init__() super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_anchor_head(rpn_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
......
...@@ -18,7 +18,7 @@ class SingleStageDetector(BaseDetector): ...@@ -18,7 +18,7 @@ class SingleStageDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
if neck is not None: if neck is not None:
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
self.bbox_head = builder.build_single_stage_head(bbox_head) self.bbox_head = builder.build_anchor_head(bbox_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
...@@ -51,7 +51,7 @@ class SingleStageDetector(BaseDetector): ...@@ -51,7 +51,7 @@ class SingleStageDetector(BaseDetector):
x = self.extract_feat(img) x = self.extract_feat(img)
outs = self.bbox_head(x) outs = self.bbox_head(x)
bbox_inputs = outs + (img_meta, self.test_cfg, rescale) bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
bbox_list = self.bbox_head.get_det_bboxes(*bbox_inputs) bbox_list = self.bbox_head.get_bboxes(*bbox_inputs)
bbox_results = [ bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list for det_bboxes, det_labels in bbox_list
......
...@@ -7,7 +7,7 @@ class RPNTestMixin(object): ...@@ -7,7 +7,7 @@ class RPNTestMixin(object):
def simple_test_rpn(self, x, img_meta, rpn_test_cfg): def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x) rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg) proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
return proposal_list return proposal_list
def aug_test_rpn(self, feats, img_metas, rpn_test_cfg): def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
......
...@@ -30,7 +30,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -30,7 +30,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise NotImplementedError raise NotImplementedError
if rpn_head is not None: if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_anchor_head(rpn_head)
if bbox_head is not None: if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor( self.bbox_roi_extractor = builder.build_roi_extractor(
...@@ -96,7 +96,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -96,7 +96,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
......
from .rpn_head import RPNHead
__all__ = ['RPNHead']
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy)
from mmdet.ops import nms
from ..utils import normal_init
class RPNHead(nn.Module):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
""" # noqa: W605
def __init__(self,
in_channels,
feat_channels=256,
anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False):
super(RPNHead, self).__init__()
self.in_channels = in_channels
self.feat_channels = feat_channels
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
out_channels = (self.num_anchors
if self.use_sigmoid_cls else self.num_anchors * 2)
self.rpn_cls = nn.Conv2d(feat_channels, out_channels, 1)
self.rpn_reg = nn.Conv2d(feat_channels, self.num_anchors * 4, 1)
self.debug_imgs = None
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
rpn_feat = self.relu(self.rpn_conv(x))
rpn_cls_score = self.rpn_cls(rpn_feat)
rpn_bbox_pred = self.rpn_reg(rpn_feat)
return rpn_cls_score, rpn_bbox_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1)
label_weights = label_weights.contiguous().view(-1)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1)
criterion = weighted_binary_cross_entropy
else:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1, 2)
criterion = weighted_cross_entropy
loss_cls = criterion(
rpn_cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(
-1, 4)
loss_reg = weighted_smoothl1(
rpn_bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_total_samples)
return loss_cls, loss_reg
def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg):
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes)
cls_reg_targets = anchor_target(
anchor_list, valid_flag_list, gt_bboxes, img_shapes,
self.target_means, self.target_stds, cfg)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
losses_cls, losses_reg = multi_apply(
self.loss_single,
rpn_cls_scores,
rpn_bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_pos + num_total_neg,
cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_meta, cfg):
num_imgs = len(img_meta)
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
mlvl_anchors = [
self.anchor_generators[idx].grid_anchors(featmap_sizes[idx],
self.anchor_strides[idx])
for idx in range(len(featmap_sizes))
]
proposal_list = []
for img_id in range(num_imgs):
rpn_cls_score_list = [
rpn_cls_scores[idx][img_id].detach()
for idx in range(len(rpn_cls_scores))
]
rpn_bbox_pred_list = [
rpn_bbox_preds[idx][img_id].detach()
for idx in range(len(rpn_bbox_preds))
]
assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list)
proposals = self._get_proposals_single(
rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors,
img_meta[img_id]['img_shape'], cfg)
proposal_list.append(proposals)
return proposal_list
def _get_proposals_single(self, rpn_cls_scores, rpn_bbox_preds,
mlvl_anchors, img_shape, cfg):
mlvl_proposals = []
for idx in range(len(rpn_cls_scores)):
rpn_cls_score = rpn_cls_scores[idx]
rpn_bbox_pred = rpn_bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
anchors = mlvl_anchors[idx]
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(-1)
rpn_cls_prob = rpn_cls_score.sigmoid()
scores = rpn_cls_prob
else:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(
-1, 2)
rpn_cls_prob = F.softmax(rpn_cls_score, dim=1)
scores = rpn_cls_prob[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).contiguous().view(
-1, 4)
_, order = scores.sort(0, descending=True)
if cfg.nms_pre > 0:
order = order[:cfg.nms_pre]
rpn_bbox_pred = rpn_bbox_pred[order, :]
anchors = anchors[order, :]
scores = scores[order]
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
_, order = scores.sort(0, descending=True)
num = min(cfg.max_num, proposals.shape[0])
order = order[:num]
proposals = proposals[order, :]
return proposals
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment