Unverified Commit a8ec6fb3 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #252 from hellock/anchor-head

Unify RPNHead and single stage heads with AnchorHead
parents d7d4a991 2df1e0a0
......@@ -12,7 +12,7 @@ def anchor_target(anchor_list,
target_stds,
cfg,
gt_labels_list=None,
cls_out_channels=1,
label_channels=1,
sampling=True,
unmap_outputs=True):
"""Compute regression and classification targets for anchors.
......@@ -54,7 +54,7 @@ def anchor_target(anchor_list,
target_means=target_means,
target_stds=target_stds,
cfg=cfg,
cls_out_channels=cls_out_channels,
label_channels=label_channels,
sampling=sampling,
unmap_outputs=unmap_outputs)
# no valid anchors
......@@ -95,7 +95,7 @@ def anchor_target_single(flat_anchors,
target_means,
target_stds,
cfg,
cls_out_channels=1,
label_channels=1,
sampling=True,
unmap_outputs=True):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
......@@ -147,9 +147,9 @@ def anchor_target_single(flat_anchors,
num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags)
if cls_out_channels > 1:
labels, label_weights = expand_binary_labels(labels, label_weights,
cls_out_channels)
if label_channels > 1:
labels, label_weights = expand_binary_labels(
labels, label_weights, label_channels)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
......@@ -157,14 +157,14 @@ def anchor_target_single(flat_anchors,
neg_inds)
def expand_binary_labels(labels, label_weights, cls_out_channels):
def expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full(
(labels.size(0), cls_out_channels), 0, dtype=torch.float32)
(labels.size(0), label_channels), 0, dtype=torch.float32)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), cls_out_channels)
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
......
from .detectors import (BaseDetector, TwoStageDetector, RPN, FastRCNN,
FasterRCNN, MaskRCNN)
from .builder import (build_neck, build_rpn_head, build_roi_extractor,
from .builder import (build_neck, build_anchor_head, build_roi_extractor,
build_bbox_head, build_mask_head, build_detector)
__all__ = [
'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN',
'MaskRCNN', 'build_backbone', 'build_neck', 'build_rpn_head',
'MaskRCNN', 'build_backbone', 'build_neck', 'build_anchor_head',
'build_roi_extractor', 'build_bbox_head', 'build_mask_head',
'build_detector'
]
from .anchor_head import AnchorHead
from .rpn_head import RPNHead
from .retina_head import RetinaHead
from .ssd_head import SSDHead
__all__ = ['RetinaHead', 'SSDHead']
__all__ = ['AnchorHead', 'RPNHead', 'RetinaHead', 'SSDHead']
......@@ -4,113 +4,81 @@ import numpy as np
import torch
import torch.nn as nn
from mmdet.core import (AnchorGenerator, anchor_target, multi_apply,
delta2bbox, weighted_smoothl1,
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy,
weighted_sigmoid_focal_loss, multiclass_nms)
from ..utils import normal_init, bias_init_with_prob
from ..utils import normal_init
class RetinaHead(nn.Module):
"""Head of RetinaNet.
/ cls_convs - retina_cls (3x3 conv)
input -
\ reg_convs - retina_reg (3x3 conv)
class AnchorHead(nn.Module):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
Args:
in_channels (int): Number of channels in the input feature map.
num_classes (int): Class number (including background).
stacked_convs (int): Number of convolutional layers added for cls and
reg branch.
feat_channels (int): Number of channels for the RPN feature map.
scales_per_octave (int): Number of anchor scales per octave.
octave_base_scale (int): Base octave scale. Anchor scales are computed
as `s*2^(i/n)`, for i in [0, n-1], where s is `octave_base_scale`
and n is `scales_per_octave`.
feat_channels (int): Number of channels of the feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
use_focal_loss (bool): Whether to use focal loss for classification.
""" # noqa: W605
def __init__(self,
in_channels,
num_classes,
stacked_convs=4,
in_channels,
feat_channels=256,
octave_base_scale=4,
scales_per_octave=3,
anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[8, 16, 32, 64, 128],
anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)):
super(RetinaHead, self).__init__()
target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False,
use_focal_loss=False):
super(AnchorHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
self.feat_channels = feat_channels
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.use_focal_loss = use_focal_loss
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
self.anchor_generators.append(
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.relu = nn.ReLU(inplace=True)
self.num_anchors = int(
len(self.anchor_ratios) * self.scales_per_octave)
self.cls_out_channels = self.num_classes - 1
self.bbox_pred_dim = 4
self.stacked_convs = stacked_convs
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = in_channels if i == 0 else feat_channels
self.cls_convs.append(
nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1))
self.reg_convs.append(
nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1))
self.retina_cls = nn.Conv2d(
feat_channels,
self.num_anchors * self.cls_out_channels,
3,
stride=1,
padding=1)
self.retina_reg = nn.Conv2d(
feat_channels,
self.num_anchors * self.bbox_pred_dim,
3,
stride=1,
padding=1)
self.debug_imgs = None
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
if self.use_sigmoid_cls:
self.cls_out_channels = self.num_classes - 1
else:
self.cls_out_channels = self.num_classes
self._init_layers()
def _init_layers(self):
self.conv_cls = nn.Conv2d(self.feat_channels,
self.num_anchors * self.cls_out_channels, 1)
self.conv_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
def init_weights(self):
for m in self.cls_convs:
normal_init(m, std=0.01)
for m in self.reg_convs:
normal_init(m, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
normal_init(self.conv_cls, std=0.01)
normal_init(self.conv_reg, std=0.01)
def forward_single(self, x):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = self.relu(cls_conv(cls_feat))
for reg_conv in self.reg_convs:
reg_feat = self.relu(reg_conv(reg_feat))
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
return cls_score, bbox_pred
def forward(self, feats):
......@@ -156,30 +124,47 @@ class RetinaHead(nn.Module):
return anchor_list, valid_flag_list
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_pos_samples, cfg):
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1, self.cls_out_channels)
label_weights = label_weights.contiguous().view(
-1, self.cls_out_channels)
cls_score = cls_score.permute(0, 2, 3, 1).contiguous().view(
if self.use_sigmoid_cls:
labels = labels.reshape(-1, self.cls_out_channels)
label_weights = label_weights.reshape(-1, self.cls_out_channels)
else:
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels)
loss_cls = weighted_sigmoid_focal_loss(
cls_score,
labels,
label_weights,
cfg.gamma,
cfg.alpha,
avg_factor=num_pos_samples)
if self.use_sigmoid_cls:
if self.use_focal_loss:
cls_criterion = weighted_sigmoid_focal_loss
else:
cls_criterion = weighted_binary_cross_entropy
else:
if self.use_focal_loss:
raise NotImplementedError
else:
cls_criterion = weighted_cross_entropy
if self.use_focal_loss:
loss_cls = cls_criterion(
cls_score,
labels,
label_weights,
gamma=cfg.gamma,
alpha=cfg.alpha,
avg_factor=num_total_samples)
else:
loss_cls = cls_criterion(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).contiguous().view(-1, 4)
bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
loss_reg = weighted_smoothl1(
bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_pos_samples)
avg_factor=num_total_samples)
return loss_cls, loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
......@@ -189,6 +174,8 @@ class RetinaHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas)
sampling = False if self.use_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
anchor_list,
valid_flag_list,
......@@ -198,13 +185,14 @@ class RetinaHead(nn.Module):
self.target_stds,
cfg,
gt_labels_list=gt_labels,
cls_out_channels=self.cls_out_channels,
sampling=False)
label_channels=label_channels,
sampling=sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (num_total_pos if self.use_focal_loss else
num_total_pos + num_total_neg)
losses_cls, losses_reg = multi_apply(
self.loss_single,
cls_scores,
......@@ -213,16 +201,12 @@ class RetinaHead(nn.Module):
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_pos_samples=num_total_pos,
num_total_samples=num_total_samples,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg)
def get_det_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg,
rescale=False):
def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
......@@ -231,7 +215,6 @@ class RetinaHead(nn.Module):
self.anchor_strides[i])
for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
......@@ -242,46 +225,54 @@ class RetinaHead(nn.Module):
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
results = self._get_det_bboxes_single(
cls_score_list, bbox_pred_list, mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
result_list.append(results)
proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
result_list.append(proposals)
return result_list
def _get_det_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
def get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_proposals = []
mlvl_bboxes = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2, 0).contiguous().view(
cls_score = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels)
scores = cls_score.sigmoid()
bbox_pred = bbox_pred.permute(1, 2, 0).contiguous().view(-1, 4)
proposals = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
maxscores, _ = scores.max(dim=1)
_, topk_inds = maxscores.topk(cfg.nms_pre)
proposals = proposals[topk_inds, :]
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
max_scores, _ = scores[:, 1:].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
mlvl_proposals.append(proposals)
bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_proposals = torch.cat(mlvl_proposals)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_proposals /= scale_factor
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
if self.use_sigmoid_cls:
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img)
return det_bboxes, det_labels
import numpy as np
import torch.nn as nn
from mmcv.cnn import normal_init
from .anchor_head import AnchorHead
from ..utils import bias_init_with_prob
class RetinaHead(AnchorHead):
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
octave_base_scale=4,
scales_per_octave=3,
**kwargs):
self.stacked_convs = stacked_convs
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
super(RetinaHead, self).__init__(
num_classes,
in_channels,
anchor_scales=anchor_scales,
use_sigmoid_cls=True,
use_focal_loss=True,
**kwargs)
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
nn.Conv2d(chn, self.feat_channels, 3, stride=1, padding=1))
self.reg_convs.append(
nn.Conv2d(chn, self.feat_channels, 3, stride=1, padding=1))
self.retina_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.retina_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
def init_weights(self):
for m in self.cls_convs:
normal_init(m, std=0.01)
for m in self.reg_convs:
normal_init(m, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
def forward_single(self, x):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = self.relu(cls_conv(cls_feat))
for reg_conv in self.reg_convs:
reg_feat = self.relu(reg_conv(reg_feat))
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
return cls_score, bbox_pred
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.core import delta2bbox
from mmdet.ops import nms
from .anchor_head import AnchorHead
class RPNHead(AnchorHead):
def __init__(self, in_channels, **kwargs):
super(RPNHead, self).__init__(2, in_channels, **kwargs)
def _init_layers(self):
self.rpn_conv = nn.Conv2d(
self.in_channels, self.feat_channels, 3, padding=1)
self.rpn_cls = nn.Conv2d(self.feat_channels,
self.num_anchors * self.cls_out_channels, 1)
self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
x = self.rpn_conv(x)
x = F.relu(x, inplace=True)
rpn_cls_score = self.rpn_cls(x)
rpn_bbox_pred = self.rpn_reg(x)
return rpn_cls_score, rpn_bbox_pred
def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, cfg):
losses = super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes,
None, img_metas, cfg)
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_reg=losses['loss_reg'])
def get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
mlvl_proposals = []
for idx in range(len(cls_scores)):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
anchors = mlvl_anchors[idx]
rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(-1)
scores = rpn_cls_score.sigmoid()
else:
rpn_cls_score = rpn_cls_score.reshape(-1, 2)
scores = rpn_cls_score.softmax(dim=1)[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
_, topk_inds = scores.topk(cfg.nms_pre)
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
anchors = anchors[topk_inds, :]
scores = scores[topk_inds]
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
if cfg.min_bbox_size > 0:
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
num = min(cfg.max_num, proposals.shape[0])
_, topk_inds = scores.topk(num)
proposals = proposals[topk_inds, :]
return proposals
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init
from mmdet.core import (AnchorGenerator, anchor_target, multi_apply,
delta2bbox, weighted_smoothl1, multiclass_nms)
from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1,
multi_apply)
from .anchor_head import AnchorHead
class SSDHead(nn.Module):
class SSDHead(AnchorHead):
def __init__(self,
input_size=300,
in_channels=(512, 1024, 512, 256, 256, 256),
num_classes=81,
in_channels=(512, 1024, 512, 256, 256, 256),
anchor_strides=(8, 16, 32, 64, 100, 300),
basesize_ratio_range=(0.1, 0.9),
anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)):
super(SSDHead, self).__init__()
# construct head
num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
self.in_channels = in_channels
super(AnchorHead, self).__init__()
self.input_size = input_size
self.num_classes = num_classes
self.in_channels = in_channels
self.cls_out_channels = num_classes
num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
reg_convs = []
cls_convs = []
for i in range(len(in_channels)):
......@@ -88,6 +87,8 @@ class SSDHead(nn.Module):
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = False
self.use_focal_loss = False
def init_weights(self):
for m in self.modules():
......@@ -103,68 +104,28 @@ class SSDHead(nn.Module):
bbox_preds.append(reg_conv(feat))
return cls_scores, bbox_preds
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_pos_samples, cfg):
bbox_targets, bbox_weights, num_total_samples, cfg):
loss_cls_all = F.cross_entropy(
cls_score, labels, reduction='none') * label_weights
pos_label_inds = (labels > 0).nonzero().view(-1)
neg_label_inds = (labels == 0).nonzero().view(-1)
num_sample_pos = pos_label_inds.size(0)
num_sample_neg = cfg.neg_pos_ratio * num_sample_pos
if num_sample_neg > neg_label_inds.size(0):
num_sample_neg = neg_label_inds.size(0)
topk_loss_cls_neg, topk_loss_cls_neg_inds = \
loss_cls_all[neg_label_inds].topk(num_sample_neg)
loss_cls_pos = loss_cls_all[pos_label_inds].sum()
pos_inds = (labels > 0).nonzero().view(-1)
neg_inds = (labels == 0).nonzero().view(-1)
num_pos_samples = pos_inds.size(0)
num_neg_samples = cfg.neg_pos_ratio * num_pos_samples
if num_neg_samples > neg_inds.size(0):
num_neg_samples = neg_inds.size(0)
topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
loss_cls_pos = loss_cls_all[pos_inds].sum()
loss_cls_neg = topk_loss_cls_neg.sum()
loss_cls = (loss_cls_pos + loss_cls_neg) / num_pos_samples
loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
loss_reg = weighted_smoothl1(
bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_pos_samples)
avg_factor=num_total_samples)
return loss_cls[None], loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
......@@ -193,14 +154,14 @@ class SSDHead(nn.Module):
num_images = len(img_metas)
all_cls_scores = torch.cat([
s.permute(0, 2, 3, 1).contiguous().view(
s.permute(0, 2, 3, 1).reshape(
num_images, -1, self.cls_out_channels) for s in cls_scores
], 1)
all_labels = torch.cat(labels_list, -1).view(num_images, -1)
all_label_weights = torch.cat(label_weights_list, -1).view(
num_images, -1)
all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).contiguous().view(num_images, -1, 4)
b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
for b in bbox_preds
], -2)
all_bbox_targets = torch.cat(bbox_targets_list, -2).view(
......@@ -216,68 +177,6 @@ class SSDHead(nn.Module):
all_label_weights,
all_bbox_targets,
all_bbox_weights,
num_pos_samples=num_total_pos,
num_total_samples=num_total_pos,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg)
def get_det_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
mlvl_anchors = [
self.anchor_generators[i].grid_anchors(cls_scores[i].size()[-2:],
self.anchor_strides[i])
for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
results = self._get_det_bboxes_single(
cls_score_list, bbox_pred_list, mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
result_list.append(results)
return result_list
def _get_det_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_proposals = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2, 0).contiguous().view(
-1, self.cls_out_channels)
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).contiguous().view(-1, 4)
proposals = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_proposals.append(proposals)
mlvl_scores.append(scores)
mlvl_proposals = torch.cat(mlvl_proposals)
if rescale:
mlvl_proposals /= mlvl_proposals.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
from mmcv.runner import obj_from_dict
from torch import nn
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
mask_heads, single_stage_heads)
__all__ = [
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
'build_bbox_head', 'build_mask_head', 'build_single_stage_head',
'build_detector'
]
from . import (backbones, necks, roi_extractors, anchor_heads, bbox_heads,
mask_heads)
def _build_module(cfg, parrent=None, default_args=None):
......@@ -32,8 +26,8 @@ def build_neck(cfg):
return build(cfg, necks)
def build_rpn_head(cfg):
return build(cfg, rpn_heads)
def build_anchor_head(cfg):
return build(cfg, anchor_heads)
def build_roi_extractor(cfg):
......@@ -48,10 +42,6 @@ def build_mask_head(cfg):
return build(cfg, mask_heads)
def build_single_stage_head(cfg):
return build(cfg, single_stage_heads)
def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
......@@ -37,7 +37,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
raise NotImplementedError
if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head)
self.rpn_head = builder.build_anchor_head(rpn_head)
if bbox_head is not None:
self.bbox_roi_extractor = nn.ModuleList()
......@@ -123,7 +123,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else:
proposal_list = proposals
......
......@@ -18,7 +18,7 @@ class RPN(BaseDetector, RPNTestMixin):
super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_rpn_head(rpn_head)
self.rpn_head = builder.build_anchor_head(rpn_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
......
......@@ -18,7 +18,7 @@ class SingleStageDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
self.bbox_head = builder.build_single_stage_head(bbox_head)
self.bbox_head = builder.build_anchor_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
......@@ -51,7 +51,7 @@ class SingleStageDetector(BaseDetector):
x = self.extract_feat(img)
outs = self.bbox_head(x)
bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
bbox_list = self.bbox_head.get_det_bboxes(*bbox_inputs)
bbox_list = self.bbox_head.get_bboxes(*bbox_inputs)
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
......
......@@ -7,7 +7,7 @@ class RPNTestMixin(object):
def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
return proposal_list
def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
......
......@@ -30,7 +30,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise NotImplementedError
if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head)
self.rpn_head = builder.build_anchor_head(rpn_head)
if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor(
......@@ -96,7 +96,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else:
proposal_list = proposals
......
from .rpn_head import RPNHead
__all__ = ['RPNHead']
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy)
from mmdet.ops import nms
from ..utils import normal_init
class RPNHead(nn.Module):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
""" # noqa: W605
def __init__(self,
in_channels,
feat_channels=256,
anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False):
super(RPNHead, self).__init__()
self.in_channels = in_channels
self.feat_channels = feat_channels
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
out_channels = (self.num_anchors
if self.use_sigmoid_cls else self.num_anchors * 2)
self.rpn_cls = nn.Conv2d(feat_channels, out_channels, 1)
self.rpn_reg = nn.Conv2d(feat_channels, self.num_anchors * 4, 1)
self.debug_imgs = None
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
rpn_feat = self.relu(self.rpn_conv(x))
rpn_cls_score = self.rpn_cls(rpn_feat)
rpn_bbox_pred = self.rpn_reg(rpn_feat)
return rpn_cls_score, rpn_bbox_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1)
label_weights = label_weights.contiguous().view(-1)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1)
criterion = weighted_binary_cross_entropy
else:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1, 2)
criterion = weighted_cross_entropy
loss_cls = criterion(
rpn_cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(
-1, 4)
loss_reg = weighted_smoothl1(
rpn_bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_total_samples)
return loss_cls, loss_reg
def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg):
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes)
cls_reg_targets = anchor_target(
anchor_list, valid_flag_list, gt_bboxes, img_shapes,
self.target_means, self.target_stds, cfg)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
losses_cls, losses_reg = multi_apply(
self.loss_single,
rpn_cls_scores,
rpn_bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_pos + num_total_neg,
cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_meta, cfg):
num_imgs = len(img_meta)
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
mlvl_anchors = [
self.anchor_generators[idx].grid_anchors(featmap_sizes[idx],
self.anchor_strides[idx])
for idx in range(len(featmap_sizes))
]
proposal_list = []
for img_id in range(num_imgs):
rpn_cls_score_list = [
rpn_cls_scores[idx][img_id].detach()
for idx in range(len(rpn_cls_scores))
]
rpn_bbox_pred_list = [
rpn_bbox_preds[idx][img_id].detach()
for idx in range(len(rpn_bbox_preds))
]
assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list)
proposals = self._get_proposals_single(
rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors,
img_meta[img_id]['img_shape'], cfg)
proposal_list.append(proposals)
return proposal_list
def _get_proposals_single(self, rpn_cls_scores, rpn_bbox_preds,
mlvl_anchors, img_shape, cfg):
mlvl_proposals = []
for idx in range(len(rpn_cls_scores)):
rpn_cls_score = rpn_cls_scores[idx]
rpn_bbox_pred = rpn_bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
anchors = mlvl_anchors[idx]
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(-1)
rpn_cls_prob = rpn_cls_score.sigmoid()
scores = rpn_cls_prob
else:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(
-1, 2)
rpn_cls_prob = F.softmax(rpn_cls_score, dim=1)
scores = rpn_cls_prob[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).contiguous().view(
-1, 4)
_, order = scores.sort(0, descending=True)
if cfg.nms_pre > 0:
order = order[:cfg.nms_pre]
rpn_bbox_pred = rpn_bbox_pred[order, :]
anchors = anchors[order, :]
scores = scores[order]
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
_, order = scores.sort(0, descending=True)
num = min(cfg.max_num, proposals.shape[0])
order = order[:num]
proposals = proposals[order, :]
return proposals
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment