repeat_dataset.py 475 Bytes
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import numpy as np


class RepeatDataset(object):

    def __init__(self, dataset, repeat_times):
        self.dataset = dataset
        self.repeat_times = repeat_times
        if hasattr(self.dataset, 'flag'):
            self.flag = np.tile(self.dataset.flag, repeat_times)
        self.length = len(self.dataset) * self.repeat_times

    def __getitem__(self, idx):
        return self.dataset[idx % len(self.dataset)]

    def __len__(self):
        return self.length