"examples/tensorflow/dgi/gcn.py" did not exist on "0a78dbe12a90845e1010adcf76f71f04f7386bd1"
Unverified Commit 4990aae6 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #44 from hellock/dev

Add BBoxAssigner and BBoxSampler components for better modular usage
parents f8fa51c9 bac11303
...@@ -43,17 +43,19 @@ model = dict( ...@@ -43,17 +43,19 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rcnn=dict( rcnn=dict(
mask_size=28, assigner=dict(
pos_iou_thr=0.5, pos_iou_thr=0.5,
neg_iou_thr=0.5, neg_iou_thr=0.5,
crowd_thr=1.1, min_pos_iou=0.5,
roi_batch_size=512, ignore_iof_thr=-1),
add_gt_as_proposals=True, sampler=dict(
num=512,
pos_fraction=0.25, pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True,
pos_balance_sampling=False, pos_balance_sampling=False,
neg_pos_ub=512, neg_balance_thr=0),
neg_balance_thr=0, mask_size=28,
min_pos_iou=0.5,
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict( test_cfg = dict(
......
...@@ -32,16 +32,18 @@ model = dict( ...@@ -32,16 +32,18 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rcnn=dict( rcnn=dict(
assigner=dict(
pos_iou_thr=0.5, pos_iou_thr=0.5,
neg_iou_thr=0.5, neg_iou_thr=0.5,
crowd_thr=1.1, min_pos_iou=0.5,
roi_batch_size=512, ignore_iof_thr=-1),
add_gt_as_proposals=True, sampler=dict(
num=512,
pos_fraction=0.25, pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True,
pos_balance_sampling=False, pos_balance_sampling=False,
neg_pos_ub=512, neg_balance_thr=0),
neg_balance_thr=0,
min_pos_iou=0.5,
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5)) test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
......
...@@ -42,30 +42,35 @@ model = dict( ...@@ -42,30 +42,35 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict( rpn=dict(
pos_fraction=0.5, assigner=dict(
pos_balance_sampling=False,
neg_pos_ub=256,
allowed_border=0,
crowd_thr=1.1,
anchor_batch_size=256,
pos_iou_thr=0.7, pos_iou_thr=0.7,
neg_iou_thr=0.3, neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=0.3, min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False,
pos_balance_sampling=False,
neg_balance_thr=0),
allowed_border=0,
pos_weight=-1, pos_weight=-1,
smoothl1_beta=1 / 9.0, smoothl1_beta=1 / 9.0,
debug=False), debug=False),
rcnn=dict( rcnn=dict(
assigner=dict(
pos_iou_thr=0.5, pos_iou_thr=0.5,
neg_iou_thr=0.5, neg_iou_thr=0.5,
crowd_thr=1.1, min_pos_iou=0.5,
roi_batch_size=512, ignore_iof_thr=-1),
add_gt_as_proposals=True, sampler=dict(
num=512,
pos_fraction=0.25, pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True,
pos_balance_sampling=False, pos_balance_sampling=False,
neg_pos_ub=512, neg_balance_thr=0),
neg_balance_thr=0,
min_pos_iou=0.5,
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict( test_cfg = dict(
......
...@@ -53,31 +53,36 @@ model = dict( ...@@ -53,31 +53,36 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict( rpn=dict(
pos_fraction=0.5, assigner=dict(
pos_balance_sampling=False,
neg_pos_ub=256,
allowed_border=0,
crowd_thr=1.1,
anchor_batch_size=256,
pos_iou_thr=0.7, pos_iou_thr=0.7,
neg_iou_thr=0.3, neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=0.3, min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False,
pos_balance_sampling=False,
neg_balance_thr=0),
allowed_border=0,
pos_weight=-1, pos_weight=-1,
smoothl1_beta=1 / 9.0, smoothl1_beta=1 / 9.0,
debug=False), debug=False),
rcnn=dict( rcnn=dict(
mask_size=28, assigner=dict(
pos_iou_thr=0.5, pos_iou_thr=0.5,
neg_iou_thr=0.5, neg_iou_thr=0.5,
crowd_thr=1.1, min_pos_iou=0.5,
roi_batch_size=512, ignore_iof_thr=-1),
add_gt_as_proposals=True, sampler=dict(
num=512,
pos_fraction=0.25, pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True,
pos_balance_sampling=False, pos_balance_sampling=False,
neg_pos_ub=512, neg_balance_thr=0),
neg_balance_thr=0, mask_size=28,
min_pos_iou=0.5,
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict( test_cfg = dict(
......
...@@ -27,16 +27,19 @@ model = dict( ...@@ -27,16 +27,19 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict( rpn=dict(
pos_fraction=0.5, assigner=dict(
pos_balance_sampling=False,
neg_pos_ub=256,
allowed_border=0,
crowd_thr=1.1,
anchor_batch_size=256,
pos_iou_thr=0.7, pos_iou_thr=0.7,
neg_iou_thr=0.3, neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=0.3, min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False,
pos_balance_sampling=False,
neg_balance_thr=0),
allowed_border=0,
pos_weight=-1, pos_weight=-1,
smoothl1_beta=1 / 9.0, smoothl1_beta=1 / 9.0,
debug=False)) debug=False))
......
import torch import torch
from ..bbox import bbox_assign, bbox2delta, bbox_sampling from ..bbox import assign_and_sample, bbox2delta
from ..utils import multi_apply from ..utils import multi_apply
...@@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta, ...@@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
return (None, ) * 6 return (None, ) * 6
# assign gt and sample anchors # assign gt and sample anchors
anchors = flat_anchors[inside_flags, :] anchors = flat_anchors[inside_flags, :]
assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign( _, sampling_result = assign_and_sample(anchors, gt_bboxes, None, None, cfg)
anchors,
gt_bboxes,
pos_iou_thr=cfg.pos_iou_thr,
neg_iou_thr=cfg.neg_iou_thr,
min_pos_iou=cfg.min_pos_iou)
pos_inds, neg_inds = bbox_sampling(assigned_gt_inds, cfg.anchor_batch_size,
cfg.pos_fraction, cfg.neg_pos_ub,
cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors) bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors) bbox_weights = torch.zeros_like(anchors)
labels = torch.zeros_like(assigned_gt_inds) labels = anchors.new_zeros((num_valid_anchors, ))
label_weights = torch.zeros_like(assigned_gt_inds, dtype=anchors.dtype) label_weights = anchors.new_zeros((num_valid_anchors, ))
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0: if len(pos_inds) > 0:
pos_anchors = anchors[pos_inds, :] pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes,
pos_gt_bbox = gt_bboxes[assigned_gt_inds[pos_inds] - 1, :] sampling_result.pos_gt_bboxes,
pos_bbox_targets = bbox2delta(pos_anchors, pos_gt_bbox, target_means, target_means, target_stds)
target_stds)
bbox_targets[pos_inds, :] = pos_bbox_targets bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0 bbox_weights[pos_inds, :] = 1.0
labels[pos_inds] = 1 labels[pos_inds] = 1
......
from .geometry import bbox_overlaps from .geometry import bbox_overlaps
from .sampling import (random_choice, bbox_assign, bbox_assign_wrt_overlaps, from .assignment import BBoxAssigner, AssignResult
bbox_sampling, bbox_sampling_pos, bbox_sampling_neg, from .sampling import (BBoxSampler, SamplingResult, assign_and_sample,
sample_bboxes) random_choice)
from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping, from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping,
bbox_mapping_back, bbox2roi, roi2bbox, bbox2result) bbox_mapping_back, bbox2roi, roi2bbox, bbox2result)
from .bbox_target import bbox_target from .bbox_target import bbox_target
__all__ = [ __all__ = [
'bbox_overlaps', 'random_choice', 'bbox_assign', 'bbox_overlaps', 'BBoxAssigner', 'AssignResult', 'BBoxSampler',
'bbox_assign_wrt_overlaps', 'bbox_sampling', 'bbox_sampling_pos', 'SamplingResult', 'assign_and_sample', 'random_choice', 'bbox2delta',
'bbox_sampling_neg', 'sample_bboxes', 'bbox2delta', 'delta2bbox', 'delta2bbox', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi',
'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'roi2bbox', 'bbox2result', 'bbox_target'
'bbox2result', 'bbox_target'
] ]
import torch
from .geometry import bbox_overlaps
class BBoxAssigner(object):
"""Assign a corresponding gt bbox or background to each bbox.
Each proposals will be assigned with `-1`, `0`, or a positive integer
indicating the ground truth index.
- -1: don't care
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou (float): Minimum iou for a bbox to be considered as a
positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
it is usually set as pos_iou_thr
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
`gt_bboxes_ignore` is specified). Negative values mean not
ignoring any bboxes.
"""
def __init__(self,
pos_iou_thr,
neg_iou_thr,
min_pos_iou=.0,
ignore_iof_thr=-1):
self.pos_iou_thr = pos_iou_thr
self.neg_iou_thr = neg_iou_thr
self.min_pos_iou = min_pos_iou
self.ignore_iof_thr = ignore_iof_thr
def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
"""Assign gt to bboxes.
This method assign a gt bbox to every bbox (proposal/anchor), each bbox
will be assigned with -1, 0, or a positive number. -1 means don't care,
0 means negative sample, positive number is the index (1-based) of
assigned gt.
The assignment is done in following steps, the order matters.
1. assign every bbox to -1
2. assign proposals whose iou with all gts < neg_iou_thr to 0
3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
assign it to that bbox
4. for each gt bbox, assign its nearest proposals (may be more than
one) to itself
Args:
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0:
raise ValueError('No gt or bboxes')
bboxes = bboxes[:, :4]
overlaps = bbox_overlaps(bboxes, gt_bboxes)
if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and (
gt_bboxes_ignore.numel() > 0):
ignore_overlaps = bbox_overlaps(
bboxes, gt_bboxes_ignore, mode='iof')
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
ignore_bboxes_inds = torch.nonzero(
ignore_max_overlaps > self.ignore_iof_thr).squeeze()
if ignore_bboxes_inds.numel() > 0:
overlaps[ignore_bboxes_inds[:, 0], :] = -1
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
return assign_result
def assign_wrt_overlaps(self, overlaps, gt_labels=None):
"""Assign w.r.t. the overlaps of bboxes with gts.
Args:
overlaps (Tensor): Overlaps between n bboxes and k gt_bboxes,
shape(n, k).
gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if overlaps.numel() == 0:
raise ValueError('No gt or proposals')
num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
# 1. assign -1 by default
assigned_gt_inds = overlaps.new_full(
(num_bboxes, ), -1, dtype=torch.long)
assert overlaps.size() == (num_bboxes, num_gts)
# for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts
max_overlaps, argmax_overlaps = overlaps.max(dim=1)
# for each gt, which anchor best overlaps with it
# for each gt, the max iou of all proposals
gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0)
# 2. assign negative: below
if isinstance(self.neg_iou_thr, float):
assigned_gt_inds[(max_overlaps >= 0)
& (max_overlaps < self.neg_iou_thr)] = 0
elif isinstance(self.neg_iou_thr, tuple):
assert len(self.neg_iou_thr) == 2
assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
& (max_overlaps < self.neg_iou_thr[1])] = 0
# 3. assign positive: above positive IoU threshold
pos_inds = max_overlaps >= self.pos_iou_thr
assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
# 4. assign fg: for each gt, proposals with highest IoU
for i in range(num_gts):
if gt_max_overlaps[i] >= self.min_pos_iou:
assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1
if gt_labels is not None:
assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, ))
pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
if pos_inds.numel() > 0:
assigned_labels[pos_inds] = gt_labels[
assigned_gt_inds[pos_inds] - 1]
else:
assigned_labels = None
return AssignResult(
num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
class AssignResult(object):
def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
self.num_gts = num_gts
self.gt_inds = gt_inds
self.max_overlaps = max_overlaps
self.labels = labels
def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
...@@ -4,23 +4,23 @@ from .transforms import bbox2delta ...@@ -4,23 +4,23 @@ from .transforms import bbox2delta
from ..utils import multi_apply from ..utils import multi_apply
def bbox_target(pos_proposals_list, def bbox_target(pos_bboxes_list,
neg_proposals_list, neg_bboxes_list,
pos_gt_bboxes_list, pos_gt_bboxes_list,
pos_gt_labels_list, pos_gt_labels_list,
cfg, cfg,
reg_num_classes=1, reg_classes=1,
target_means=[.0, .0, .0, .0], target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0], target_stds=[1.0, 1.0, 1.0, 1.0],
concat=True): concat=True):
labels, label_weights, bbox_targets, bbox_weights = multi_apply( labels, label_weights, bbox_targets, bbox_weights = multi_apply(
proposal_target_single, bbox_target_single,
pos_proposals_list, pos_bboxes_list,
neg_proposals_list, neg_bboxes_list,
pos_gt_bboxes_list, pos_gt_bboxes_list,
pos_gt_labels_list, pos_gt_labels_list,
cfg=cfg, cfg=cfg,
reg_num_classes=reg_num_classes, reg_classes=reg_classes,
target_means=target_means, target_means=target_means,
target_stds=target_stds) target_stds=target_stds)
...@@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list, ...@@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list,
return labels, label_weights, bbox_targets, bbox_weights return labels, label_weights, bbox_targets, bbox_weights
def proposal_target_single(pos_proposals, def bbox_target_single(pos_bboxes,
neg_proposals, neg_bboxes,
pos_gt_bboxes, pos_gt_bboxes,
pos_gt_labels, pos_gt_labels,
cfg, cfg,
reg_num_classes=1, reg_classes=1,
target_means=[.0, .0, .0, .0], target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]): target_stds=[1.0, 1.0, 1.0, 1.0]):
num_pos = pos_proposals.size(0) num_pos = pos_bboxes.size(0)
num_neg = neg_proposals.size(0) num_neg = neg_bboxes.size(0)
num_samples = num_pos + num_neg num_samples = num_pos + num_neg
labels = pos_proposals.new_zeros(num_samples, dtype=torch.long) labels = pos_bboxes.new_zeros(num_samples, dtype=torch.long)
label_weights = pos_proposals.new_zeros(num_samples) label_weights = pos_bboxes.new_zeros(num_samples)
bbox_targets = pos_proposals.new_zeros(num_samples, 4) bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
bbox_weights = pos_proposals.new_zeros(num_samples, 4) bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
if num_pos > 0: if num_pos > 0:
labels[:num_pos] = pos_gt_labels labels[:num_pos] = pos_gt_labels
pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
label_weights[:num_pos] = pos_weight label_weights[:num_pos] = pos_weight
pos_bbox_targets = bbox2delta(pos_proposals, pos_gt_bboxes, pos_bbox_targets = bbox2delta(pos_bboxes, pos_gt_bboxes, target_means,
target_means, target_stds) target_stds)
bbox_targets[:num_pos, :] = pos_bbox_targets bbox_targets[:num_pos, :] = pos_bbox_targets
bbox_weights[:num_pos, :] = 1 bbox_weights[:num_pos, :] = 1
if num_neg > 0: if num_neg > 0:
label_weights[-num_neg:] = 1.0 label_weights[-num_neg:] = 1.0
if reg_num_classes > 1: if reg_classes > 1:
bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights, bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights,
labels, reg_num_classes) labels, reg_classes)
return labels, label_weights, bbox_targets, bbox_weights return labels, label_weights, bbox_targets, bbox_weights
......
import numpy as np import numpy as np
import torch import torch
from .geometry import bbox_overlaps from .assignment import BBoxAssigner
def random_choice(gallery, num): def random_choice(gallery, num):
...@@ -21,158 +21,68 @@ def random_choice(gallery, num): ...@@ -21,158 +21,68 @@ def random_choice(gallery, num):
return gallery[rand_inds] return gallery[rand_inds]
def bbox_assign(proposals, def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
gt_bboxes, bbox_assigner = BBoxAssigner(**cfg.assigner)
gt_bboxes_ignore=None, bbox_sampler = BBoxSampler(**cfg.sampler)
gt_labels=None, assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore,
pos_iou_thr=0.5, gt_labels)
neg_iou_thr=0.5, sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes,
min_pos_iou=.0, gt_labels)
crowd_thr=-1): return assign_result, sampling_result
"""Assign a corresponding gt bbox or background to each proposal/anchor.
Each proposals will be assigned with `-1`, `0`, or a positive integer.
- -1: don't care class BBoxSampler(object):
- 0: negative sample, no assigned gt """Sample positive and negative bboxes given assigned results.
- positive integer: positive sample, index (1-based) of assigned gt
If `gt_bboxes_ignore` is specified, bboxes which have iof (intersection
over foreground) with `gt_bboxes_ignore` above `crowd_thr` will be ignored.
Args:
proposals (Tensor): Proposals or RPN anchors, shape (n, 4).
gt_bboxes (Tensor): Ground truth bboxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): shape(m, 4).
gt_labels (Tensor, optional): shape (k, ).
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou (float): Minimum iou for a bbox to be considered as a
positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
it is usually set as pos_iou_thr
crowd_thr (float): IoF threshold for ignoring bboxes. Negative value
for not ignoring any bboxes.
Returns:
tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
"""
# calculate overlaps between the proposals and the gt boxes
overlaps = bbox_overlaps(proposals, gt_bboxes)
if overlaps.numel() == 0:
raise ValueError('No gt bbox or proposals')
# ignore proposals according to crowd bboxes
if (crowd_thr > 0) and (gt_bboxes_ignore is
not None) and (gt_bboxes_ignore.numel() > 0):
crowd_overlaps = bbox_overlaps(proposals, gt_bboxes_ignore, mode='iof')
crowd_max_overlaps, _ = crowd_overlaps.max(dim=1)
crowd_bboxes_inds = torch.nonzero(
crowd_max_overlaps > crowd_thr).long()
if crowd_bboxes_inds.numel() > 0:
overlaps[crowd_bboxes_inds, :] = -1
return bbox_assign_wrt_overlaps(overlaps, gt_labels, pos_iou_thr,
neg_iou_thr, min_pos_iou)
def bbox_assign_wrt_overlaps(overlaps,
gt_labels=None,
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=.0):
"""Assign a corresponding gt bbox or background to each proposal/anchor.
This method assign a gt bbox to every proposal, each proposals will be
assigned with -1, 0, or a positive number. -1 means don't care, 0 means
negative sample, positive number is the index (1-based) of assigned gt.
The assignment is done in following steps, the order matters:
1. assign every anchor to -1
2. assign proposals whose iou with all gts < neg_iou_thr to 0
3. for each anchor, if the iou with its nearest gt >= pos_iou_thr,
assign it to that bbox
4. for each gt bbox, assign its nearest proposals(may be more than one)
to itself
Args: Args:
overlaps (Tensor): Overlaps between n proposals and k gt_bboxes, pos_fraction (float): Positive sample fraction.
shape(n, k). neg_pos_ub (float): Negative/Positive upper bound.
gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ). pos_balance_sampling (bool): Whether to sample positive samples around
pos_iou_thr (float): IoU threshold for positive bboxes. each gt bbox evenly.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes. neg_balance_thr (float, optional): IoU threshold for simple/hard
min_pos_iou (float): Minimum IoU for a bbox to be considered as a negative balance sampling.
positive bbox. This argument only affects the 4th step. neg_hard_fraction (float, optional): Fraction of hard negative samples
for negative balance sampling.
Returns:
tuple: (assigned_gt_inds, [assigned_labels], argmax_overlaps,
max_overlaps), shape (n, )
""" """
num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
# 1. assign -1 by default
assigned_gt_inds = overlaps.new(num_bboxes).long().fill_(-1)
if overlaps.numel() == 0:
raise ValueError('No gt bbox or proposals')
assert overlaps.size() == (num_bboxes, num_gts)
# for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts
max_overlaps, argmax_overlaps = overlaps.max(dim=1)
# for each gt, which anchor best overlaps with it
# for each gt, the max iou of all proposals
gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0)
# 2. assign negative: below
if isinstance(neg_iou_thr, float):
assigned_gt_inds[(max_overlaps >= 0)
& (max_overlaps < neg_iou_thr)] = 0
elif isinstance(neg_iou_thr, tuple):
assert len(neg_iou_thr) == 2
assigned_gt_inds[(max_overlaps >= neg_iou_thr[0])
& (max_overlaps < neg_iou_thr[1])] = 0
# 3. assign positive: above positive IoU threshold def __init__(self,
pos_inds = max_overlaps >= pos_iou_thr num,
assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1 pos_fraction,
neg_pos_ub=-1,
# 4. assign fg: for each gt, proposals with highest IoU add_gt_as_proposals=True,
for i in range(num_gts): pos_balance_sampling=False,
if gt_max_overlaps[i] >= min_pos_iou: neg_balance_thr=0,
assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1 neg_hard_fraction=0.5):
self.num = num
if gt_labels is None: self.pos_fraction = pos_fraction
return assigned_gt_inds, argmax_overlaps, max_overlaps self.neg_pos_ub = neg_pos_ub
else: self.add_gt_as_proposals = add_gt_as_proposals
assigned_labels = assigned_gt_inds.new(num_bboxes).fill_(0) self.pos_balance_sampling = pos_balance_sampling
pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze() self.neg_balance_thr = neg_balance_thr
if pos_inds.numel() > 0: self.neg_hard_fraction = neg_hard_fraction
assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
1] def _sample_pos(self, assign_result, num_expected):
return assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps
def bbox_sampling_pos(assigned_gt_inds, num_expected, balance_sampling=True):
"""Balance sampling for positive bboxes/anchors. """Balance sampling for positive bboxes/anchors.
1. calculate average positive num for each gt: num_per_gt 1. calculate average positive num for each gt: num_per_gt
2. sample at most num_per_gt positives for each gt 2. sample at most num_per_gt positives for each gt
3. random sampling from rest anchors if not enough fg 3. random sampling from rest anchors if not enough fg
""" """
pos_inds = torch.nonzero(assigned_gt_inds > 0) pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0: if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1) pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected: if pos_inds.numel() <= num_expected:
return pos_inds return pos_inds
elif not balance_sampling: elif not self.pos_balance_sampling:
return random_choice(pos_inds, num_expected) return random_choice(pos_inds, num_expected)
else: else:
unique_gt_inds = torch.unique(assigned_gt_inds[pos_inds].cpu()) unique_gt_inds = torch.unique(
assign_result.gt_inds[pos_inds].cpu())
num_gts = len(unique_gt_inds) num_gts = len(unique_gt_inds)
num_per_gt = int(round(num_expected / float(num_gts)) + 1) num_per_gt = int(round(num_expected / float(num_gts)) + 1)
sampled_inds = [] sampled_inds = []
for i in unique_gt_inds: for i in unique_gt_inds:
inds = torch.nonzero(assigned_gt_inds == i.item()) inds = torch.nonzero(assign_result.gt_inds == i.item())
if inds.numel() != 0: if inds.numel() != 0:
inds = inds.squeeze(1) inds = inds.squeeze(1)
else: else:
...@@ -188,56 +98,53 @@ def bbox_sampling_pos(assigned_gt_inds, num_expected, balance_sampling=True): ...@@ -188,56 +98,53 @@ def bbox_sampling_pos(assigned_gt_inds, num_expected, balance_sampling=True):
if len(extra_inds) > num_extra: if len(extra_inds) > num_extra:
extra_inds = random_choice(extra_inds, num_extra) extra_inds = random_choice(extra_inds, num_extra)
extra_inds = torch.from_numpy(extra_inds).to( extra_inds = torch.from_numpy(extra_inds).to(
assigned_gt_inds.device).long() assign_result.gt_inds.device).long()
sampled_inds = torch.cat([sampled_inds, extra_inds]) sampled_inds = torch.cat([sampled_inds, extra_inds])
elif len(sampled_inds) > num_expected: elif len(sampled_inds) > num_expected:
sampled_inds = random_choice(sampled_inds, num_expected) sampled_inds = random_choice(sampled_inds, num_expected)
return sampled_inds return sampled_inds
def _sample_neg(self, assign_result, num_expected):
def bbox_sampling_neg(assigned_gt_inds,
num_expected,
max_overlaps=None,
balance_thr=0,
hard_fraction=0.5):
"""Balance sampling for negative bboxes/anchors. """Balance sampling for negative bboxes/anchors.
Negative samples are split into 2 set: hard (balance_thr <= iou < Negative samples are split into 2 set: hard (balance_thr <= iou <
neg_iou_thr) and easy(iou < balance_thr). The sampling ratio is controlled neg_iou_thr) and easy (iou < balance_thr). The sampling ratio is
by `hard_fraction`. controlled by `hard_fraction`.
""" """
neg_inds = torch.nonzero(assigned_gt_inds == 0) neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0: if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1) neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected: if len(neg_inds) <= num_expected:
return neg_inds return neg_inds
elif balance_thr <= 0: elif self.neg_balance_thr <= 0:
# uniform sampling among all negative samples # uniform sampling among all negative samples
return random_choice(neg_inds, num_expected) return random_choice(neg_inds, num_expected)
else: else:
assert max_overlaps is not None max_overlaps = assign_result.max_overlaps.cpu().numpy()
max_overlaps = max_overlaps.cpu().numpy()
# balance sampling for negative samples # balance sampling for negative samples
neg_set = set(neg_inds.cpu().numpy()) neg_set = set(neg_inds.cpu().numpy())
easy_set = set( easy_set = set(
np.where( np.where(
np.logical_and(max_overlaps >= 0, np.logical_and(max_overlaps >= 0,
max_overlaps < balance_thr))[0]) max_overlaps < self.neg_balance_thr))[0])
hard_set = set(np.where(max_overlaps >= balance_thr)[0]) hard_set = set(np.where(max_overlaps >= self.neg_balance_thr)[0])
easy_neg_inds = list(easy_set & neg_set) easy_neg_inds = list(easy_set & neg_set)
hard_neg_inds = list(hard_set & neg_set) hard_neg_inds = list(hard_set & neg_set)
num_expected_hard = int(num_expected * hard_fraction) num_expected_hard = int(num_expected * self.neg_hard_fraction)
if len(hard_neg_inds) > num_expected_hard: if len(hard_neg_inds) > num_expected_hard:
sampled_hard_inds = random_choice(hard_neg_inds, num_expected_hard) sampled_hard_inds = random_choice(hard_neg_inds,
num_expected_hard)
else: else:
sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int) sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
num_expected_easy = num_expected - len(sampled_hard_inds) num_expected_easy = num_expected - len(sampled_hard_inds)
if len(easy_neg_inds) > num_expected_easy: if len(easy_neg_inds) > num_expected_easy:
sampled_easy_inds = random_choice(easy_neg_inds, num_expected_easy) sampled_easy_inds = random_choice(easy_neg_inds,
num_expected_easy)
else: else:
sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int) sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
sampled_inds = np.concatenate((sampled_easy_inds, sampled_hard_inds)) sampled_inds = np.concatenate((sampled_easy_inds,
sampled_hard_inds))
if len(sampled_inds) < num_expected: if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds) num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(neg_set - set(sampled_inds))) extra_inds = np.array(list(neg_set - set(sampled_inds)))
...@@ -245,99 +152,76 @@ def bbox_sampling_neg(assigned_gt_inds, ...@@ -245,99 +152,76 @@ def bbox_sampling_neg(assigned_gt_inds,
extra_inds = random_choice(extra_inds, num_extra) extra_inds = random_choice(extra_inds, num_extra)
sampled_inds = np.concatenate((sampled_inds, extra_inds)) sampled_inds = np.concatenate((sampled_inds, extra_inds))
sampled_inds = torch.from_numpy(sampled_inds).long().to( sampled_inds = torch.from_numpy(sampled_inds).long().to(
assigned_gt_inds.device) assign_result.gt_inds.device)
return sampled_inds return sampled_inds
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None):
def bbox_sampling(assigned_gt_inds,
num_expected,
pos_fraction,
neg_pos_ub,
pos_balance_sampling=True,
max_overlaps=None,
neg_balance_thr=0,
neg_hard_fraction=0.5):
"""Sample positive and negative bboxes given assigned results.
Args:
assigned_gt_inds (Tensor): Assigned gt indices for each bbox.
num_expected (int): Expected total samples (pos and neg).
pos_fraction (float): Positive sample fraction.
neg_pos_ub (float): Negative/Positive upper bound.
pos_balance_sampling(bool): Whether to sample positive samples around
each gt bbox evenly.
max_overlaps (Tensor, optional): For each bbox, the max IoU of all gts.
Used for negative balance sampling only.
neg_balance_thr (float, optional): IoU threshold for simple/hard
negative balance sampling.
neg_hard_fraction (float, optional): Fraction of hard negative samples
for negative balance sampling.
Returns:
tuple[Tensor]: positive bbox indices, negative bbox indices.
"""
num_expected_pos = int(num_expected * pos_fraction)
pos_inds = bbox_sampling_pos(assigned_gt_inds, num_expected_pos,
pos_balance_sampling)
# We found that sampled indices have duplicated items occasionally.
# (mab be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_neg_max = int(
neg_pos_ub *
num_sampled_pos) if num_sampled_pos > 0 else int(neg_pos_ub)
num_expected_neg = min(num_neg_max, num_expected - num_sampled_pos)
neg_inds = bbox_sampling_neg(assigned_gt_inds, num_expected_neg,
max_overlaps, neg_balance_thr,
neg_hard_fraction)
neg_inds = neg_inds.unique()
return pos_inds, neg_inds
def sample_bboxes(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
"""Sample positive and negative bboxes. """Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates and This is a simple implementation of bbox sampling given candidates,
ground truth bboxes, which includes 3 steps. assigning results and ground truth bboxes.
1. Assign gt to each bbox. 1. Assign gt to each bbox.
2. Add gt bboxes to the sampling pool (optional). 2. Add gt bboxes to the sampling pool (optional).
3. Perform positive and negative sampling. 3. Perform positive and negative sampling.
Args: Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from. bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes. gt_bboxes (Tensor): Ground truth bboxes.
gt_bboxes_ignore (Tensor): Ignored ground truth bboxes. In MS COCO, gt_labels (Tensor, optional): Class labels of ground truth bboxes.
`crowd` bboxes are considered as ignored.
gt_labels (Tensor): Class labels of ground truth bboxes.
cfg (dict): Sampling configs.
Returns: Returns:
tuple[Tensor]: pos_bboxes, neg_bboxes, pos_assigned_gt_inds, :obj:`SamplingResult`: Sampling result.
pos_gt_bboxes, pos_gt_labels
""" """
bboxes = bboxes[:, :4] bboxes = bboxes[:, :4]
assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
bbox_assign(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels,
cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.min_pos_iou,
cfg.crowd_thr)
if cfg.add_gt_as_proposals: gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
if self.add_gt_as_proposals:
bboxes = torch.cat([gt_bboxes, bboxes], dim=0) bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
gt_assign_self = torch.arange( assign_result.add_gt_(gt_labels)
1, len(gt_labels) + 1, dtype=torch.long, device=bboxes.device) gt_flags = torch.cat([
assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds]) bboxes.new_ones((gt_bboxes.shape[0], ), dtype=torch.uint8),
assigned_labels = torch.cat([gt_labels, assigned_labels]) gt_flags
])
num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self._sample_pos(assign_result, num_expected_pos)
# We found that sampled indices have duplicated items occasionally.
# (mab be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_expected_neg = self.num - num_sampled_pos
if self.neg_pos_ub >= 0:
num_neg_max = int(self.neg_pos_ub *
num_sampled_pos) if num_sampled_pos > 0 else int(
self.neg_pos_ub)
num_expected_neg = min(num_neg_max, num_expected_neg)
neg_inds = self._sample_neg(assign_result, num_expected_neg)
neg_inds = neg_inds.unique()
pos_inds, neg_inds = bbox_sampling( return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction, cfg.neg_pos_ub, assign_result, gt_flags)
cfg.pos_balance_sampling, max_overlaps, cfg.neg_balance_thr)
pos_bboxes = bboxes[pos_inds]
neg_bboxes = bboxes[neg_inds]
pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
pos_gt_labels = assigned_labels[pos_inds]
return (pos_bboxes, neg_bboxes, pos_assigned_gt_inds, pos_gt_bboxes, class SamplingResult(object):
pos_gt_labels)
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
gt_flags):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
self.pos_bboxes = bboxes[pos_inds]
self.neg_bboxes = bboxes[neg_inds]
self.pos_is_gt = gt_flags[pos_inds]
self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
self.pos_gt_labels = None
@property
def bboxes(self):
return torch.cat([self.pos_bboxes, self.neg_bboxes])
...@@ -215,7 +215,7 @@ class CocoDataset(Dataset): ...@@ -215,7 +215,7 @@ class CocoDataset(Dataset):
'proposals should have shapes (n, 4) or (n, 5), ' 'proposals should have shapes (n, 4) or (n, 5), '
'but found {}'.format(proposals.shape)) 'but found {}'.format(proposals.shape))
if proposals.shape[1] == 5: if proposals.shape[1] == 5:
scores = proposals[:, 4] scores = proposals[:, 4, None]
proposals = proposals[:, :4] proposals = proposals[:, :4]
else: else:
scores = None scores = None
...@@ -237,8 +237,8 @@ class CocoDataset(Dataset): ...@@ -237,8 +237,8 @@ class CocoDataset(Dataset):
if self.proposals is not None: if self.proposals is not None:
proposals = self.bbox_transform(proposals, img_shape, proposals = self.bbox_transform(proposals, img_shape,
scale_factor, flip) scale_factor, flip)
proposals = np.hstack([proposals, scores[:, None] proposals = np.hstack(
]) if scores is not None else proposals [proposals, scores]) if scores is not None else proposals
gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor, gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
flip) flip)
gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape, gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
...@@ -295,14 +295,14 @@ class CocoDataset(Dataset): ...@@ -295,14 +295,14 @@ class CocoDataset(Dataset):
flip=flip) flip=flip)
if proposal is not None: if proposal is not None:
if proposal.shape[1] == 5: if proposal.shape[1] == 5:
score = proposal[:, 4] score = proposal[:, 4, None]
proposal = proposal[:, :4] proposal = proposal[:, :4]
else: else:
score = None score = None
_proposal = self.bbox_transform(proposal, img_shape, _proposal = self.bbox_transform(proposal, img_shape,
scale_factor, flip) scale_factor, flip)
_proposal = np.hstack([_proposal, score[:, None] _proposal = np.hstack(
]) if score is not None else _proposal [_proposal, score]) if score is not None else _proposal
_proposal = to_tensor(_proposal) _proposal = to_tensor(_proposal)
else: else:
_proposal = None _proposal = None
......
...@@ -59,16 +59,20 @@ class BBoxHead(nn.Module): ...@@ -59,16 +59,20 @@ class BBoxHead(nn.Module):
bbox_pred = self.fc_reg(x) if self.with_reg else None bbox_pred = self.fc_reg(x) if self.with_reg else None
return cls_score, bbox_pred return cls_score, bbox_pred
def get_bbox_target(self, pos_proposals, neg_proposals, pos_gt_bboxes, def get_target(self, sampling_results, gt_bboxes, gt_labels,
pos_gt_labels, rcnn_train_cfg): rcnn_train_cfg):
reg_num_classes = 1 if self.reg_class_agnostic else self.num_classes pos_proposals = [res.pos_bboxes for res in sampling_results]
neg_proposals = [res.neg_bboxes for res in sampling_results]
pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
reg_classes = 1 if self.reg_class_agnostic else self.num_classes
cls_reg_targets = bbox_target( cls_reg_targets = bbox_target(
pos_proposals, pos_proposals,
neg_proposals, neg_proposals,
pos_gt_bboxes, pos_gt_bboxes,
pos_gt_labels, pos_gt_labels,
rcnn_train_cfg, rcnn_train_cfg,
reg_num_classes, reg_classes,
target_means=self.target_means, target_means=self.target_means,
target_stds=self.target_stds) target_stds=self.target_stds)
return cls_reg_targets return cls_reg_targets
......
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from mmdet.core import sample_bboxes, bbox2roi, bbox2result, multi_apply from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply)
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...@@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
gt_labels, gt_labels,
gt_masks=None, gt_masks=None,
proposals=None): proposals=None):
losses = dict()
x = self.extract_feat(img) x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
if self.with_rpn: if self.with_rpn:
rpn_outs = self.rpn_head(x) rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
...@@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
else: else:
proposal_list = proposals proposal_list = proposals
if self.with_bbox: # assign gts and sample proposals
(pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes, if self.with_bbox or self.with_mask:
pos_gt_labels) = multi_apply( assign_results, sampling_results = multi_apply(
sample_bboxes, assign_and_sample,
proposal_list, proposal_list,
gt_bboxes, gt_bboxes,
gt_bboxes_ignore, gt_bboxes_ignore,
gt_labels, gt_labels,
cfg=self.train_cfg.rcnn) cfg=self.train_cfg.rcnn)
(labels, label_weights, bbox_targets,
bbox_weights) = self.bbox_head.get_bbox_target( # bbox head forward and loss
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels, if self.with_bbox:
self.train_cfg.rcnn) rois = bbox2roi([res.bboxes for res in sampling_results])
# TODO: a more flexible way to decide which feature maps to use
rois = bbox2roi([ bbox_feats = self.bbox_roi_extractor(
torch.cat([pos, neg], dim=0)
for pos, neg in zip(pos_proposals, neg_proposals)
])
# TODO: a more flexible way to configurate feat maps
roi_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois) x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels, bbox_targets = self.bbox_head.get_target(
label_weights, bbox_targets, sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn)
bbox_weights) loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
*bbox_targets)
losses.update(loss_bbox) losses.update(loss_bbox)
# mask head forward and loss
if self.with_mask: if self.with_mask:
mask_targets = self.mask_head.get_mask_target( pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
pos_proposals, pos_assigned_gt_inds, gt_masks,
self.train_cfg.rcnn)
pos_rois = bbox2roi(pos_proposals)
mask_feats = self.mask_roi_extractor( mask_feats = self.mask_roi_extractor(
x[:self.mask_roi_extractor.num_inputs], pos_rois) x[:self.mask_roi_extractor.num_inputs], pos_rois)
mask_pred = self.mask_head(mask_feats) mask_pred = self.mask_head(mask_feats)
mask_targets = self.mask_head.get_target(
sampling_results, gt_masks, self.train_cfg.rcnn)
pos_labels = torch.cat(
[res.pos_gt_labels for res in sampling_results])
loss_mask = self.mask_head.loss(mask_pred, mask_targets, loss_mask = self.mask_head.loss(mask_pred, mask_targets,
torch.cat(pos_gt_labels)) pos_labels)
losses.update(loss_mask) losses.update(loss_mask)
return losses return losses
...@@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
x = self.extract_feat(img) x = self.extract_feat(img)
proposal_list = self.simple_test_rpn( proposal_list = self.simple_test_rpn(
x, img_meta, x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
self.test_cfg.rpn) if proposals is None else proposals
det_bboxes, det_labels = self.simple_test_bboxes( det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
......
...@@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module): ...@@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module):
mask_pred = self.conv_logits(x) mask_pred = self.conv_logits(x)
return mask_pred return mask_pred
def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks, def get_target(self, sampling_results, gt_masks, rcnn_train_cfg):
rcnn_train_cfg): pos_proposals = [res.pos_bboxes for res in sampling_results]
pos_assigned_gt_inds = [
res.pos_assigned_gt_inds for res in sampling_results
]
mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
gt_masks, rcnn_train_cfg) gt_masks, rcnn_train_cfg)
return mask_targets return mask_targets
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment