Commit 0401cccd authored by Kai Chen's avatar Kai Chen
Browse files

refactor rpn target computing

parent b71a1210
import torch
import numpy as np
from ..bbox_ops import (bbox_assign, bbox_transform, bbox_sampling)
from ..bbox_ops import bbox_assign, bbox_transform, bbox_sampling
from ..utils import multi_apply
def anchor_target(anchor_list, valid_flag_list, featmap_sizes, gt_bboxes_list,
img_metas, target_means, target_stds, cfg):
"""Compute regression and classification targets for anchors.
There may be multiple feature levels,
def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
target_means, target_stds, cfg):
"""Compute regression and classification targets for anchors.
Args:
anchor_list(list): anchors of each feature map level
featmap_sizes(list): feature map sizes
gt_bboxes_list(list): ground truth bbox of images in a mini-batch
img_shapes(list): shape of each image in a mini-batch
cfg(dict): configs
anchor_list (list[list]): Multi level anchors of each image.
valid_flag_list (list[list]): Multi level valid flags of each image.
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
target_means (Iterable): Mean value of regression targets.
target_stds (Iterable): Std value of regression targets.
cfg (dict): RPN train configs.
Returns:
tuple
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
if len(anchor_list) == num_levels:
all_anchors = torch.cat(anchor_list, 0)
anchor_nums = [anchors.size(0) for anchors in anchor_list]
use_isomerism_anchors = False
elif len(anchor_list) == num_imgs:
# using different anchors for different images
all_anchors_list = [
torch.cat(anchor_list[img_id], 0) for img_id in range(num_imgs)
]
anchor_nums = [anchors.size(0) for anchors in anchor_list[0]]
use_isomerism_anchors = True
else:
raise ValueError('length of anchor_list should be equal to number of '
'feature lvls or number of images in a batch')
all_labels = []
all_label_weights = []
all_bbox_targets = []
all_bbox_weights = []
num_total_sampled = 0
for img_id in range(num_imgs):
if isinstance(valid_flag_list[img_id], list):
valid_flags = torch.cat(valid_flag_list[img_id], 0)
else:
valid_flags = valid_flag_list[img_id]
if use_isomerism_anchors:
all_anchors = all_anchors_list[img_id]
inside_flags = anchor_inside_flags(all_anchors, valid_flags,
img_metas[img_id]['img_shape'][:2],
cfg.allowed_border)
if not inside_flags.any():
return None
gt_bboxes = gt_bboxes_list[img_id]
anchor_targets = anchor_target_single(all_anchors, inside_flags,
gt_bboxes, target_means,
target_stds, cfg)
(labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds) = anchor_targets
all_labels.append(labels)
all_label_weights.append(label_weights)
all_bbox_targets.append(bbox_targets)
all_bbox_weights.append(bbox_weights)
num_total_sampled += max(pos_inds.numel() + neg_inds.numel(), 1)
all_labels = torch.stack(all_labels, 0)
all_label_weights = torch.stack(all_label_weights, 0)
all_bbox_targets = torch.stack(all_bbox_targets, 0)
all_bbox_weights = torch.stack(all_bbox_weights, 0)
# split into different feature levels
labels_list = []
label_weights_list = []
bbox_targets_list = []
bbox_weights_list = []
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
anchor_list[i] = torch.cat(anchor_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
means_replicas = [target_means for _ in range(num_imgs)]
stds_replicas = [target_stds for _ in range(num_imgs)]
cfg_replicas = [cfg for _ in range(num_imgs)]
(all_labels, all_label_weights, all_bbox_targets,
all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
anchor_target_single, anchor_list, valid_flag_list, gt_bboxes_list,
img_metas, means_replicas, stds_replicas, cfg_replicas)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_samples = sum([
max(pos_inds.numel() + neg_inds.numel(), 1)
for pos_inds, neg_inds in zip(pos_inds_list, neg_inds_list)
])
# split targets to a list w.r.t. multiple levels
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights, num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors)
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_samples)
def images_to_levels(target, num_level_anchors):
"""Convert targets by image to targets by feature level.
[target_img0, target_img1] -> [target_level0, target_level1, ...]
"""
target = torch.stack(target, 0)
level_targets = []
start = 0
for anchor_num in anchor_nums:
end = start + anchor_num
labels_list.append(all_labels[:, start:end].squeeze(0))
label_weights_list.append(all_label_weights[:, start:end].squeeze(0))
bbox_targets_list.append(all_bbox_targets[:, start:end].squeeze(0))
bbox_weights_list.append(all_bbox_weights[:, start:end].squeeze(0))
for n in num_level_anchors:
end = start + n
level_targets.append(target[:, start:end].squeeze(0))
start = end
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_sampled)
return level_targets
def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means,
target_stds, cfg):
def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
target_means, target_stds, cfg):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 6
# assign gt and sample anchors
anchors = all_anchors[inside_flags, :]
anchors = flat_anchors[inside_flags, :]
assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign(
anchors,
gt_bboxes,
......@@ -120,7 +112,7 @@ def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means,
label_weights[neg_inds] = 1.0
# map up to original set of anchors
num_total_anchors = all_anchors.size(0)
num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
......@@ -130,27 +122,20 @@ def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means,
neg_inds)
def anchor_inside_flags(all_anchors, valid_flags, img_shape, allowed_border=0):
def anchor_inside_flags(flat_anchors, valid_flags, img_shape,
allowed_border=0):
img_h, img_w = img_shape[:2]
if allowed_border >= 0:
inside_flags = valid_flags & \
(all_anchors[:, 0] >= -allowed_border) & \
(all_anchors[:, 1] >= -allowed_border) & \
(all_anchors[:, 2] < img_w + allowed_border) & \
(all_anchors[:, 3] < img_h + allowed_border)
(flat_anchors[:, 0] >= -allowed_border) & \
(flat_anchors[:, 1] >= -allowed_border) & \
(flat_anchors[:, 2] < img_w + allowed_border) & \
(flat_anchors[:, 3] < img_h + allowed_border)
else:
inside_flags = valid_flags
return inside_flags
def unique(tensor):
if tensor.is_cuda:
u_tensor = np.unique(tensor.cpu().numpy())
return tensor.new_tensor(u_tensor)
else:
return torch.unique(tensor)
def unmap(data, count, inds, fill=0):
""" Unmap a subset of item (data) back to the original set of items (of
size count) """
......
......@@ -212,7 +212,7 @@ class CocoDataset(Dataset):
# apply transforms
flip = True if np.random.rand() < self.flip_ratio else False
img_scale = random_scale(self.img_scales) # sample a scale
img, img_shape, scale_factor = self.img_transform(
img, img_shape, pad_shape, scale_factor = self.img_transform(
img, img_scale, flip)
if self.proposals is not None:
proposals = self.bbox_transform(proposals, img_shape,
......@@ -232,6 +232,7 @@ class CocoDataset(Dataset):
img_meta = dict(
ori_shape=ori_shape,
img_shape=img_shape,
pad_shape=pad_shape,
scale_factor=scale_factor,
flip=flip)
......@@ -260,12 +261,13 @@ class CocoDataset(Dataset):
if self.proposals is not None else None)
def prepare_single(img, scale, flip, proposal=None):
_img, img_shape, scale_factor = self.img_transform(
_img, img_shape, pad_shape, scale_factor = self.img_transform(
img, scale, flip)
_img = to_tensor(_img)
_img_meta = dict(
ori_shape=(img_info['height'], img_info['width'], 3),
img_shape=img_shape,
pad_shape=pad_shape,
scale_factor=scale_factor,
flip=flip)
if proposal is not None:
......
......@@ -36,8 +36,11 @@ class ImageTransform(object):
img = mmcv.imflip(img)
if self.size_divisor is not None:
img = mmcv.impad_to_multiple(img, self.size_divisor)
pad_shape = img.shape
else:
pad_shape = img_shape
img = img.transpose(2, 0, 1)
return img, img_shape, scale_factor
return img, img_shape, pad_shape, scale_factor
def bbox_flip(bboxes, img_shape):
......
......@@ -6,18 +6,35 @@ import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (AnchorGenerator, anchor_target, bbox_transform_inv,
weighted_cross_entropy, weighted_smoothl1,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy)
from mmdet.ops import nms
from ..utils import multi_apply, normal_init
from ..utils import normal_init
class RPNHead(nn.Module):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
"""
def __init__(self,
in_channels,
feat_channels=512,
coarsest_stride=32,
feat_channels=256,
anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
......@@ -28,7 +45,6 @@ class RPNHead(nn.Module):
super(RPNHead, self).__init__()
self.in_channels = in_channels
self.feat_channels = feat_channels
self.coarsest_stride = coarsest_stride
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
......@@ -66,38 +82,42 @@ class RPNHead(nn.Module):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors given a list of feature map sizes, and get valid flags
at the same time. (Extra padding regions should be marked as invalid)
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
# calculate actual image shapes
padded_img_shapes = []
for img_meta in img_metas:
h, w = img_meta['img_shape'][:2]
padded_h = int(
np.ceil(h / self.coarsest_stride) * self.coarsest_stride)
padded_w = int(
np.ceil(w / self.coarsest_stride) * self.coarsest_stride)
padded_img_shapes.append((padded_h, padded_w))
# generate anchors for different feature levels
# len = feature levels
anchor_list = []
# len = imgs per gpu
valid_flag_list = [[] for _ in range(len(img_metas))]
for i in range(len(featmap_sizes)):
anchor_stride = self.anchor_strides[i]
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], anchor_stride)
anchor_list.append(anchors)
# for each image in this feature level, get valid flags
featmap_size = featmap_sizes[i]
for img_id, (h, w) in enumerate(padded_img_shapes):
valid_feat_h = min(
int(np.ceil(h / anchor_stride)), featmap_size[0])
valid_feat_w = min(
int(np.ceil(w / anchor_stride)), featmap_size[1])
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
featmap_size, (valid_feat_h, valid_feat_w))
valid_flag_list[img_id].append(flags)
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
......@@ -135,7 +155,7 @@ class RPNHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes)
cls_reg_targets = anchor_target(
anchor_list, valid_flag_list, featmap_sizes, gt_bboxes, img_shapes,
anchor_list, valid_flag_list, gt_bboxes, img_shapes,
self.target_means, self.target_stds, cfg)
if cls_reg_targets is None:
return None
......
......@@ -18,7 +18,6 @@ model = dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
coarsest_stride=32,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
......
......@@ -18,7 +18,6 @@ model = dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
coarsest_stride=32,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
......
......@@ -18,7 +18,6 @@ model = dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
coarsest_stride=32,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
......@@ -104,5 +103,5 @@ dist_params = dict(backend='gloo')
log_level = 'INFO'
work_dir = './work_dirs/fpn_rpn_r50_1x'
load_from = None
resume_from = None
resume_from = None
workflow = [('train', 1)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment