Commit 092b97f6 authored by yhcao6's avatar yhcao6
Browse files

add kwargs to sample_pos, sample_neg

parent f9b31893
...@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler ...@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
class InstanceBalancedPosSampler(RandomSampler): class InstanceBalancedPosSampler(RandomSampler):
def _sample_pos(self, assign_result, num_expected, bboxes=None): def _sample_pos(self, assign_result, num_expected, **kwargs):
pos_inds = torch.nonzero(assign_result.gt_inds > 0) pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0: if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1) pos_inds = pos_inds.squeeze(1)
......
...@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler): ...@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
self.hard_thr = hard_thr self.hard_thr = hard_thr
self.hard_fraction = hard_fraction self.hard_fraction = hard_fraction
def _sample_neg(self, assign_result, num_expected, bboxes=None): def _sample_neg(self, assign_result, num_expected, **kwargs):
neg_inds = torch.nonzero(assign_result.gt_inds == 0) neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0: if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1) neg_inds = neg_inds.squeeze(1)
......
...@@ -22,7 +22,7 @@ class OHEMSampler(BaseSampler): ...@@ -22,7 +22,7 @@ class OHEMSampler(BaseSampler):
self.bbox_head = bbox_head self.bbox_head = bbox_head
def _sample_pos(self, assign_result, num_expected, bboxes=None, def _sample_pos(self, assign_result, num_expected, bboxes=None,
feats=None): feats=None, **kwargs):
"""Hard sample some positive samples.""" """Hard sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0) pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0: if pos_inds.numel() != 0:
...@@ -35,7 +35,7 @@ class OHEMSampler(BaseSampler): ...@@ -35,7 +35,7 @@ class OHEMSampler(BaseSampler):
bbox_feats = self.bbox_roi_extractor( bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois) feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats) cls_score, _ = self.bbox_head(bbox_feats)
loss_all = self.bbox_head.loss( loss_pos = self.bbox_head.loss(
cls_score=cls_score, cls_score=cls_score,
bbox_pred=None, bbox_pred=None,
labels=assign_result.labels[pos_inds], labels=assign_result.labels[pos_inds],
...@@ -43,11 +43,11 @@ class OHEMSampler(BaseSampler): ...@@ -43,11 +43,11 @@ class OHEMSampler(BaseSampler):
bbox_targets=None, bbox_targets=None,
bbox_weights=None, bbox_weights=None,
reduction='none')['loss_cls'] reduction='none')['loss_cls']
_, topk_loss_pos_inds = loss_all.topk(num_expected) _, topk_loss_pos_inds = loss_pos.topk(num_expected)
return pos_inds[topk_loss_pos_inds] return pos_inds[topk_loss_pos_inds]
def _sample_neg(self, assign_result, num_expected, bboxes=None, def _sample_neg(self, assign_result, num_expected, bboxes=None,
feats=None): feats=None, **kwargs):
"""Hard sample some negative samples.""" """Hard sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0) neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0: if neg_inds.numel() != 0:
...@@ -60,7 +60,7 @@ class OHEMSampler(BaseSampler): ...@@ -60,7 +60,7 @@ class OHEMSampler(BaseSampler):
bbox_feats = self.bbox_roi_extractor( bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois) feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats) cls_score, _ = self.bbox_head(bbox_feats)
loss_all = self.bbox_head.loss( loss_neg = self.bbox_head.loss(
cls_score=cls_score, cls_score=cls_score,
bbox_pred=None, bbox_pred=None,
labels=assign_result.labels[neg_inds], labels=assign_result.labels[neg_inds],
...@@ -68,5 +68,5 @@ class OHEMSampler(BaseSampler): ...@@ -68,5 +68,5 @@ class OHEMSampler(BaseSampler):
bbox_targets=None, bbox_targets=None,
bbox_weights=None, bbox_weights=None,
reduction='none')['loss_cls'] reduction='none')['loss_cls']
_, topk_loss_neg_inds = loss_all.topk(num_expected) _, topk_loss_neg_inds = loss_neg.topk(num_expected)
return neg_inds[topk_loss_neg_inds] return neg_inds[topk_loss_neg_inds]
...@@ -34,7 +34,7 @@ class RandomSampler(BaseSampler): ...@@ -34,7 +34,7 @@ class RandomSampler(BaseSampler):
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
return gallery[rand_inds] return gallery[rand_inds]
def _sample_pos(self, assign_result, num_expected, bboxes=None): def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Randomly sample some positive samples.""" """Randomly sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0) pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0: if pos_inds.numel() != 0:
...@@ -44,7 +44,7 @@ class RandomSampler(BaseSampler): ...@@ -44,7 +44,7 @@ class RandomSampler(BaseSampler):
else: else:
return self.random_choice(pos_inds, num_expected) return self.random_choice(pos_inds, num_expected)
def _sample_neg(self, assign_result, num_expected, bboxes=None): def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Randomly sample some negative samples.""" """Randomly sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0) neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0: if neg_inds.numel() != 0:
......
...@@ -114,17 +114,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -114,17 +114,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
assign_result = bbox_assigner.assign( assign_result = bbox_assigner.assign(
proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
gt_labels[i]) gt_labels[i])
if self.train_cfg.rcnn.sampler.type == 'OHEMSampler':
sampling_result = bbox_sampler.sample( sampling_result = bbox_sampler.sample(
assign_result, assign_result,
proposal_list[i], proposal_list[i],
gt_bboxes[i], gt_bboxes[i],
gt_labels[i], gt_labels[i],
feats=[xx[i][None] for xx in x]) feats=[xx[i][None] for xx in x])
else:
sampling_result = bbox_sampler.sample(
assign_result, proposal_list[i], gt_bboxes[i],
gt_labels[i])
assign_results.append(assign_result) assign_results.append(assign_result)
sampling_results.append(sampling_result) sampling_results.append(sampling_result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment