Commit 20e75c22 authored by Kai Chen's avatar Kai Chen
Browse files

use anchor_target in RetinaNet

parent 45af4242
from .anchor import * # noqa: F401, F403
from .bbox_ops import * # noqa: F401, F403
from .mask_ops import * # noqa: F401, F403
from .targets import * # noqa: F401, F403
from .losses import * # noqa: F401, F403
from .eval import * # noqa: F401, F403
from .parallel import * # noqa: F401, F403
......
......@@ -4,8 +4,16 @@ from ..bbox_ops import bbox_assign, bbox2delta, bbox_sampling
from ..utils import multi_apply
def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
target_means, target_stds, cfg):
def anchor_target(anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
target_means,
target_stds,
cfg,
gt_labels_list=None,
cls_out_channels=1,
sampling=True):
"""Compute regression and classification targets for anchors.
Args:
......@@ -32,28 +40,34 @@ def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
means_replicas = [target_means for _ in range(num_imgs)]
stds_replicas = [target_stds for _ in range(num_imgs)]
cfg_replicas = [cfg for _ in range(num_imgs)]
(all_labels, all_label_weights, all_bbox_targets,
all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
anchor_target_single, anchor_list, valid_flag_list, gt_bboxes_list,
img_metas, means_replicas, stds_replicas, cfg_replicas)
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
pos_inds_list, neg_inds_list) = multi_apply(
anchor_target_single,
anchor_list,
valid_flag_list,
gt_bboxes_list,
gt_labels_list,
img_metas,
target_means=target_means,
target_stds=target_stds,
cfg=cfg,
cls_out_channels=cls_out_channels,
sampling=sampling)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_samples = sum([
max(pos_inds.numel() + neg_inds.numel(), 1)
for pos_inds, neg_inds in zip(pos_inds_list, neg_inds_list)
])
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
# split targets to a list w.r.t. multiple levels
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights, num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors)
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_samples)
bbox_weights_list, num_total_pos, num_total_neg)
def images_to_levels(target, num_level_anchors):
......@@ -71,8 +85,16 @@ def images_to_levels(target, num_level_anchors):
return level_targets
def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
target_means, target_stds, cfg):
def anchor_target_single(flat_anchors,
valid_flags,
gt_bboxes,
gt_labels,
img_meta,
target_means,
target_stds,
cfg,
cls_out_channels=1,
sampling=True):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
cfg.allowed_border)
......@@ -86,10 +108,14 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
pos_iou_thr=cfg.pos_iou_thr,
neg_iou_thr=cfg.neg_iou_thr,
min_pos_iou=cfg.min_pos_iou)
pos_inds, neg_inds = bbox_sampling(assigned_gt_inds, cfg.anchor_batch_size,
cfg.pos_fraction, cfg.neg_pos_ub,
cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
if sampling:
pos_inds, neg_inds = bbox_sampling(
assigned_gt_inds, cfg.anchor_batch_size, cfg.pos_fraction,
cfg.neg_pos_ub, cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
else:
pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze(-1).unique()
neg_inds = torch.nonzero(assigned_gt_inds == 0).squeeze(-1).unique()
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
......@@ -103,7 +129,10 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
target_stds)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
labels[pos_inds] = 1
if gt_labels is None:
labels[pos_inds] = 1
else:
labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - 1]
if cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
......@@ -115,6 +144,9 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags)
if cls_out_channels > 1:
labels, label_weights = expand_binary_labels(labels, label_weights,
cls_out_channels)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
......@@ -122,6 +154,17 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
neg_inds)
def expand_binary_labels(labels, label_weights, cls_out_channels):
bin_labels = labels.new_full(
(labels.size(0), cls_out_channels), 0, dtype=torch.float32)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), cls_out_channels)
return bin_labels, bin_label_weights
def anchor_inside_flags(flat_anchors, valid_flags, img_shape,
allowed_border=0):
img_h, img_w = img_shape[:2]
......
from .retina_target import retina_target
import torch
from ..bbox_ops import bbox_assign, bbox2delta
from ..utils import multi_apply
def retina_target(anchor_list, valid_flag_list, gt_bboxes_list, gt_labels_list,
img_metas, target_means, target_stds, cls_out_channels, cfg):
"""Compute regression and classification targets for anchors.
Args:
anchor_list (list[list]): Multi level anchors of each image.
valid_flag_list (list[list]): Multi level valid flags of each image.
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
target_means (Iterable): Mean value of regression targets.
target_stds (Iterable): Std value of regression targets.
cfg (dict): RPN train configs.
Returns:
tuple
"""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
anchor_list[i] = torch.cat(anchor_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
pos_inds_list, neg_inds_list) = multi_apply(
retina_target_single,
anchor_list,
valid_flag_list,
gt_bboxes_list,
gt_labels_list,
img_metas,
target_means=target_means,
target_stds=target_stds,
cls_out_channels=cls_out_channels,
cfg=cfg)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_pos_samples = sum([
max(pos_inds.numel(), 1)
for pos_inds, neg_inds in zip(pos_inds_list, neg_inds_list)
])
# split targets to a list w.r.t. multiple levels
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights, num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors)
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_pos_samples)
def images_to_levels(target, num_level_anchors):
"""Convert targets by image to targets by feature level.
[target_img0, target_img1] -> [target_level0, target_level1, ...]
"""
target = torch.stack(target, 0)
level_targets = []
start = 0
for n in num_level_anchors:
end = start + n
level_targets.append(target[:, start:end].squeeze(0))
start = end
return level_targets
def retina_target_single(flat_anchors, valid_flags, gt_bboxes, gt_labels,
img_meta, target_means, target_stds, cls_out_channels,
cfg):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 6
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign(
anchors,
gt_bboxes,
pos_iou_thr=cfg.pos_iou_thr,
neg_iou_thr=cfg.neg_iou_thr,
min_pos_iou=cfg.min_pos_iou)
pos_inds = torch.nonzero(assigned_gt_inds > 0)
neg_inds = torch.nonzero(assigned_gt_inds == 0)
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = torch.zeros_like(assigned_gt_inds)
label_weights = torch.zeros_like(assigned_gt_inds, dtype=anchors.dtype)
if len(pos_inds) > 0:
pos_inds = pos_inds.squeeze(1).unique()
pos_anchors = anchors[pos_inds, :]
pos_gt_bbox = gt_bboxes[assigned_gt_inds[pos_inds] - 1, :]
pos_bbox_targets = bbox2delta(pos_anchors, pos_gt_bbox, target_means,
target_stds)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - 1]
if cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = cfg.pos_weight
if len(neg_inds) > 0:
neg_inds = neg_inds.squeeze(1).unique()
label_weights[neg_inds] = 1.0
# map up to original set of anchors
num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags)
labels, label_weights = expand_binary_labels(labels, label_weights,
cls_out_channels)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds)
def expand_binary_labels(labels, label_weights, cls_out_channels):
bin_labels = labels.new_full(
(labels.size(0), cls_out_channels), 0, dtype=torch.float32)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), cls_out_channels)
return bin_labels, bin_label_weights
def anchor_inside_flags(flat_anchors, valid_flags, img_shape,
allowed_border=0):
img_h, img_w = img_shape[:2]
if allowed_border >= 0:
inside_flags = valid_flags & \
(flat_anchors[:, 0] >= -allowed_border) & \
(flat_anchors[:, 1] >= -allowed_border) & \
(flat_anchors[:, 2] < img_w + allowed_border) & \
(flat_anchors[:, 3] < img_h + allowed_border)
else:
inside_flags = valid_flags
return inside_flags
def unmap(data, count, inds, fill=0):
""" Unmap a subset of item (data) back to the original set of items (of
size count) """
if data.dim() == 1:
ret = data.new_full((count, ), fill)
ret[inds] = data
else:
new_size = (count, ) + data.size()[1:]
ret = data.new_full(new_size, fill)
ret[inds, :] = data
return ret
......@@ -160,7 +160,7 @@ class RPNHead(nn.Module):
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_samples) = cls_reg_targets
num_total_pos, num_total_neg) = cls_reg_targets
losses_cls, losses_reg = multi_apply(
self.loss_single,
rpn_cls_scores,
......@@ -169,7 +169,7 @@ class RPNHead(nn.Module):
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples,
num_total_samples=num_total_pos + num_total_neg,
cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
......
......@@ -4,9 +4,9 @@ import numpy as np
import torch
import torch.nn as nn
from mmdet.core import (AnchorGenerator, multi_apply, delta2bbox,
weighted_smoothl1, weighted_sigmoid_focal_loss,
multiclass_nms, retina_target)
from mmdet.core import (AnchorGenerator, anchor_target, multi_apply,
delta2bbox, weighted_smoothl1,
weighted_sigmoid_focal_loss, multiclass_nms)
from ..utils import normal_init, bias_init_with_prob
......@@ -172,20 +172,28 @@ class RetinaHead(nn.Module):
avg_factor=num_pos_samples)
return loss_cls, loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_shapes,
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
cfg):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes)
cls_reg_targets = retina_target(
anchor_list, valid_flag_list, gt_bboxes, gt_labels, img_shapes,
self.target_means, self.target_stds, self.cls_out_channels, cfg)
featmap_sizes, img_metas)
cls_reg_targets = anchor_target(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_labels_list=gt_labels,
cls_out_channels=self.cls_out_channels,
sampling=False)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_pos_samples) = cls_reg_targets
num_total_pos, num_total_neg) = cls_reg_targets
losses_cls, losses_reg = multi_apply(
self.loss_single,
......@@ -195,7 +203,7 @@ class RetinaHead(nn.Module):
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_pos_samples=num_pos_samples,
num_pos_samples=num_total_pos,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg)
......
from .conv_module import ConvModule
from .norm import build_norm_layer
from .weight_init import xavier_init, normal_init, uniform_init, kaiming_init
from .weight_init import (xavier_init, normal_init, uniform_init, kaiming_init,
bias_init_with_prob)
__all__ = [
'ConvModule', 'build_norm_layer', 'xavier_init', 'normal_init',
'uniform_init', 'kaiming_init'
'uniform_init', 'kaiming_init', 'bias_init_with_prob'
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment