Commit c29ebbc5 authored by yhcao6's avatar yhcao6
Browse files

fix

parent 5f6bbcf4
...@@ -58,9 +58,8 @@ class BaseSampler(metaclass=ABCMeta): ...@@ -58,9 +58,8 @@ class BaseSampler(metaclass=ABCMeta):
gt_flags = torch.cat([gt_ones, gt_flags]) gt_flags = torch.cat([gt_ones, gt_flags])
num_expected_pos = int(self.num * self.pos_fraction) num_expected_pos = int(self.num * self.pos_fraction)
kwargs.update(dict(bboxes=bboxes)) pos_inds = self.pos_sampler._sample_pos(
pos_inds = self.pos_sampler._sample_pos(assign_result, assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
num_expected_pos, **kwargs)
# We found that sampled indices have duplicated items occasionally. # We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch) # (may be a bug of PyTorch)
pos_inds = pos_inds.unique() pos_inds = pos_inds.unique()
...@@ -71,8 +70,8 @@ class BaseSampler(metaclass=ABCMeta): ...@@ -71,8 +70,8 @@ class BaseSampler(metaclass=ABCMeta):
neg_upper_bound = int(self.neg_pos_ub * _pos) neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound: if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg(assign_result, neg_inds = self.neg_sampler._sample_neg(
num_expected_neg, **kwargs) assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique() neg_inds = neg_inds.unique()
return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
......
...@@ -9,17 +9,16 @@ class OHEMSampler(BaseSampler): ...@@ -9,17 +9,16 @@ class OHEMSampler(BaseSampler):
def __init__(self, def __init__(self,
num, num,
pos_fraction, pos_fraction,
neg_pos_ub,
add_gt_as_proposals,
context, context,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs): **kwargs):
super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub, super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals, **kwargs) add_gt_as_proposals)
self.bbox_roi_extractor = context.bbox_roi_extractor self.bbox_roi_extractor = context.bbox_roi_extractor
self.bbox_head = context.bbox_head self.bbox_head = context.bbox_head
def hard_mining(self, inds, num_expected, bboxes, labels, feats): def hard_mining(self, inds, num_expected, bboxes, labels, feats):
# hard mining from the gallery.
with torch.no_grad(): with torch.no_grad():
rois = bbox2roi([bboxes]) rois = bbox2roi([bboxes])
bbox_feats = self.bbox_roi_extractor( bbox_feats = self.bbox_roi_extractor(
...@@ -42,7 +41,7 @@ class OHEMSampler(BaseSampler): ...@@ -42,7 +41,7 @@ class OHEMSampler(BaseSampler):
bboxes=None, bboxes=None,
feats=None, feats=None,
**kwargs): **kwargs):
# Hard sample some positive samples # Sample some hard positive samples
pos_inds = torch.nonzero(assign_result.gt_inds > 0) pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0: if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1) pos_inds = pos_inds.squeeze(1)
...@@ -58,7 +57,7 @@ class OHEMSampler(BaseSampler): ...@@ -58,7 +57,7 @@ class OHEMSampler(BaseSampler):
bboxes=None, bboxes=None,
feats=None, feats=None,
**kwargs): **kwargs):
# Hard sample some negative samples # Sample some hard negative samples
neg_inds = torch.nonzero(assign_result.gt_inds == 0) neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0: if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1) neg_inds = neg_inds.squeeze(1)
......
...@@ -13,11 +13,7 @@ class RandomSampler(BaseSampler): ...@@ -13,11 +13,7 @@ class RandomSampler(BaseSampler):
add_gt_as_proposals=True, add_gt_as_proposals=True,
**kwargs): **kwargs):
super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub, super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals, **kwargs) add_gt_as_proposals)
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
@staticmethod @staticmethod
def random_choice(gallery, num): def random_choice(gallery, num):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment