Commit c29ebbc5 authored by yhcao6's avatar yhcao6
Browse files

fix

parent 5f6bbcf4
......@@ -58,9 +58,8 @@ class BaseSampler(metaclass=ABCMeta):
gt_flags = torch.cat([gt_ones, gt_flags])
num_expected_pos = int(self.num * self.pos_fraction)
kwargs.update(dict(bboxes=bboxes))
pos_inds = self.pos_sampler._sample_pos(assign_result,
num_expected_pos, **kwargs)
pos_inds = self.pos_sampler._sample_pos(
assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds = pos_inds.unique()
......@@ -71,8 +70,8 @@ class BaseSampler(metaclass=ABCMeta):
neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg(assign_result,
num_expected_neg, **kwargs)
neg_inds = self.neg_sampler._sample_neg(
assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique()
return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
......
......@@ -9,17 +9,16 @@ class OHEMSampler(BaseSampler):
def __init__(self,
num,
pos_fraction,
neg_pos_ub,
add_gt_as_proposals,
context,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals, **kwargs)
add_gt_as_proposals)
self.bbox_roi_extractor = context.bbox_roi_extractor
self.bbox_head = context.bbox_head
def hard_mining(self, inds, num_expected, bboxes, labels, feats):
# hard mining from the gallery.
with torch.no_grad():
rois = bbox2roi([bboxes])
bbox_feats = self.bbox_roi_extractor(
......@@ -42,7 +41,7 @@ class OHEMSampler(BaseSampler):
bboxes=None,
feats=None,
**kwargs):
# Hard sample some positive samples
# Sample some hard positive samples
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
......@@ -58,7 +57,7 @@ class OHEMSampler(BaseSampler):
bboxes=None,
feats=None,
**kwargs):
# Hard sample some negative samples
# Sample some hard negative samples
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
......
......@@ -13,11 +13,7 @@ class RandomSampler(BaseSampler):
add_gt_as_proposals=True,
**kwargs):
super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals, **kwargs)
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
add_gt_as_proposals)
@staticmethod
def random_choice(gallery, num):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment