Commit 6a363603 authored by yhcao6's avatar yhcao6
Browse files

reuse hard mining code

parent 09c3bc4c
......@@ -21,34 +21,39 @@ class OHEMSampler(BaseSampler):
self.bbox_roi_extractor = bbox_roi_extractor
self.bbox_head = bbox_head
def hard_mining(self, gallery, assign_result, num_expected, bboxes, feats):
# hard mining from the gallery.
with torch.no_grad():
rois = bbox2roi([bboxes[gallery]])
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[gallery],
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
_, topk_loss_inds = loss.topk(num_expected)
return gallery[topk_loss_inds]
def _sample_pos(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
"""Hard sample some positive samples."""
# Hard sample some positive samples
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
else:
with torch.no_grad():
rois = bbox2roi([bboxes[pos_inds]])
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss_pos = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[pos_inds],
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
_, topk_loss_pos_inds = loss_pos.topk(num_expected)
return pos_inds[topk_loss_pos_inds]
return self.hard_mining(pos_inds, assign_result, num_expected,
bboxes, feats)
def _sample_neg(self,
assign_result,
......@@ -56,25 +61,12 @@ class OHEMSampler(BaseSampler):
bboxes=None,
feats=None,
**kwargs):
"""Hard sample some negative samples."""
# Hard sample some negative samples
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
else:
with torch.no_grad():
rois = bbox2roi([bboxes[neg_inds]])
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss_neg = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[neg_inds],
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
_, topk_loss_neg_inds = loss_neg.topk(num_expected)
return neg_inds[topk_loss_neg_inds]
return self.hard_mining(neg_inds, assign_result, num_expected,
bboxes, feats)
......@@ -4,7 +4,7 @@ import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder
from mmdet.core import (bbox2roi, bbox2result, build_assigner, build_sampler)
from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment