concat_dataset.py 675 Bytes
Newer Older
1
2
3
4
5
6
import numpy as np
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset


class ConcatDataset(_ConcatDataset):
    """
wangg12's avatar
wangg12 committed
7
    Same as torch.utils.data.dataset.ConcatDataset, but
8
9
10
11
12
13
14
    concat the group flag for image aspect ratio.
    """
    def __init__(self, datasets):
        """
        flag: Images with aspect ratio greater than 1 will be set as group 1,
              otherwise group 0.
        """
wangg12's avatar
wangg12 committed
15
        super(ConcatDataset, self).__init__(datasets)
16
17
18
19
20
        if hasattr(datasets[0], 'flag'):
            flags = []
            for i in range(0, len(datasets)):
                flags.append(datasets[i].flag)
            self.flag = np.concatenate(flags)