"configs/vscode:/vscode.git/clone" did not exist on "7e221c1472ad0cf04299825d8661fd8dfe22acfc"
Commit 96e853f3 authored by yhcao6's avatar yhcao6
Browse files

update interface

parent edc5e18b
...@@ -7,7 +7,8 @@ from .sampling_result import SamplingResult ...@@ -7,7 +7,8 @@ from .sampling_result import SamplingResult
class BaseSampler(metaclass=ABCMeta): class BaseSampler(metaclass=ABCMeta):
def __init__(self): def __init__(self, context):
self.context = context
self.pos_sampler = self self.pos_sampler = self
self.neg_sampler = self self.neg_sampler = self
...@@ -19,7 +20,11 @@ class BaseSampler(metaclass=ABCMeta): ...@@ -19,7 +20,11 @@ class BaseSampler(metaclass=ABCMeta):
def _sample_neg(self, assign_result, num_expected, **kwargs): def _sample_neg(self, assign_result, num_expected, **kwargs):
pass pass
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, def sample(self,
assign_result,
bboxes,
gt_bboxes,
gt_labels=None,
**kwargs): **kwargs):
"""Sample positive and negative bboxes. """Sample positive and negative bboxes.
......
...@@ -11,15 +11,14 @@ class OHEMSampler(BaseSampler): ...@@ -11,15 +11,14 @@ class OHEMSampler(BaseSampler):
pos_fraction, pos_fraction,
neg_pos_ub=-1, neg_pos_ub=-1,
add_gt_as_proposals=True, add_gt_as_proposals=True,
bbox_roi_extractor=None, context=None):
bbox_head=None): super(OHEMSampler, self).__init__(context)
super(OHEMSampler, self).__init__()
self.num = num self.num = num
self.pos_fraction = pos_fraction self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals self.add_gt_as_proposals = add_gt_as_proposals
self.bbox_roi_extractor = bbox_roi_extractor self.bbox_roi_extractor = context.bbox_roi_extractor
self.bbox_head = bbox_head self.bbox_head = context.bbox_head
def hard_mining(self, inds, num_expected, bboxes, labels, feats): def hard_mining(self, inds, num_expected, bboxes, labels, feats):
# hard mining from the gallery. # hard mining from the gallery.
......
...@@ -10,8 +10,9 @@ class RandomSampler(BaseSampler): ...@@ -10,8 +10,9 @@ class RandomSampler(BaseSampler):
num, num,
pos_fraction, pos_fraction,
neg_pos_ub=-1, neg_pos_ub=-1,
add_gt_as_proposals=True): add_gt_as_proposals=True,
super(RandomSampler, self).__init__() context=None):
super(RandomSampler, self).__init__(context)
self.num = num self.num = num
self.pos_fraction = pos_fraction self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub self.neg_pos_ub = neg_pos_ub
......
...@@ -104,9 +104,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -104,9 +104,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
if self.with_bbox or self.with_mask: if self.with_bbox or self.with_mask:
bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
bbox_sampler = build_sampler( bbox_sampler = build_sampler(
self.train_cfg.rcnn.sampler, self.train_cfg.rcnn.sampler, context=self)
bbox_roi_extractor=self.bbox_roi_extractor,
bbox_head=self.bbox_head)
num_imgs = img.size(0) num_imgs = img.size(0)
assign_results = [] assign_results = []
sampling_results = [] sampling_results = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment