repeat_dataset.py 463 Bytes
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
3
4
5
import numpy as np


class RepeatDataset(object):

yhcao6's avatar
yhcao6 committed
6
    def __init__(self, dataset, times):
yhcao6's avatar
yhcao6 committed
7
        self.dataset = dataset
yhcao6's avatar
yhcao6 committed
8
        self.times = times
yhcao6's avatar
yhcao6 committed
9
        if hasattr(self.dataset, 'flag'):
yhcao6's avatar
yhcao6 committed
10
11
            self.flag = np.tile(self.dataset.flag, times)
        self._original_length = len(self.dataset)
yhcao6's avatar
yhcao6 committed
12
13

    def __getitem__(self, idx):
yhcao6's avatar
yhcao6 committed
14
        return self.dataset[idx % self._original_length]
yhcao6's avatar
yhcao6 committed
15
16

    def __len__(self):
yhcao6's avatar
yhcao6 committed
17
        return self.times * self._original_length