Commit d6b69bda authored by yhcao6's avatar yhcao6
Browse files

add RepeatDataset

parent a6ee0532
...@@ -3,9 +3,10 @@ from .coco import CocoDataset ...@@ -3,9 +3,10 @@ from .coco import CocoDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann, get_dataset from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset
__all__ = [ __all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale',
'show_ann', 'get_dataset' 'show_ann', 'get_dataset', 'RepeatDataset'
] ]
import numpy as np
class RepeatDataset(object):
def __init__(self, dataset, repeat_times):
self.dataset = dataset
self.repeat_times = repeat_times
if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, repeat_times)
self.length = len(self.dataset) * self.repeat_times
def __getitem__(self, idx):
return self.dataset[idx % len(self.dataset)]
def __len__(self):
return self.length
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment