Commit 6a363603 authored by yhcao6's avatar yhcao6
Browse files

reuse hard mining code

parent 09c3bc4c
...@@ -21,34 +21,39 @@ class OHEMSampler(BaseSampler): ...@@ -21,34 +21,39 @@ class OHEMSampler(BaseSampler):
self.bbox_roi_extractor = bbox_roi_extractor self.bbox_roi_extractor = bbox_roi_extractor
self.bbox_head = bbox_head self.bbox_head = bbox_head
def hard_mining(self, gallery, assign_result, num_expected, bboxes, feats):
# hard mining from the gallery.
with torch.no_grad():
rois = bbox2roi([bboxes[gallery]])
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[gallery],
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
_, topk_loss_inds = loss.topk(num_expected)
return gallery[topk_loss_inds]
def _sample_pos(self, def _sample_pos(self,
assign_result, assign_result,
num_expected, num_expected,
bboxes=None, bboxes=None,
feats=None, feats=None,
**kwargs): **kwargs):
"""Hard sample some positive samples.""" # Hard sample some positive samples
pos_inds = torch.nonzero(assign_result.gt_inds > 0) pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0: if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1) pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected: if pos_inds.numel() <= num_expected:
return pos_inds return pos_inds
else: else:
with torch.no_grad(): return self.hard_mining(pos_inds, assign_result, num_expected,
rois = bbox2roi([bboxes[pos_inds]]) bboxes, feats)
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss_pos = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[pos_inds],
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
_, topk_loss_pos_inds = loss_pos.topk(num_expected)
return pos_inds[topk_loss_pos_inds]
def _sample_neg(self, def _sample_neg(self,
assign_result, assign_result,
...@@ -56,25 +61,12 @@ class OHEMSampler(BaseSampler): ...@@ -56,25 +61,12 @@ class OHEMSampler(BaseSampler):
bboxes=None, bboxes=None,
feats=None, feats=None,
**kwargs): **kwargs):
"""Hard sample some negative samples.""" # Hard sample some negative samples
neg_inds = torch.nonzero(assign_result.gt_inds == 0) neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0: if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1) neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected: if len(neg_inds) <= num_expected:
return neg_inds return neg_inds
else: else:
with torch.no_grad(): return self.hard_mining(neg_inds, assign_result, num_expected,
rois = bbox2roi([bboxes[neg_inds]]) bboxes, feats)
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss_neg = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels[neg_inds],
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
_, topk_loss_neg_inds = loss_neg.topk(num_expected)
return neg_inds[topk_loss_neg_inds]
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from mmdet.core import (bbox2roi, bbox2result, build_assigner, build_sampler) from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment