Commit bac11303 authored by Kai Chen's avatar Kai Chen
Browse files

add BBoxAssigner and BBoxSampler

parent f3768bcd
...@@ -43,17 +43,19 @@ model = dict( ...@@ -43,17 +43,19 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rcnn=dict( rcnn=dict(
assigner=dict(
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True,
pos_balance_sampling=False,
neg_balance_thr=0),
mask_size=28, mask_size=28,
pos_iou_thr=0.5,
neg_iou_thr=0.5,
crowd_thr=1.1,
roi_batch_size=512,
add_gt_as_proposals=True,
pos_fraction=0.25,
pos_balance_sampling=False,
neg_pos_ub=512,
neg_balance_thr=0,
min_pos_iou=0.5,
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict( test_cfg = dict(
......
...@@ -32,16 +32,18 @@ model = dict( ...@@ -32,16 +32,18 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rcnn=dict( rcnn=dict(
pos_iou_thr=0.5, assigner=dict(
neg_iou_thr=0.5, pos_iou_thr=0.5,
crowd_thr=1.1, neg_iou_thr=0.5,
roi_batch_size=512, min_pos_iou=0.5,
add_gt_as_proposals=True, ignore_iof_thr=-1),
pos_fraction=0.25, sampler=dict(
pos_balance_sampling=False, num=512,
neg_pos_ub=512, pos_fraction=0.25,
neg_balance_thr=0, neg_pos_ub=-1,
min_pos_iou=0.5, add_gt_as_proposals=True,
pos_balance_sampling=False,
neg_balance_thr=0),
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5)) test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
......
...@@ -42,30 +42,35 @@ model = dict( ...@@ -42,30 +42,35 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict( rpn=dict(
pos_fraction=0.5, assigner=dict(
pos_balance_sampling=False, pos_iou_thr=0.7,
neg_pos_ub=256, neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False,
pos_balance_sampling=False,
neg_balance_thr=0),
allowed_border=0, allowed_border=0,
crowd_thr=1.1,
anchor_batch_size=256,
pos_iou_thr=0.7,
neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=0.3,
pos_weight=-1, pos_weight=-1,
smoothl1_beta=1 / 9.0, smoothl1_beta=1 / 9.0,
debug=False), debug=False),
rcnn=dict( rcnn=dict(
pos_iou_thr=0.5, assigner=dict(
neg_iou_thr=0.5, pos_iou_thr=0.5,
crowd_thr=1.1, neg_iou_thr=0.5,
roi_batch_size=512, min_pos_iou=0.5,
add_gt_as_proposals=True, ignore_iof_thr=-1),
pos_fraction=0.25, sampler=dict(
pos_balance_sampling=False, num=512,
neg_pos_ub=512, pos_fraction=0.25,
neg_balance_thr=0, neg_pos_ub=-1,
min_pos_iou=0.5, add_gt_as_proposals=True,
pos_balance_sampling=False,
neg_balance_thr=0),
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict( test_cfg = dict(
......
...@@ -53,31 +53,36 @@ model = dict( ...@@ -53,31 +53,36 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict( rpn=dict(
pos_fraction=0.5, assigner=dict(
pos_balance_sampling=False, pos_iou_thr=0.7,
neg_pos_ub=256, neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False,
pos_balance_sampling=False,
neg_balance_thr=0),
allowed_border=0, allowed_border=0,
crowd_thr=1.1,
anchor_batch_size=256,
pos_iou_thr=0.7,
neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=0.3,
pos_weight=-1, pos_weight=-1,
smoothl1_beta=1 / 9.0, smoothl1_beta=1 / 9.0,
debug=False), debug=False),
rcnn=dict( rcnn=dict(
assigner=dict(
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True,
pos_balance_sampling=False,
neg_balance_thr=0),
mask_size=28, mask_size=28,
pos_iou_thr=0.5,
neg_iou_thr=0.5,
crowd_thr=1.1,
roi_batch_size=512,
add_gt_as_proposals=True,
pos_fraction=0.25,
pos_balance_sampling=False,
neg_pos_ub=512,
neg_balance_thr=0,
min_pos_iou=0.5,
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict( test_cfg = dict(
......
...@@ -27,16 +27,19 @@ model = dict( ...@@ -27,16 +27,19 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict( rpn=dict(
pos_fraction=0.5, assigner=dict(
pos_balance_sampling=False, pos_iou_thr=0.7,
neg_pos_ub=256, neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False,
pos_balance_sampling=False,
neg_balance_thr=0),
allowed_border=0, allowed_border=0,
crowd_thr=1.1,
anchor_batch_size=256,
pos_iou_thr=0.7,
neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=0.3,
pos_weight=-1, pos_weight=-1,
smoothl1_beta=1 / 9.0, smoothl1_beta=1 / 9.0,
debug=False)) debug=False))
......
import torch import torch
from ..bbox import bbox_assign, bbox2delta, bbox_sampling from ..bbox import assign_and_sample, bbox2delta
from ..utils import multi_apply from ..utils import multi_apply
...@@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta, ...@@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
return (None, ) * 6 return (None, ) * 6
# assign gt and sample anchors # assign gt and sample anchors
anchors = flat_anchors[inside_flags, :] anchors = flat_anchors[inside_flags, :]
assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign( _, sampling_result = assign_and_sample(anchors, gt_bboxes, None, None, cfg)
anchors,
gt_bboxes,
pos_iou_thr=cfg.pos_iou_thr,
neg_iou_thr=cfg.neg_iou_thr,
min_pos_iou=cfg.min_pos_iou)
pos_inds, neg_inds = bbox_sampling(assigned_gt_inds, cfg.anchor_batch_size,
cfg.pos_fraction, cfg.neg_pos_ub,
cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors) bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors) bbox_weights = torch.zeros_like(anchors)
labels = torch.zeros_like(assigned_gt_inds) labels = anchors.new_zeros((num_valid_anchors, ))
label_weights = torch.zeros_like(assigned_gt_inds, dtype=anchors.dtype) label_weights = anchors.new_zeros((num_valid_anchors, ))
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0: if len(pos_inds) > 0:
pos_anchors = anchors[pos_inds, :] pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes,
pos_gt_bbox = gt_bboxes[assigned_gt_inds[pos_inds] - 1, :] sampling_result.pos_gt_bboxes,
pos_bbox_targets = bbox2delta(pos_anchors, pos_gt_bbox, target_means, target_means, target_stds)
target_stds)
bbox_targets[pos_inds, :] = pos_bbox_targets bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0 bbox_weights[pos_inds, :] = 1.0
labels[pos_inds] = 1 labels[pos_inds] = 1
......
from .geometry import bbox_overlaps from .geometry import bbox_overlaps
from .sampling import (random_choice, bbox_assign, bbox_assign_wrt_overlaps, from .assignment import BBoxAssigner, AssignResult
bbox_sampling, bbox_sampling_pos, bbox_sampling_neg, from .sampling import (BBoxSampler, SamplingResult, assign_and_sample,
sample_bboxes) random_choice)
from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping, from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping,
bbox_mapping_back, bbox2roi, roi2bbox, bbox2result) bbox_mapping_back, bbox2roi, roi2bbox, bbox2result)
from .bbox_target import bbox_target from .bbox_target import bbox_target
__all__ = [ __all__ = [
'bbox_overlaps', 'random_choice', 'bbox_assign', 'bbox_overlaps', 'BBoxAssigner', 'AssignResult', 'BBoxSampler',
'bbox_assign_wrt_overlaps', 'bbox_sampling', 'bbox_sampling_pos', 'SamplingResult', 'assign_and_sample', 'random_choice', 'bbox2delta',
'bbox_sampling_neg', 'sample_bboxes', 'bbox2delta', 'delta2bbox', 'delta2bbox', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi',
'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'roi2bbox', 'bbox2result', 'bbox_target'
'bbox2result', 'bbox_target'
] ]
import torch
from .geometry import bbox_overlaps
class BBoxAssigner(object):
"""Assign a corresponding gt bbox or background to each bbox.
Each proposals will be assigned with `-1`, `0`, or a positive integer
indicating the ground truth index.
- -1: don't care
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou (float): Minimum iou for a bbox to be considered as a
positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
it is usually set as pos_iou_thr
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
`gt_bboxes_ignore` is specified). Negative values mean not
ignoring any bboxes.
"""
def __init__(self,
pos_iou_thr,
neg_iou_thr,
min_pos_iou=.0,
ignore_iof_thr=-1):
self.pos_iou_thr = pos_iou_thr
self.neg_iou_thr = neg_iou_thr
self.min_pos_iou = min_pos_iou
self.ignore_iof_thr = ignore_iof_thr
def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
"""Assign gt to bboxes.
This method assign a gt bbox to every bbox (proposal/anchor), each bbox
will be assigned with -1, 0, or a positive number. -1 means don't care,
0 means negative sample, positive number is the index (1-based) of
assigned gt.
The assignment is done in following steps, the order matters.
1. assign every bbox to -1
2. assign proposals whose iou with all gts < neg_iou_thr to 0
3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
assign it to that bbox
4. for each gt bbox, assign its nearest proposals (may be more than
one) to itself
Args:
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0:
raise ValueError('No gt or bboxes')
bboxes = bboxes[:, :4]
overlaps = bbox_overlaps(bboxes, gt_bboxes)
if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and (
gt_bboxes_ignore.numel() > 0):
ignore_overlaps = bbox_overlaps(
bboxes, gt_bboxes_ignore, mode='iof')
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
ignore_bboxes_inds = torch.nonzero(
ignore_max_overlaps > self.ignore_iof_thr).squeeze()
if ignore_bboxes_inds.numel() > 0:
overlaps[ignore_bboxes_inds[:, 0], :] = -1
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
return assign_result
def assign_wrt_overlaps(self, overlaps, gt_labels=None):
"""Assign w.r.t. the overlaps of bboxes with gts.
Args:
overlaps (Tensor): Overlaps between n bboxes and k gt_bboxes,
shape(n, k).
gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if overlaps.numel() == 0:
raise ValueError('No gt or proposals')
num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
# 1. assign -1 by default
assigned_gt_inds = overlaps.new_full(
(num_bboxes, ), -1, dtype=torch.long)
assert overlaps.size() == (num_bboxes, num_gts)
# for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts
max_overlaps, argmax_overlaps = overlaps.max(dim=1)
# for each gt, which anchor best overlaps with it
# for each gt, the max iou of all proposals
gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0)
# 2. assign negative: below
if isinstance(self.neg_iou_thr, float):
assigned_gt_inds[(max_overlaps >= 0)
& (max_overlaps < self.neg_iou_thr)] = 0
elif isinstance(self.neg_iou_thr, tuple):
assert len(self.neg_iou_thr) == 2
assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
& (max_overlaps < self.neg_iou_thr[1])] = 0
# 3. assign positive: above positive IoU threshold
pos_inds = max_overlaps >= self.pos_iou_thr
assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
# 4. assign fg: for each gt, proposals with highest IoU
for i in range(num_gts):
if gt_max_overlaps[i] >= self.min_pos_iou:
assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1
if gt_labels is not None:
assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, ))
pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
if pos_inds.numel() > 0:
assigned_labels[pos_inds] = gt_labels[
assigned_gt_inds[pos_inds] - 1]
else:
assigned_labels = None
return AssignResult(
num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
class AssignResult(object):
def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
self.num_gts = num_gts
self.gt_inds = gt_inds
self.max_overlaps = max_overlaps
self.labels = labels
def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
...@@ -4,23 +4,23 @@ from .transforms import bbox2delta ...@@ -4,23 +4,23 @@ from .transforms import bbox2delta
from ..utils import multi_apply from ..utils import multi_apply
def bbox_target(pos_proposals_list, def bbox_target(pos_bboxes_list,
neg_proposals_list, neg_bboxes_list,
pos_gt_bboxes_list, pos_gt_bboxes_list,
pos_gt_labels_list, pos_gt_labels_list,
cfg, cfg,
reg_num_classes=1, reg_classes=1,
target_means=[.0, .0, .0, .0], target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0], target_stds=[1.0, 1.0, 1.0, 1.0],
concat=True): concat=True):
labels, label_weights, bbox_targets, bbox_weights = multi_apply( labels, label_weights, bbox_targets, bbox_weights = multi_apply(
proposal_target_single, bbox_target_single,
pos_proposals_list, pos_bboxes_list,
neg_proposals_list, neg_bboxes_list,
pos_gt_bboxes_list, pos_gt_bboxes_list,
pos_gt_labels_list, pos_gt_labels_list,
cfg=cfg, cfg=cfg,
reg_num_classes=reg_num_classes, reg_classes=reg_classes,
target_means=target_means, target_means=target_means,
target_stds=target_stds) target_stds=target_stds)
...@@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list, ...@@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list,
return labels, label_weights, bbox_targets, bbox_weights return labels, label_weights, bbox_targets, bbox_weights
def proposal_target_single(pos_proposals, def bbox_target_single(pos_bboxes,
neg_proposals, neg_bboxes,
pos_gt_bboxes, pos_gt_bboxes,
pos_gt_labels, pos_gt_labels,
cfg, cfg,
reg_num_classes=1, reg_classes=1,
target_means=[.0, .0, .0, .0], target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]): target_stds=[1.0, 1.0, 1.0, 1.0]):
num_pos = pos_proposals.size(0) num_pos = pos_bboxes.size(0)
num_neg = neg_proposals.size(0) num_neg = neg_bboxes.size(0)
num_samples = num_pos + num_neg num_samples = num_pos + num_neg
labels = pos_proposals.new_zeros(num_samples, dtype=torch.long) labels = pos_bboxes.new_zeros(num_samples, dtype=torch.long)
label_weights = pos_proposals.new_zeros(num_samples) label_weights = pos_bboxes.new_zeros(num_samples)
bbox_targets = pos_proposals.new_zeros(num_samples, 4) bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
bbox_weights = pos_proposals.new_zeros(num_samples, 4) bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
if num_pos > 0: if num_pos > 0:
labels[:num_pos] = pos_gt_labels labels[:num_pos] = pos_gt_labels
pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
label_weights[:num_pos] = pos_weight label_weights[:num_pos] = pos_weight
pos_bbox_targets = bbox2delta(pos_proposals, pos_gt_bboxes, pos_bbox_targets = bbox2delta(pos_bboxes, pos_gt_bboxes, target_means,
target_means, target_stds) target_stds)
bbox_targets[:num_pos, :] = pos_bbox_targets bbox_targets[:num_pos, :] = pos_bbox_targets
bbox_weights[:num_pos, :] = 1 bbox_weights[:num_pos, :] = 1
if num_neg > 0: if num_neg > 0:
label_weights[-num_neg:] = 1.0 label_weights[-num_neg:] = 1.0
if reg_num_classes > 1: if reg_classes > 1:
bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights, bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights,
labels, reg_num_classes) labels, reg_classes)
return labels, label_weights, bbox_targets, bbox_weights return labels, label_weights, bbox_targets, bbox_weights
......
This diff is collapsed.
...@@ -215,7 +215,7 @@ class CocoDataset(Dataset): ...@@ -215,7 +215,7 @@ class CocoDataset(Dataset):
'proposals should have shapes (n, 4) or (n, 5), ' 'proposals should have shapes (n, 4) or (n, 5), '
'but found {}'.format(proposals.shape)) 'but found {}'.format(proposals.shape))
if proposals.shape[1] == 5: if proposals.shape[1] == 5:
scores = proposals[:, 4] scores = proposals[:, 4, None]
proposals = proposals[:, :4] proposals = proposals[:, :4]
else: else:
scores = None scores = None
...@@ -237,8 +237,8 @@ class CocoDataset(Dataset): ...@@ -237,8 +237,8 @@ class CocoDataset(Dataset):
if self.proposals is not None: if self.proposals is not None:
proposals = self.bbox_transform(proposals, img_shape, proposals = self.bbox_transform(proposals, img_shape,
scale_factor, flip) scale_factor, flip)
proposals = np.hstack([proposals, scores[:, None] proposals = np.hstack(
]) if scores is not None else proposals [proposals, scores]) if scores is not None else proposals
gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor, gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
flip) flip)
gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape, gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
...@@ -295,14 +295,14 @@ class CocoDataset(Dataset): ...@@ -295,14 +295,14 @@ class CocoDataset(Dataset):
flip=flip) flip=flip)
if proposal is not None: if proposal is not None:
if proposal.shape[1] == 5: if proposal.shape[1] == 5:
score = proposal[:, 4] score = proposal[:, 4, None]
proposal = proposal[:, :4] proposal = proposal[:, :4]
else: else:
score = None score = None
_proposal = self.bbox_transform(proposal, img_shape, _proposal = self.bbox_transform(proposal, img_shape,
scale_factor, flip) scale_factor, flip)
_proposal = np.hstack([_proposal, score[:, None] _proposal = np.hstack(
]) if score is not None else _proposal [_proposal, score]) if score is not None else _proposal
_proposal = to_tensor(_proposal) _proposal = to_tensor(_proposal)
else: else:
_proposal = None _proposal = None
......
...@@ -59,16 +59,20 @@ class BBoxHead(nn.Module): ...@@ -59,16 +59,20 @@ class BBoxHead(nn.Module):
bbox_pred = self.fc_reg(x) if self.with_reg else None bbox_pred = self.fc_reg(x) if self.with_reg else None
return cls_score, bbox_pred return cls_score, bbox_pred
def get_bbox_target(self, pos_proposals, neg_proposals, pos_gt_bboxes, def get_target(self, sampling_results, gt_bboxes, gt_labels,
pos_gt_labels, rcnn_train_cfg): rcnn_train_cfg):
reg_num_classes = 1 if self.reg_class_agnostic else self.num_classes pos_proposals = [res.pos_bboxes for res in sampling_results]
neg_proposals = [res.neg_bboxes for res in sampling_results]
pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
reg_classes = 1 if self.reg_class_agnostic else self.num_classes
cls_reg_targets = bbox_target( cls_reg_targets = bbox_target(
pos_proposals, pos_proposals,
neg_proposals, neg_proposals,
pos_gt_bboxes, pos_gt_bboxes,
pos_gt_labels, pos_gt_labels,
rcnn_train_cfg, rcnn_train_cfg,
reg_num_classes, reg_classes,
target_means=self.target_means, target_means=self.target_means,
target_stds=self.target_stds) target_stds=self.target_stds)
return cls_reg_targets return cls_reg_targets
......
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from mmdet.core import sample_bboxes, bbox2roi, bbox2result, multi_apply from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply)
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...@@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
gt_labels, gt_labels,
gt_masks=None, gt_masks=None,
proposals=None): proposals=None):
losses = dict()
x = self.extract_feat(img) x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
if self.with_rpn: if self.with_rpn:
rpn_outs = self.rpn_head(x) rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
...@@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
else: else:
proposal_list = proposals proposal_list = proposals
# assign gts and sample proposals
if self.with_bbox or self.with_mask:
assign_results, sampling_results = multi_apply(
assign_and_sample,
proposal_list,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
cfg=self.train_cfg.rcnn)
# bbox head forward and loss
if self.with_bbox: if self.with_bbox:
(pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes, rois = bbox2roi([res.bboxes for res in sampling_results])
pos_gt_labels) = multi_apply( # TODO: a more flexible way to decide which feature maps to use
sample_bboxes, bbox_feats = self.bbox_roi_extractor(
proposal_list,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
cfg=self.train_cfg.rcnn)
(labels, label_weights, bbox_targets,
bbox_weights) = self.bbox_head.get_bbox_target(
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
self.train_cfg.rcnn)
rois = bbox2roi([
torch.cat([pos, neg], dim=0)
for pos, neg in zip(pos_proposals, neg_proposals)
])
# TODO: a more flexible way to configurate feat maps
roi_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois) x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels, bbox_targets = self.bbox_head.get_target(
label_weights, bbox_targets, sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn)
bbox_weights) loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
*bbox_targets)
losses.update(loss_bbox) losses.update(loss_bbox)
# mask head forward and loss
if self.with_mask: if self.with_mask:
mask_targets = self.mask_head.get_mask_target( pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
pos_proposals, pos_assigned_gt_inds, gt_masks,
self.train_cfg.rcnn)
pos_rois = bbox2roi(pos_proposals)
mask_feats = self.mask_roi_extractor( mask_feats = self.mask_roi_extractor(
x[:self.mask_roi_extractor.num_inputs], pos_rois) x[:self.mask_roi_extractor.num_inputs], pos_rois)
mask_pred = self.mask_head(mask_feats) mask_pred = self.mask_head(mask_feats)
mask_targets = self.mask_head.get_target(
sampling_results, gt_masks, self.train_cfg.rcnn)
pos_labels = torch.cat(
[res.pos_gt_labels for res in sampling_results])
loss_mask = self.mask_head.loss(mask_pred, mask_targets, loss_mask = self.mask_head.loss(mask_pred, mask_targets,
torch.cat(pos_gt_labels)) pos_labels)
losses.update(loss_mask) losses.update(loss_mask)
return losses return losses
...@@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
x = self.extract_feat(img) x = self.extract_feat(img)
proposal_list = self.simple_test_rpn( proposal_list = self.simple_test_rpn(
x, img_meta, x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
self.test_cfg.rpn) if proposals is None else proposals
det_bboxes, det_labels = self.simple_test_bboxes( det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
......
...@@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module): ...@@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module):
mask_pred = self.conv_logits(x) mask_pred = self.conv_logits(x)
return mask_pred return mask_pred
def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks, def get_target(self, sampling_results, gt_masks, rcnn_train_cfg):
rcnn_train_cfg): pos_proposals = [res.pos_bboxes for res in sampling_results]
pos_assigned_gt_inds = [
res.pos_assigned_gt_inds for res in sampling_results
]
mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
gt_masks, rcnn_train_cfg) gt_masks, rcnn_train_cfg)
return mask_targets return mask_targets
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment