import bisect import numpy as np from torch.utils.data.dataset import ConcatDataset as _ConcatDataset class ConcatDataset(_ConcatDataset): """ Same as torch.utils.data.dataset.ConcatDataset, but concat the group flag for image aspect ratio. """ def __init__(self, datasets): """ flag: Images with aspect ratio greater than 1 will be set as group 1, otherwise group 0. """ super(ConcatDataset, self).__init__(datasets) if hasattr(datasets[0], 'flag'): flags = [] for i in range(0, len(datasets)): flags.append(datasets[i].flag) self.flag = np.concatenate(flags) def get_idxs(self, idx): dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return dataset_idx, sample_idx