Commit 5f6bbcf4 authored by yhcao6's avatar yhcao6
Browse files

reconstruct interfaces of sampler

parent 96e853f3
......@@ -7,8 +7,16 @@ from .sampling_result import SamplingResult
class BaseSampler(metaclass=ABCMeta):
def __init__(self, context):
self.context = context
def __init__(self,
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_sampler = self
self.neg_sampler = self
......
from .random_sampler import RandomSampler
from .base_sampler import BaseSampler
from ..assign_sampling import build_sampler
class CombinedSampler(RandomSampler):
class CombinedSampler(BaseSampler):
def __init__(self, num, pos_fraction, pos_sampler, neg_sampler, **kwargs):
super(CombinedSampler, self).__init__(num, pos_fraction, **kwargs)
default_args = dict(num=num, pos_fraction=pos_fraction)
default_args.update(kwargs)
self.pos_sampler = build_sampler(
pos_sampler, default_args=default_args)
self.neg_sampler = build_sampler(
neg_sampler, default_args=default_args)
def __init__(self, pos_sampler, neg_sampler, **kwargs):
super(CombinedSampler, self).__init__(**kwargs)
self.pos_sampler = build_sampler(pos_sampler, **kwargs)
self.neg_sampler = build_sampler(neg_sampler, **kwargs)
def _sample_pos(self, **kwargs):
raise NotImplementedError
def _sample_neg(self, **kwargs):
raise NotImplementedError
......@@ -9,14 +9,12 @@ class OHEMSampler(BaseSampler):
def __init__(self,
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,
context=None):
super(OHEMSampler, self).__init__(context)
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
neg_pos_ub,
add_gt_as_proposals,
context,
**kwargs):
super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals, **kwargs)
self.bbox_roi_extractor = context.bbox_roi_extractor
self.bbox_head = context.bbox_head
......
......@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult
class PseudoSampler(BaseSampler):
def __init__(self):
def __init__(self, **kwargs):
pass
def _sample_pos(self):
def _sample_pos(self, **kwargs):
raise NotImplementedError
def _sample_neg(self):
def _sample_neg(self, **kwargs):
raise NotImplementedError
def sample(self, assign_result, bboxes, gt_bboxes):
def sample(self, assign_result, bboxes, gt_bboxes, **kwargs):
pos_inds = torch.nonzero(
assign_result.gt_inds > 0).squeeze(-1).unique()
neg_inds = torch.nonzero(
......
......@@ -11,8 +11,9 @@ class RandomSampler(BaseSampler):
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,
context=None):
super(RandomSampler, self).__init__(context)
**kwargs):
super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals, **kwargs)
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment