"...resnet50_tensorflow.git" did not exist on "2b21ab966b1d9eacaac7a75a15dc420da168ca74"
Commit 0401cccd authored by Kai Chen's avatar Kai Chen
Browse files

refactor rpn target computing

parent b71a1210
import torch import torch
import numpy as np
from ..bbox_ops import (bbox_assign, bbox_transform, bbox_sampling)
from ..bbox_ops import bbox_assign, bbox_transform, bbox_sampling
from ..utils import multi_apply
def anchor_target(anchor_list, valid_flag_list, featmap_sizes, gt_bboxes_list,
img_metas, target_means, target_stds, cfg):
"""Compute regression and classification targets for anchors.
There may be multiple feature levels, def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
target_means, target_stds, cfg):
"""Compute regression and classification targets for anchors.
Args: Args:
anchor_list(list): anchors of each feature map level anchor_list (list[list]): Multi level anchors of each image.
featmap_sizes(list): feature map sizes valid_flag_list (list[list]): Multi level valid flags of each image.
gt_bboxes_list(list): ground truth bbox of images in a mini-batch gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_shapes(list): shape of each image in a mini-batch img_metas (list[dict]): Meta info of each image.
cfg(dict): configs target_means (Iterable): Mean value of regression targets.
target_stds (Iterable): Std value of regression targets.
cfg (dict): RPN train configs.
Returns: Returns:
tuple tuple
""" """
num_imgs = len(img_metas) num_imgs = len(img_metas)
num_levels = len(featmap_sizes) assert len(anchor_list) == len(valid_flag_list) == num_imgs
if len(anchor_list) == num_levels:
all_anchors = torch.cat(anchor_list, 0) # anchor number of multi levels
anchor_nums = [anchors.size(0) for anchors in anchor_list] num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
use_isomerism_anchors = False # concat all level anchors and flags to a single tensor
elif len(anchor_list) == num_imgs: for i in range(num_imgs):
# using different anchors for different images assert len(anchor_list[i]) == len(valid_flag_list[i])
all_anchors_list = [ anchor_list[i] = torch.cat(anchor_list[i])
torch.cat(anchor_list[img_id], 0) for img_id in range(num_imgs) valid_flag_list[i] = torch.cat(valid_flag_list[i])
]
anchor_nums = [anchors.size(0) for anchors in anchor_list[0]] # compute targets for each image
use_isomerism_anchors = True means_replicas = [target_means for _ in range(num_imgs)]
else: stds_replicas = [target_stds for _ in range(num_imgs)]
raise ValueError('length of anchor_list should be equal to number of ' cfg_replicas = [cfg for _ in range(num_imgs)]
'feature lvls or number of images in a batch') (all_labels, all_label_weights, all_bbox_targets,
all_labels = [] all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
all_label_weights = [] anchor_target_single, anchor_list, valid_flag_list, gt_bboxes_list,
all_bbox_targets = [] img_metas, means_replicas, stds_replicas, cfg_replicas)
all_bbox_weights = [] # no valid anchors
num_total_sampled = 0 if any([labels is None for labels in all_labels]):
for img_id in range(num_imgs): return None
if isinstance(valid_flag_list[img_id], list): # sampled anchors of all images
valid_flags = torch.cat(valid_flag_list[img_id], 0) num_total_samples = sum([
else: max(pos_inds.numel() + neg_inds.numel(), 1)
valid_flags = valid_flag_list[img_id] for pos_inds, neg_inds in zip(pos_inds_list, neg_inds_list)
if use_isomerism_anchors: ])
all_anchors = all_anchors_list[img_id] # split targets to a list w.r.t. multiple levels
inside_flags = anchor_inside_flags(all_anchors, valid_flags, labels_list = images_to_levels(all_labels, num_level_anchors)
img_metas[img_id]['img_shape'][:2], label_weights_list = images_to_levels(all_label_weights, num_level_anchors)
cfg.allowed_border) bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors)
if not inside_flags.any(): bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors)
return None return (labels_list, label_weights_list, bbox_targets_list,
gt_bboxes = gt_bboxes_list[img_id] bbox_weights_list, num_total_samples)
anchor_targets = anchor_target_single(all_anchors, inside_flags,
gt_bboxes, target_means,
target_stds, cfg) def images_to_levels(target, num_level_anchors):
(labels, label_weights, bbox_targets, bbox_weights, pos_inds, """Convert targets by image to targets by feature level.
neg_inds) = anchor_targets
all_labels.append(labels) [target_img0, target_img1] -> [target_level0, target_level1, ...]
all_label_weights.append(label_weights) """
all_bbox_targets.append(bbox_targets) target = torch.stack(target, 0)
all_bbox_weights.append(bbox_weights) level_targets = []
num_total_sampled += max(pos_inds.numel() + neg_inds.numel(), 1)
all_labels = torch.stack(all_labels, 0)
all_label_weights = torch.stack(all_label_weights, 0)
all_bbox_targets = torch.stack(all_bbox_targets, 0)
all_bbox_weights = torch.stack(all_bbox_weights, 0)
# split into different feature levels
labels_list = []
label_weights_list = []
bbox_targets_list = []
bbox_weights_list = []
start = 0 start = 0
for anchor_num in anchor_nums: for n in num_level_anchors:
end = start + anchor_num end = start + n
labels_list.append(all_labels[:, start:end].squeeze(0)) level_targets.append(target[:, start:end].squeeze(0))
label_weights_list.append(all_label_weights[:, start:end].squeeze(0))
bbox_targets_list.append(all_bbox_targets[:, start:end].squeeze(0))
bbox_weights_list.append(all_bbox_weights[:, start:end].squeeze(0))
start = end start = end
return (labels_list, label_weights_list, bbox_targets_list, return level_targets
bbox_weights_list, num_total_sampled)
def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means, def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
target_stds, cfg): target_means, target_stds, cfg):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 6
# assign gt and sample anchors # assign gt and sample anchors
anchors = all_anchors[inside_flags, :] anchors = flat_anchors[inside_flags, :]
assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign( assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign(
anchors, anchors,
gt_bboxes, gt_bboxes,
...@@ -120,7 +112,7 @@ def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means, ...@@ -120,7 +112,7 @@ def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means,
label_weights[neg_inds] = 1.0 label_weights[neg_inds] = 1.0
# map up to original set of anchors # map up to original set of anchors
num_total_anchors = all_anchors.size(0) num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags) labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags) label_weights = unmap(label_weights, num_total_anchors, inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
...@@ -130,27 +122,20 @@ def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means, ...@@ -130,27 +122,20 @@ def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means,
neg_inds) neg_inds)
def anchor_inside_flags(all_anchors, valid_flags, img_shape, allowed_border=0): def anchor_inside_flags(flat_anchors, valid_flags, img_shape,
allowed_border=0):
img_h, img_w = img_shape[:2] img_h, img_w = img_shape[:2]
if allowed_border >= 0: if allowed_border >= 0:
inside_flags = valid_flags & \ inside_flags = valid_flags & \
(all_anchors[:, 0] >= -allowed_border) & \ (flat_anchors[:, 0] >= -allowed_border) & \
(all_anchors[:, 1] >= -allowed_border) & \ (flat_anchors[:, 1] >= -allowed_border) & \
(all_anchors[:, 2] < img_w + allowed_border) & \ (flat_anchors[:, 2] < img_w + allowed_border) & \
(all_anchors[:, 3] < img_h + allowed_border) (flat_anchors[:, 3] < img_h + allowed_border)
else: else:
inside_flags = valid_flags inside_flags = valid_flags
return inside_flags return inside_flags
def unique(tensor):
if tensor.is_cuda:
u_tensor = np.unique(tensor.cpu().numpy())
return tensor.new_tensor(u_tensor)
else:
return torch.unique(tensor)
def unmap(data, count, inds, fill=0): def unmap(data, count, inds, fill=0):
""" Unmap a subset of item (data) back to the original set of items (of """ Unmap a subset of item (data) back to the original set of items (of
size count) """ size count) """
......
...@@ -212,7 +212,7 @@ class CocoDataset(Dataset): ...@@ -212,7 +212,7 @@ class CocoDataset(Dataset):
# apply transforms # apply transforms
flip = True if np.random.rand() < self.flip_ratio else False flip = True if np.random.rand() < self.flip_ratio else False
img_scale = random_scale(self.img_scales) # sample a scale img_scale = random_scale(self.img_scales) # sample a scale
img, img_shape, scale_factor = self.img_transform( img, img_shape, pad_shape, scale_factor = self.img_transform(
img, img_scale, flip) img, img_scale, flip)
if self.proposals is not None: if self.proposals is not None:
proposals = self.bbox_transform(proposals, img_shape, proposals = self.bbox_transform(proposals, img_shape,
...@@ -232,6 +232,7 @@ class CocoDataset(Dataset): ...@@ -232,6 +232,7 @@ class CocoDataset(Dataset):
img_meta = dict( img_meta = dict(
ori_shape=ori_shape, ori_shape=ori_shape,
img_shape=img_shape, img_shape=img_shape,
pad_shape=pad_shape,
scale_factor=scale_factor, scale_factor=scale_factor,
flip=flip) flip=flip)
...@@ -260,12 +261,13 @@ class CocoDataset(Dataset): ...@@ -260,12 +261,13 @@ class CocoDataset(Dataset):
if self.proposals is not None else None) if self.proposals is not None else None)
def prepare_single(img, scale, flip, proposal=None): def prepare_single(img, scale, flip, proposal=None):
_img, img_shape, scale_factor = self.img_transform( _img, img_shape, pad_shape, scale_factor = self.img_transform(
img, scale, flip) img, scale, flip)
_img = to_tensor(_img) _img = to_tensor(_img)
_img_meta = dict( _img_meta = dict(
ori_shape=(img_info['height'], img_info['width'], 3), ori_shape=(img_info['height'], img_info['width'], 3),
img_shape=img_shape, img_shape=img_shape,
pad_shape=pad_shape,
scale_factor=scale_factor, scale_factor=scale_factor,
flip=flip) flip=flip)
if proposal is not None: if proposal is not None:
......
...@@ -36,8 +36,11 @@ class ImageTransform(object): ...@@ -36,8 +36,11 @@ class ImageTransform(object):
img = mmcv.imflip(img) img = mmcv.imflip(img)
if self.size_divisor is not None: if self.size_divisor is not None:
img = mmcv.impad_to_multiple(img, self.size_divisor) img = mmcv.impad_to_multiple(img, self.size_divisor)
pad_shape = img.shape
else:
pad_shape = img_shape
img = img.transpose(2, 0, 1) img = img.transpose(2, 0, 1)
return img, img_shape, scale_factor return img, img_shape, pad_shape, scale_factor
def bbox_flip(bboxes, img_shape): def bbox_flip(bboxes, img_shape):
......
...@@ -6,18 +6,35 @@ import torch.nn as nn ...@@ -6,18 +6,35 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmdet.core import (AnchorGenerator, anchor_target, bbox_transform_inv, from mmdet.core import (AnchorGenerator, anchor_target, bbox_transform_inv,
weighted_cross_entropy, weighted_smoothl1, multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy) weighted_binary_cross_entropy)
from mmdet.ops import nms from mmdet.ops import nms
from ..utils import multi_apply, normal_init from ..utils import normal_init
class RPNHead(nn.Module): class RPNHead(nn.Module):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
"""
def __init__(self, def __init__(self,
in_channels, in_channels,
feat_channels=512, feat_channels=256,
coarsest_stride=32,
anchor_scales=[8, 16, 32], anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0], anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64], anchor_strides=[4, 8, 16, 32, 64],
...@@ -28,7 +45,6 @@ class RPNHead(nn.Module): ...@@ -28,7 +45,6 @@ class RPNHead(nn.Module):
super(RPNHead, self).__init__() super(RPNHead, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.feat_channels = feat_channels self.feat_channels = feat_channels
self.coarsest_stride = coarsest_stride
self.anchor_scales = anchor_scales self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides self.anchor_strides = anchor_strides
...@@ -66,38 +82,42 @@ class RPNHead(nn.Module): ...@@ -66,38 +82,42 @@ class RPNHead(nn.Module):
return multi_apply(self.forward_single, feats) return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas): def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors given a list of feature map sizes, and get valid flags """Get anchors according to feature map sizes.
at the same time. (Extra padding regions should be marked as invalid)
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
""" """
# calculate actual image shapes num_imgs = len(img_metas)
padded_img_shapes = [] num_levels = len(featmap_sizes)
for img_meta in img_metas:
h, w = img_meta['img_shape'][:2] # since feature map sizes of all images are the same, we only compute
padded_h = int( # anchors for one time
np.ceil(h / self.coarsest_stride) * self.coarsest_stride) multi_level_anchors = []
padded_w = int( for i in range(num_levels):
np.ceil(w / self.coarsest_stride) * self.coarsest_stride)
padded_img_shapes.append((padded_h, padded_w))
# generate anchors for different feature levels
# len = feature levels
anchor_list = []
# len = imgs per gpu
valid_flag_list = [[] for _ in range(len(img_metas))]
for i in range(len(featmap_sizes)):
anchor_stride = self.anchor_strides[i]
anchors = self.anchor_generators[i].grid_anchors( anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], anchor_stride) featmap_sizes[i], self.anchor_strides[i])
anchor_list.append(anchors) multi_level_anchors.append(anchors)
# for each image in this feature level, get valid flags anchor_list = [multi_level_anchors for _ in range(num_imgs)]
featmap_size = featmap_sizes[i]
for img_id, (h, w) in enumerate(padded_img_shapes): # for each image, we compute valid flags of multi level anchors
valid_feat_h = min( valid_flag_list = []
int(np.ceil(h / anchor_stride)), featmap_size[0]) for img_id, img_meta in enumerate(img_metas):
valid_feat_w = min( multi_level_flags = []
int(np.ceil(w / anchor_stride)), featmap_size[1]) for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags( flags = self.anchor_generators[i].valid_flags(
featmap_size, (valid_feat_h, valid_feat_w)) (feat_h, feat_w), (valid_feat_h, valid_feat_w))
valid_flag_list[img_id].append(flags) multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list return anchor_list, valid_flag_list
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights, def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
...@@ -135,7 +155,7 @@ class RPNHead(nn.Module): ...@@ -135,7 +155,7 @@ class RPNHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors( anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes) featmap_sizes, img_shapes)
cls_reg_targets = anchor_target( cls_reg_targets = anchor_target(
anchor_list, valid_flag_list, featmap_sizes, gt_bboxes, img_shapes, anchor_list, valid_flag_list, gt_bboxes, img_shapes,
self.target_means, self.target_stds, cfg) self.target_means, self.target_stds, cfg)
if cls_reg_targets is None: if cls_reg_targets is None:
return None return None
......
...@@ -18,7 +18,6 @@ model = dict( ...@@ -18,7 +18,6 @@ model = dict(
type='RPNHead', type='RPNHead',
in_channels=256, in_channels=256,
feat_channels=256, feat_channels=256,
coarsest_stride=32,
anchor_scales=[8], anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0], anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64], anchor_strides=[4, 8, 16, 32, 64],
......
...@@ -18,7 +18,6 @@ model = dict( ...@@ -18,7 +18,6 @@ model = dict(
type='RPNHead', type='RPNHead',
in_channels=256, in_channels=256,
feat_channels=256, feat_channels=256,
coarsest_stride=32,
anchor_scales=[8], anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0], anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64], anchor_strides=[4, 8, 16, 32, 64],
......
...@@ -18,7 +18,6 @@ model = dict( ...@@ -18,7 +18,6 @@ model = dict(
type='RPNHead', type='RPNHead',
in_channels=256, in_channels=256,
feat_channels=256, feat_channels=256,
coarsest_stride=32,
anchor_scales=[8], anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0], anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64], anchor_strides=[4, 8, 16, 32, 64],
...@@ -104,5 +103,5 @@ dist_params = dict(backend='gloo') ...@@ -104,5 +103,5 @@ dist_params = dict(backend='gloo')
log_level = 'INFO' log_level = 'INFO'
work_dir = './work_dirs/fpn_rpn_r50_1x' work_dir = './work_dirs/fpn_rpn_r50_1x'
load_from = None load_from = None
resume_from = None resume_from = None
workflow = [('train', 1)] workflow = [('train', 1)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment