Commit f9b31893 authored by yhcao6's avatar yhcao6
Browse files

refactor

parent 763153dc
......@@ -12,14 +12,15 @@ class BaseSampler(metaclass=ABCMeta):
self.neg_sampler = self
@abstractmethod
def _sample_pos(self, assign_result, num_expected):
def _sample_pos(self, assign_result, num_expected, **kwargs):
pass
@abstractmethod
def _sample_neg(self, assign_result, num_expected):
def _sample_neg(self, assign_result, num_expected, **kwargs):
pass
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None):
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None,
**kwargs):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
......@@ -44,8 +45,9 @@ class BaseSampler(metaclass=ABCMeta):
gt_flags = torch.cat([gt_ones, gt_flags])
num_expected_pos = int(self.num * self.pos_fraction)
kwargs.update(dict(bboxes=bboxes))
pos_inds = self.pos_sampler._sample_pos(assign_result,
num_expected_pos)
num_expected_pos, **kwargs)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds = pos_inds.unique()
......@@ -57,7 +59,7 @@ class BaseSampler(metaclass=ABCMeta):
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg(assign_result,
num_expected_neg)
num_expected_neg, **kwargs)
neg_inds = neg_inds.unique()
return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
......
......@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
class InstanceBalancedPosSampler(RandomSampler):
def _sample_pos(self, assign_result, num_expected):
def _sample_pos(self, assign_result, num_expected, bboxes=None):
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
......
......@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
self.hard_thr = hard_thr
self.hard_fraction = hard_fraction
def _sample_neg(self, assign_result, num_expected):
def _sample_neg(self, assign_result, num_expected, bboxes=None):
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
......
......@@ -2,7 +2,6 @@ import torch
from .base_sampler import BaseSampler
from ..transforms import bbox2roi
from .sampling_result import SamplingResult
class OHEMSampler(BaseSampler):
......@@ -11,14 +10,19 @@ class OHEMSampler(BaseSampler):
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,):
add_gt_as_proposals=True,
bbox_roi_extractor=None,
bbox_head=None):
super(OHEMSampler, self).__init__()
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.bbox_roi_extractor = bbox_roi_extractor
self.bbox_head = bbox_head
def _sample_pos(self, assign_result, num_expected, loss_all):
def _sample_pos(self, assign_result, num_expected, bboxes=None,
feats=None):
"""Hard sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
......@@ -26,10 +30,24 @@ class OHEMSampler(BaseSampler):
if pos_inds.numel() <= num_expected:
return pos_inds
else:
_, topk_loss_pos_inds = loss_all[pos_inds].topk(num_expected)
with torch.no_grad():
rois = bbox2roi([bboxes[pos_inds]])
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss_all = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[pos_inds],
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction='none')['loss_cls']
_, topk_loss_pos_inds = loss_all.topk(num_expected)
return pos_inds[topk_loss_pos_inds]
def _sample_neg(self, assign_result, num_expected, loss_all):
def _sample_neg(self, assign_result, num_expected, bboxes=None,
feats=None):
"""Hard sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
......@@ -37,63 +55,18 @@ class OHEMSampler(BaseSampler):
if len(neg_inds) <= num_expected:
return neg_inds
else:
_, topk_loss_neg_inds = loss_all[neg_inds].topk(num_expected)
with torch.no_grad():
rois = bbox2roi([bboxes[neg_inds]])
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss_all = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[neg_inds],
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction='none')['loss_cls']
_, topk_loss_neg_inds = loss_all.topk(num_expected)
return neg_inds[topk_loss_neg_inds]
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None,
feats=None, bbox_roi_extractor=None, bbox_head=None):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
"""
bboxes = bboxes[:, :4]
gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
if self.add_gt_as_proposals:
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
gt_flags = torch.cat([gt_ones, gt_flags])
# calculate loss of all samples used for hard mining
with torch.no_grad():
rois = bbox2roi([bboxes])
bbox_feats = bbox_roi_extractor(
feats[:bbox_roi_extractor.num_inputs], rois)
cls_score, _ = bbox_head(bbox_feats)
loss_all = bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels,
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction='none')['loss_cls']
num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self._sample_pos(assign_result, num_expected_pos, loss_all)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_expected_neg = self.num - num_sampled_pos
if self.neg_pos_ub >= 0:
_pos = max(1, num_sampled_pos)
neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
neg_inds = self._sample_neg(assign_result, num_expected_neg, loss_all)
neg_inds = neg_inds.unique()
return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
......@@ -34,7 +34,7 @@ class RandomSampler(BaseSampler):
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
return gallery[rand_inds]
def _sample_pos(self, assign_result, num_expected):
def _sample_pos(self, assign_result, num_expected, bboxes=None):
"""Randomly sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
......@@ -44,7 +44,7 @@ class RandomSampler(BaseSampler):
else:
return self.random_choice(pos_inds, num_expected)
def _sample_neg(self, assign_result, num_expected):
def _sample_neg(self, assign_result, num_expected, bboxes=None):
"""Randomly sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
......
......@@ -103,7 +103,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# assign gts and sample proposals
if self.with_bbox or self.with_mask:
bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler)
bbox_sampler = build_sampler(
self.train_cfg.rcnn.sampler,
dict(bbox_roi_extractor=self.bbox_roi_extractor,
bbox_head=self.bbox_head))
num_imgs = img.size(0)
assign_results = []
sampling_results = []
......@@ -117,9 +120,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
proposal_list[i],
gt_bboxes[i],
gt_labels[i],
[xx[i][None] for xx in x],
self.bbox_roi_extractor,
self.bbox_head)
feats=[xx[i][None] for xx in x])
else:
sampling_result = bbox_sampler.sample(
assign_result, proposal_list[i], gt_bboxes[i],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment