Commit 5f6bbcf4 authored by yhcao6's avatar yhcao6
Browse files

reconstruct interfaces of sampler

parent 96e853f3
...@@ -7,8 +7,16 @@ from .sampling_result import SamplingResult ...@@ -7,8 +7,16 @@ from .sampling_result import SamplingResult
class BaseSampler(metaclass=ABCMeta): class BaseSampler(metaclass=ABCMeta):
def __init__(self, context): def __init__(self,
self.context = context num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_sampler = self self.pos_sampler = self
self.neg_sampler = self self.neg_sampler = self
......
from .random_sampler import RandomSampler from .base_sampler import BaseSampler
from ..assign_sampling import build_sampler from ..assign_sampling import build_sampler
class CombinedSampler(RandomSampler): class CombinedSampler(BaseSampler):
def __init__(self, num, pos_fraction, pos_sampler, neg_sampler, **kwargs): def __init__(self, pos_sampler, neg_sampler, **kwargs):
super(CombinedSampler, self).__init__(num, pos_fraction, **kwargs) super(CombinedSampler, self).__init__(**kwargs)
default_args = dict(num=num, pos_fraction=pos_fraction) self.pos_sampler = build_sampler(pos_sampler, **kwargs)
default_args.update(kwargs) self.neg_sampler = build_sampler(neg_sampler, **kwargs)
self.pos_sampler = build_sampler(
pos_sampler, default_args=default_args) def _sample_pos(self, **kwargs):
self.neg_sampler = build_sampler( raise NotImplementedError
neg_sampler, default_args=default_args)
def _sample_neg(self, **kwargs):
raise NotImplementedError
...@@ -9,14 +9,12 @@ class OHEMSampler(BaseSampler): ...@@ -9,14 +9,12 @@ class OHEMSampler(BaseSampler):
def __init__(self, def __init__(self,
num, num,
pos_fraction, pos_fraction,
neg_pos_ub=-1, neg_pos_ub,
add_gt_as_proposals=True, add_gt_as_proposals,
context=None): context,
super(OHEMSampler, self).__init__(context) **kwargs):
self.num = num super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
self.pos_fraction = pos_fraction add_gt_as_proposals, **kwargs)
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.bbox_roi_extractor = context.bbox_roi_extractor self.bbox_roi_extractor = context.bbox_roi_extractor
self.bbox_head = context.bbox_head self.bbox_head = context.bbox_head
......
...@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult ...@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult
class PseudoSampler(BaseSampler): class PseudoSampler(BaseSampler):
def __init__(self): def __init__(self, **kwargs):
pass pass
def _sample_pos(self): def _sample_pos(self, **kwargs):
raise NotImplementedError raise NotImplementedError
def _sample_neg(self): def _sample_neg(self, **kwargs):
raise NotImplementedError raise NotImplementedError
def sample(self, assign_result, bboxes, gt_bboxes): def sample(self, assign_result, bboxes, gt_bboxes, **kwargs):
pos_inds = torch.nonzero( pos_inds = torch.nonzero(
assign_result.gt_inds > 0).squeeze(-1).unique() assign_result.gt_inds > 0).squeeze(-1).unique()
neg_inds = torch.nonzero( neg_inds = torch.nonzero(
......
...@@ -11,8 +11,9 @@ class RandomSampler(BaseSampler): ...@@ -11,8 +11,9 @@ class RandomSampler(BaseSampler):
pos_fraction, pos_fraction,
neg_pos_ub=-1, neg_pos_ub=-1,
add_gt_as_proposals=True, add_gt_as_proposals=True,
context=None): **kwargs):
super(RandomSampler, self).__init__(context) super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals, **kwargs)
self.num = num self.num = num
self.pos_fraction = pos_fraction self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub self.neg_pos_ub = neg_pos_ub
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment