"benchmarks/disagg_benchmarks/round_robin_proxy.py" did not exist on "ceccb7160f027695f2e89eab8c279a6374d3f719"
concat_dataset.py 698 Bytes
Newer Older
1
2
3
4
5
import numpy as np
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset


class ConcatDataset(_ConcatDataset):
6
7
8
    """A wrapper of concatenated dataset.

    Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
9
    concat the group flag for image aspect ratio.
10
11
12

    Args:
        datasets (list[:obj:`Dataset`]): A list of datasets.
13
    """
14

15
    def __init__(self, datasets):
wangg12's avatar
wangg12 committed
16
        super(ConcatDataset, self).__init__(datasets)
17
        self.CLASSES = datasets[0].CLASSES
18
19
20
21
22
        if hasattr(datasets[0], 'flag'):
            flags = []
            for i in range(0, len(datasets)):
                flags.append(datasets[i].flag)
            self.flag = np.concatenate(flags)