import numpy as np from torch.utils.data.dataset import ConcatDataset as _ConcatDataset class ConcatDataset(_ConcatDataset): """A wrapper of concatenated dataset. Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but concat the group flag for image aspect ratio. Args: datasets (list[:obj:`Dataset`]): A list of datasets. """ def __init__(self, datasets): super(ConcatDataset, self).__init__(datasets) self.CLASSES = datasets[0].CLASSES if hasattr(datasets[0], 'flag'): flags = [] for i in range(0, len(datasets)): flags.append(datasets[i].flag) self.flag = np.concatenate(flags)