repeat_dataset.py 479 Bytes
Newer Older
yhcao6's avatar
yhcao6 committed
1
2
3
4
5
import numpy as np


class RepeatDataset(object):

yhcao6's avatar
yhcao6 committed
6
    def __init__(self, dataset, times):
yhcao6's avatar
yhcao6 committed
7
        self.dataset = dataset
yhcao6's avatar
yhcao6 committed
8
        self.times = times
9
        self.CLASSES = dataset.CLASSES
yhcao6's avatar
yhcao6 committed
10
        if hasattr(self.dataset, 'flag'):
yhcao6's avatar
yhcao6 committed
11
            self.flag = np.tile(self.dataset.flag, times)
12
13

        self._ori_len = len(self.dataset)
yhcao6's avatar
yhcao6 committed
14
15

    def __getitem__(self, idx):
16
        return self.dataset[idx % self._ori_len]
yhcao6's avatar
yhcao6 committed
17
18

    def __len__(self):
19
        return self.times * self._ori_len