Commit 83644927 authored by yhcao6's avatar yhcao6
Browse files

resort formal parameters order of hard mining

parent d2cde908
......@@ -21,23 +21,23 @@ class OHEMSampler(BaseSampler):
self.bbox_roi_extractor = bbox_roi_extractor
self.bbox_head = bbox_head
def hard_mining(self, gallery, assign_result, num_expected, bboxes, feats):
def hard_mining(self, inds, num_expected, bboxes, labels, feats):
# hard mining from the gallery.
with torch.no_grad():
rois = bbox2roi([bboxes[gallery]])
rois = bbox2roi([bboxes])
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[gallery],
labels=labels,
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
_, topk_loss_inds = loss.topk(num_expected)
return gallery[topk_loss_inds]
return inds[topk_loss_inds]
def _sample_pos(self,
assign_result,
......@@ -52,8 +52,8 @@ class OHEMSampler(BaseSampler):
if pos_inds.numel() <= num_expected:
return pos_inds
else:
return self.hard_mining(pos_inds, assign_result, num_expected,
bboxes, feats)
return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
assign_result.labels[pos_inds], feats)
def _sample_neg(self,
assign_result,
......@@ -68,5 +68,5 @@ class OHEMSampler(BaseSampler):
if len(neg_inds) <= num_expected:
return neg_inds
else:
return self.hard_mining(neg_inds, assign_result, num_expected,
bboxes, feats)
return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
assign_result.labels[neg_inds], feats)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment