Commit 092b97f6 authored by yhcao6's avatar yhcao6
Browse files

add kwargs to sample_pos, sample_neg

parent f9b31893
......@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
class InstanceBalancedPosSampler(RandomSampler):
def _sample_pos(self, assign_result, num_expected, bboxes=None):
def _sample_pos(self, assign_result, num_expected, **kwargs):
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
......
......@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
self.hard_thr = hard_thr
self.hard_fraction = hard_fraction
def _sample_neg(self, assign_result, num_expected, bboxes=None):
def _sample_neg(self, assign_result, num_expected, **kwargs):
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
......
......@@ -22,7 +22,7 @@ class OHEMSampler(BaseSampler):
self.bbox_head = bbox_head
def _sample_pos(self, assign_result, num_expected, bboxes=None,
feats=None):
feats=None, **kwargs):
"""Hard sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
......@@ -35,7 +35,7 @@ class OHEMSampler(BaseSampler):
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss_all = self.bbox_head.loss(
loss_pos = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[pos_inds],
......@@ -43,11 +43,11 @@ class OHEMSampler(BaseSampler):
bbox_targets=None,
bbox_weights=None,
reduction='none')['loss_cls']
_, topk_loss_pos_inds = loss_all.topk(num_expected)
_, topk_loss_pos_inds = loss_pos.topk(num_expected)
return pos_inds[topk_loss_pos_inds]
def _sample_neg(self, assign_result, num_expected, bboxes=None,
feats=None):
feats=None, **kwargs):
"""Hard sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
......@@ -60,7 +60,7 @@ class OHEMSampler(BaseSampler):
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss_all = self.bbox_head.loss(
loss_neg = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[neg_inds],
......@@ -68,5 +68,5 @@ class OHEMSampler(BaseSampler):
bbox_targets=None,
bbox_weights=None,
reduction='none')['loss_cls']
_, topk_loss_neg_inds = loss_all.topk(num_expected)
_, topk_loss_neg_inds = loss_neg.topk(num_expected)
return neg_inds[topk_loss_neg_inds]
......@@ -34,7 +34,7 @@ class RandomSampler(BaseSampler):
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
return gallery[rand_inds]
def _sample_pos(self, assign_result, num_expected, bboxes=None):
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Randomly sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
......@@ -44,7 +44,7 @@ class RandomSampler(BaseSampler):
else:
return self.random_choice(pos_inds, num_expected)
def _sample_neg(self, assign_result, num_expected, bboxes=None):
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Randomly sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
......
......@@ -114,17 +114,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
assign_result = bbox_assigner.assign(
proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
gt_labels[i])
if self.train_cfg.rcnn.sampler.type == 'OHEMSampler':
sampling_result = bbox_sampler.sample(
assign_result,
proposal_list[i],
gt_bboxes[i],
gt_labels[i],
feats=[xx[i][None] for xx in x])
else:
sampling_result = bbox_sampler.sample(
assign_result, proposal_list[i], gt_bboxes[i],
gt_labels[i])
sampling_result = bbox_sampler.sample(
assign_result,
proposal_list[i],
gt_bboxes[i],
gt_labels[i],
feats=[xx[i][None] for xx in x])
assign_results.append(assign_result)
sampling_results.append(sampling_result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment