".github/CODE_OF_CONDUCT.md" did not exist on "d1aac35d68a203955a32bca4635429f620fc08dd"
Commit 96e853f3 authored by yhcao6's avatar yhcao6
Browse files

update interface

parent edc5e18b
......@@ -7,7 +7,8 @@ from .sampling_result import SamplingResult
class BaseSampler(metaclass=ABCMeta):
def __init__(self):
def __init__(self, context):
self.context = context
self.pos_sampler = self
self.neg_sampler = self
......@@ -19,7 +20,11 @@ class BaseSampler(metaclass=ABCMeta):
def _sample_neg(self, assign_result, num_expected, **kwargs):
pass
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None,
def sample(self,
assign_result,
bboxes,
gt_bboxes,
gt_labels=None,
**kwargs):
"""Sample positive and negative bboxes.
......
......@@ -11,15 +11,14 @@ class OHEMSampler(BaseSampler):
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,
bbox_roi_extractor=None,
bbox_head=None):
super(OHEMSampler, self).__init__()
context=None):
super(OHEMSampler, self).__init__(context)
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.bbox_roi_extractor = bbox_roi_extractor
self.bbox_head = bbox_head
self.bbox_roi_extractor = context.bbox_roi_extractor
self.bbox_head = context.bbox_head
def hard_mining(self, inds, num_expected, bboxes, labels, feats):
# hard mining from the gallery.
......
......@@ -10,8 +10,9 @@ class RandomSampler(BaseSampler):
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True):
super(RandomSampler, self).__init__()
add_gt_as_proposals=True,
context=None):
super(RandomSampler, self).__init__(context)
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
......
......@@ -104,9 +104,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
if self.with_bbox or self.with_mask:
bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
bbox_sampler = build_sampler(
self.train_cfg.rcnn.sampler,
bbox_roi_extractor=self.bbox_roi_extractor,
bbox_head=self.bbox_head)
self.train_cfg.rcnn.sampler, context=self)
num_imgs = img.size(0)
assign_results = []
sampling_results = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment