Commit 63288def authored by yhcao6's avatar yhcao6
Browse files

support recursion

parent 57994044
...@@ -7,6 +7,7 @@ from .repeat_dataset import RepeatDataset ...@@ -7,6 +7,7 @@ from .repeat_dataset import RepeatDataset
__all__ = [ __all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann',
'show_ann', 'get_dataset', 'RepeatDataset' 'get_dataset', 'ExtraAugmentation', 'ConcatDataset', 'RepeatDataset',
] ]
...@@ -3,16 +3,15 @@ import numpy as np ...@@ -3,16 +3,15 @@ import numpy as np
class RepeatDataset(object): class RepeatDataset(object):
def __init__(self, dataset, repeat_times): def __init__(self, dataset, times):
self.dataset = dataset self.dataset = dataset
self.repeat_times = repeat_times self.times = times
if hasattr(self.dataset, 'flag'): if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, repeat_times) self.flag = np.tile(self.dataset.flag, times)
self.length = len(self.dataset) * self.repeat_times self._original_length = len(self.dataset)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.dataset[idx % len(self.dataset)] return self.dataset[idx % self._original_length]
def __len__(self): def __len__(self):
return self.length return self.times * self._original_length
...@@ -75,10 +75,8 @@ def show_ann(coco, img, ann_info): ...@@ -75,10 +75,8 @@ def show_ann(coco, img, ann_info):
def get_dataset(data_cfg): def get_dataset(data_cfg):
repeat_times = None
if data_cfg['type'] == 'RepeatDataset': if data_cfg['type'] == 'RepeatDataset':
repeat_times = data_cfg['repeat_times'] return RepeatDataset(get_dataset(data_cfg['type']), data_cfg['times'])
data_cfg = data_cfg['dataset']
if isinstance(data_cfg['ann_file'], (list, tuple)): if isinstance(data_cfg['ann_file'], (list, tuple)):
ann_files = data_cfg['ann_file'] ann_files = data_cfg['ann_file']
...@@ -114,7 +112,4 @@ def get_dataset(data_cfg): ...@@ -114,7 +112,4 @@ def get_dataset(data_cfg):
dset = ConcatDataset(dsets) dset = ConcatDataset(dsets)
else: else:
dset = dsets[0] dset = dsets[0]
if repeat_times is not None:
dset = RepeatDataset(dset, repeat_times)
return dset return dset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment