Commit 83644927 authored by yhcao6's avatar yhcao6
Browse files

resort formal parameters order of hard mining

parent d2cde908
...@@ -21,23 +21,23 @@ class OHEMSampler(BaseSampler): ...@@ -21,23 +21,23 @@ class OHEMSampler(BaseSampler):
self.bbox_roi_extractor = bbox_roi_extractor self.bbox_roi_extractor = bbox_roi_extractor
self.bbox_head = bbox_head self.bbox_head = bbox_head
def hard_mining(self, gallery, assign_result, num_expected, bboxes, feats): def hard_mining(self, inds, num_expected, bboxes, labels, feats):
# hard mining from the gallery. # hard mining from the gallery.
with torch.no_grad(): with torch.no_grad():
rois = bbox2roi([bboxes[gallery]]) rois = bbox2roi([bboxes])
bbox_feats = self.bbox_roi_extractor( bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois) feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats) cls_score, _ = self.bbox_head(bbox_feats)
loss = self.bbox_head.loss( loss = self.bbox_head.loss(
cls_score=cls_score, cls_score=cls_score,
bbox_pred=None, bbox_pred=None,
labels=assign_result.labels[gallery], labels=labels,
label_weights=cls_score.new_ones(cls_score.size(0)), label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None, bbox_targets=None,
bbox_weights=None, bbox_weights=None,
reduce=False)['loss_cls'] reduce=False)['loss_cls']
_, topk_loss_inds = loss.topk(num_expected) _, topk_loss_inds = loss.topk(num_expected)
return gallery[topk_loss_inds] return inds[topk_loss_inds]
def _sample_pos(self, def _sample_pos(self,
assign_result, assign_result,
...@@ -52,8 +52,8 @@ class OHEMSampler(BaseSampler): ...@@ -52,8 +52,8 @@ class OHEMSampler(BaseSampler):
if pos_inds.numel() <= num_expected: if pos_inds.numel() <= num_expected:
return pos_inds return pos_inds
else: else:
return self.hard_mining(pos_inds, assign_result, num_expected, return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
bboxes, feats) assign_result.labels[pos_inds], feats)
def _sample_neg(self, def _sample_neg(self,
assign_result, assign_result,
...@@ -68,5 +68,5 @@ class OHEMSampler(BaseSampler): ...@@ -68,5 +68,5 @@ class OHEMSampler(BaseSampler):
if len(neg_inds) <= num_expected: if len(neg_inds) <= num_expected:
return neg_inds return neg_inds
else: else:
return self.hard_mining(neg_inds, assign_result, num_expected, return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
bboxes, feats) assign_result.labels[neg_inds], feats)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment