Commit 20e75c22 authored by Kai Chen's avatar Kai Chen
Browse files

use anchor_target in RetinaNet

parent 45af4242
from .anchor import * # noqa: F401, F403 from .anchor import * # noqa: F401, F403
from .bbox_ops import * # noqa: F401, F403 from .bbox_ops import * # noqa: F401, F403
from .mask_ops import * # noqa: F401, F403 from .mask_ops import * # noqa: F401, F403
from .targets import * # noqa: F401, F403
from .losses import * # noqa: F401, F403 from .losses import * # noqa: F401, F403
from .eval import * # noqa: F401, F403 from .eval import * # noqa: F401, F403
from .parallel import * # noqa: F401, F403 from .parallel import * # noqa: F401, F403
......
...@@ -4,8 +4,16 @@ from ..bbox_ops import bbox_assign, bbox2delta, bbox_sampling ...@@ -4,8 +4,16 @@ from ..bbox_ops import bbox_assign, bbox2delta, bbox_sampling
from ..utils import multi_apply from ..utils import multi_apply
def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas, def anchor_target(anchor_list,
target_means, target_stds, cfg): valid_flag_list,
gt_bboxes_list,
img_metas,
target_means,
target_stds,
cfg,
gt_labels_list=None,
cls_out_channels=1,
sampling=True):
"""Compute regression and classification targets for anchors. """Compute regression and classification targets for anchors.
Args: Args:
...@@ -32,28 +40,34 @@ def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas, ...@@ -32,28 +40,34 @@ def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
valid_flag_list[i] = torch.cat(valid_flag_list[i]) valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image # compute targets for each image
means_replicas = [target_means for _ in range(num_imgs)] if gt_labels_list is None:
stds_replicas = [target_stds for _ in range(num_imgs)] gt_labels_list = [None for _ in range(num_imgs)]
cfg_replicas = [cfg for _ in range(num_imgs)] (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
(all_labels, all_label_weights, all_bbox_targets, pos_inds_list, neg_inds_list) = multi_apply(
all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply( anchor_target_single,
anchor_target_single, anchor_list, valid_flag_list, gt_bboxes_list, anchor_list,
img_metas, means_replicas, stds_replicas, cfg_replicas) valid_flag_list,
gt_bboxes_list,
gt_labels_list,
img_metas,
target_means=target_means,
target_stds=target_stds,
cfg=cfg,
cls_out_channels=cls_out_channels,
sampling=sampling)
# no valid anchors # no valid anchors
if any([labels is None for labels in all_labels]): if any([labels is None for labels in all_labels]):
return None return None
# sampled anchors of all images # sampled anchors of all images
num_total_samples = sum([ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
max(pos_inds.numel() + neg_inds.numel(), 1) num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
for pos_inds, neg_inds in zip(pos_inds_list, neg_inds_list)
])
# split targets to a list w.r.t. multiple levels # split targets to a list w.r.t. multiple levels
labels_list = images_to_levels(all_labels, num_level_anchors) labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights, num_level_anchors) label_weights_list = images_to_levels(all_label_weights, num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors) bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors) bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors)
return (labels_list, label_weights_list, bbox_targets_list, return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_samples) bbox_weights_list, num_total_pos, num_total_neg)
def images_to_levels(target, num_level_anchors): def images_to_levels(target, num_level_anchors):
...@@ -71,8 +85,16 @@ def images_to_levels(target, num_level_anchors): ...@@ -71,8 +85,16 @@ def images_to_levels(target, num_level_anchors):
return level_targets return level_targets
def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta, def anchor_target_single(flat_anchors,
target_means, target_stds, cfg): valid_flags,
gt_bboxes,
gt_labels,
img_meta,
target_means,
target_stds,
cfg,
cls_out_channels=1,
sampling=True):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags, inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2], img_meta['img_shape'][:2],
cfg.allowed_border) cfg.allowed_border)
...@@ -86,10 +108,14 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta, ...@@ -86,10 +108,14 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
pos_iou_thr=cfg.pos_iou_thr, pos_iou_thr=cfg.pos_iou_thr,
neg_iou_thr=cfg.neg_iou_thr, neg_iou_thr=cfg.neg_iou_thr,
min_pos_iou=cfg.min_pos_iou) min_pos_iou=cfg.min_pos_iou)
pos_inds, neg_inds = bbox_sampling(assigned_gt_inds, cfg.anchor_batch_size, if sampling:
cfg.pos_fraction, cfg.neg_pos_ub, pos_inds, neg_inds = bbox_sampling(
cfg.pos_balance_sampling, max_overlaps, assigned_gt_inds, cfg.anchor_batch_size, cfg.pos_fraction,
cfg.neg_balance_thr) cfg.neg_pos_ub, cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
else:
pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze(-1).unique()
neg_inds = torch.nonzero(assigned_gt_inds == 0).squeeze(-1).unique()
bbox_targets = torch.zeros_like(anchors) bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors) bbox_weights = torch.zeros_like(anchors)
...@@ -103,7 +129,10 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta, ...@@ -103,7 +129,10 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
target_stds) target_stds)
bbox_targets[pos_inds, :] = pos_bbox_targets bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0 bbox_weights[pos_inds, :] = 1.0
labels[pos_inds] = 1 if gt_labels is None:
labels[pos_inds] = 1
else:
labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - 1]
if cfg.pos_weight <= 0: if cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0 label_weights[pos_inds] = 1.0
else: else:
...@@ -115,6 +144,9 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta, ...@@ -115,6 +144,9 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
num_total_anchors = flat_anchors.size(0) num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags) labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags) label_weights = unmap(label_weights, num_total_anchors, inside_flags)
if cls_out_channels > 1:
labels, label_weights = expand_binary_labels(labels, label_weights,
cls_out_channels)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
...@@ -122,6 +154,17 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta, ...@@ -122,6 +154,17 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
neg_inds) neg_inds)
def expand_binary_labels(labels, label_weights, cls_out_channels):
bin_labels = labels.new_full(
(labels.size(0), cls_out_channels), 0, dtype=torch.float32)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), cls_out_channels)
return bin_labels, bin_label_weights
def anchor_inside_flags(flat_anchors, valid_flags, img_shape, def anchor_inside_flags(flat_anchors, valid_flags, img_shape,
allowed_border=0): allowed_border=0):
img_h, img_w = img_shape[:2] img_h, img_w = img_shape[:2]
......
from .retina_target import retina_target
import torch
from ..bbox_ops import bbox_assign, bbox2delta
from ..utils import multi_apply
def retina_target(anchor_list, valid_flag_list, gt_bboxes_list, gt_labels_list,
img_metas, target_means, target_stds, cls_out_channels, cfg):
"""Compute regression and classification targets for anchors.
Args:
anchor_list (list[list]): Multi level anchors of each image.
valid_flag_list (list[list]): Multi level valid flags of each image.
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
target_means (Iterable): Mean value of regression targets.
target_stds (Iterable): Std value of regression targets.
cfg (dict): RPN train configs.
Returns:
tuple
"""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
anchor_list[i] = torch.cat(anchor_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
pos_inds_list, neg_inds_list) = multi_apply(
retina_target_single,
anchor_list,
valid_flag_list,
gt_bboxes_list,
gt_labels_list,
img_metas,
target_means=target_means,
target_stds=target_stds,
cls_out_channels=cls_out_channels,
cfg=cfg)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_pos_samples = sum([
max(pos_inds.numel(), 1)
for pos_inds, neg_inds in zip(pos_inds_list, neg_inds_list)
])
# split targets to a list w.r.t. multiple levels
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights, num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors)
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_pos_samples)
def images_to_levels(target, num_level_anchors):
"""Convert targets by image to targets by feature level.
[target_img0, target_img1] -> [target_level0, target_level1, ...]
"""
target = torch.stack(target, 0)
level_targets = []
start = 0
for n in num_level_anchors:
end = start + n
level_targets.append(target[:, start:end].squeeze(0))
start = end
return level_targets
def retina_target_single(flat_anchors, valid_flags, gt_bboxes, gt_labels,
img_meta, target_means, target_stds, cls_out_channels,
cfg):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 6
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign(
anchors,
gt_bboxes,
pos_iou_thr=cfg.pos_iou_thr,
neg_iou_thr=cfg.neg_iou_thr,
min_pos_iou=cfg.min_pos_iou)
pos_inds = torch.nonzero(assigned_gt_inds > 0)
neg_inds = torch.nonzero(assigned_gt_inds == 0)
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = torch.zeros_like(assigned_gt_inds)
label_weights = torch.zeros_like(assigned_gt_inds, dtype=anchors.dtype)
if len(pos_inds) > 0:
pos_inds = pos_inds.squeeze(1).unique()
pos_anchors = anchors[pos_inds, :]
pos_gt_bbox = gt_bboxes[assigned_gt_inds[pos_inds] - 1, :]
pos_bbox_targets = bbox2delta(pos_anchors, pos_gt_bbox, target_means,
target_stds)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - 1]
if cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = cfg.pos_weight
if len(neg_inds) > 0:
neg_inds = neg_inds.squeeze(1).unique()
label_weights[neg_inds] = 1.0
# map up to original set of anchors
num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags)
labels, label_weights = expand_binary_labels(labels, label_weights,
cls_out_channels)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds)
def expand_binary_labels(labels, label_weights, cls_out_channels):
bin_labels = labels.new_full(
(labels.size(0), cls_out_channels), 0, dtype=torch.float32)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), cls_out_channels)
return bin_labels, bin_label_weights
def anchor_inside_flags(flat_anchors, valid_flags, img_shape,
allowed_border=0):
img_h, img_w = img_shape[:2]
if allowed_border >= 0:
inside_flags = valid_flags & \
(flat_anchors[:, 0] >= -allowed_border) & \
(flat_anchors[:, 1] >= -allowed_border) & \
(flat_anchors[:, 2] < img_w + allowed_border) & \
(flat_anchors[:, 3] < img_h + allowed_border)
else:
inside_flags = valid_flags
return inside_flags
def unmap(data, count, inds, fill=0):
""" Unmap a subset of item (data) back to the original set of items (of
size count) """
if data.dim() == 1:
ret = data.new_full((count, ), fill)
ret[inds] = data
else:
new_size = (count, ) + data.size()[1:]
ret = data.new_full(new_size, fill)
ret[inds, :] = data
return ret
...@@ -160,7 +160,7 @@ class RPNHead(nn.Module): ...@@ -160,7 +160,7 @@ class RPNHead(nn.Module):
if cls_reg_targets is None: if cls_reg_targets is None:
return None return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_samples) = cls_reg_targets num_total_pos, num_total_neg) = cls_reg_targets
losses_cls, losses_reg = multi_apply( losses_cls, losses_reg = multi_apply(
self.loss_single, self.loss_single,
rpn_cls_scores, rpn_cls_scores,
...@@ -169,7 +169,7 @@ class RPNHead(nn.Module): ...@@ -169,7 +169,7 @@ class RPNHead(nn.Module):
label_weights_list, label_weights_list,
bbox_targets_list, bbox_targets_list,
bbox_weights_list, bbox_weights_list,
num_total_samples=num_total_samples, num_total_samples=num_total_pos + num_total_neg,
cfg=cfg) cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg) return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
......
...@@ -4,9 +4,9 @@ import numpy as np ...@@ -4,9 +4,9 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmdet.core import (AnchorGenerator, multi_apply, delta2bbox, from mmdet.core import (AnchorGenerator, anchor_target, multi_apply,
weighted_smoothl1, weighted_sigmoid_focal_loss, delta2bbox, weighted_smoothl1,
multiclass_nms, retina_target) weighted_sigmoid_focal_loss, multiclass_nms)
from ..utils import normal_init, bias_init_with_prob from ..utils import normal_init, bias_init_with_prob
...@@ -172,20 +172,28 @@ class RetinaHead(nn.Module): ...@@ -172,20 +172,28 @@ class RetinaHead(nn.Module):
avg_factor=num_pos_samples) avg_factor=num_pos_samples)
return loss_cls, loss_reg return loss_cls, loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_shapes, def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
cfg): cfg):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators) assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, valid_flag_list = self.get_anchors( anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes) featmap_sizes, img_metas)
cls_reg_targets = retina_target( cls_reg_targets = anchor_target(
anchor_list, valid_flag_list, gt_bboxes, gt_labels, img_shapes, anchor_list,
self.target_means, self.target_stds, self.cls_out_channels, cfg) valid_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_labels_list=gt_labels,
cls_out_channels=self.cls_out_channels,
sampling=False)
if cls_reg_targets is None: if cls_reg_targets is None:
return None return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_pos_samples) = cls_reg_targets num_total_pos, num_total_neg) = cls_reg_targets
losses_cls, losses_reg = multi_apply( losses_cls, losses_reg = multi_apply(
self.loss_single, self.loss_single,
...@@ -195,7 +203,7 @@ class RetinaHead(nn.Module): ...@@ -195,7 +203,7 @@ class RetinaHead(nn.Module):
label_weights_list, label_weights_list,
bbox_targets_list, bbox_targets_list,
bbox_weights_list, bbox_weights_list,
num_pos_samples=num_pos_samples, num_pos_samples=num_total_pos,
cfg=cfg) cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg) return dict(loss_cls=losses_cls, loss_reg=losses_reg)
......
from .conv_module import ConvModule from .conv_module import ConvModule
from .norm import build_norm_layer from .norm import build_norm_layer
from .weight_init import xavier_init, normal_init, uniform_init, kaiming_init from .weight_init import (xavier_init, normal_init, uniform_init, kaiming_init,
bias_init_with_prob)
__all__ = [ __all__ = [
'ConvModule', 'build_norm_layer', 'xavier_init', 'normal_init', 'ConvModule', 'build_norm_layer', 'xavier_init', 'normal_init',
'uniform_init', 'kaiming_init' 'uniform_init', 'kaiming_init', 'bias_init_with_prob'
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment