build_loader.py 2.44 KB
Newer Older
1
import platform
Kai Chen's avatar
Kai Chen committed
2
3
from functools import partial

Kai Chen's avatar
Kai Chen committed
4
from mmcv.parallel import collate
5
from mmcv.runner import get_dist_info
Kai Chen's avatar
Kai Chen committed
6
7
from torch.utils.data import DataLoader

8
from .sampler import DistributedGroupSampler, DistributedSampler, GroupSampler
Kai Chen's avatar
Kai Chen committed
9

10
11
12
13
14
if platform.system() != 'Windows':
    # https://github.com/pytorch/pytorch/issues/973
    import resource
    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
Kai Chen's avatar
Kai Chen committed
15

Kai Chen's avatar
Kai Chen committed
16
17
18
19

def build_dataloader(dataset,
                     imgs_per_gpu,
                     workers_per_gpu,
Kai Chen's avatar
Kai Chen committed
20
                     num_gpus=1,
Kai Chen's avatar
Kai Chen committed
21
                     dist=True,
22
                     shuffle=True,
Kai Chen's avatar
Kai Chen committed
23
                     **kwargs):
Kai Chen's avatar
Kai Chen committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    """Build PyTorch DataLoader.

    In distributed training, each GPU/process has a dataloader.
    In non-distributed training, there is only one dataloader for all GPUs.

    Args:
        dataset (Dataset): A PyTorch dataset.
        imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of
            each GPU.
        workers_per_gpu (int): How many subprocesses to use for data loading
            for each GPU.
        num_gpus (int): Number of GPUs. Only used in non-distributed training.
        dist (bool): Distributed training/test or not. Default: True.
        shuffle (bool): Whether to shuffle the data at every epoch.
            Default: True.
        kwargs: any keyword argument to be used to initialize DataLoader

    Returns:
        DataLoader: A PyTorch dataloader.
    """
Kai Chen's avatar
Kai Chen committed
44
    if dist:
45
        rank, world_size = get_dist_info()
Kai Chen's avatar
Kai Chen committed
46
47
        # DistributedGroupSampler will definitely shuffle the data to satisfy
        # that images on each GPU are in the same group
48
49
50
51
        if shuffle:
            sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
                                              world_size, rank)
        else:
52
53
            sampler = DistributedSampler(
                dataset, world_size, rank, shuffle=False)
Kai Chen's avatar
Kai Chen committed
54
55
56
        batch_size = imgs_per_gpu
        num_workers = workers_per_gpu
    else:
57
        sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
Kai Chen's avatar
Kai Chen committed
58
59
60
        batch_size = num_gpus * imgs_per_gpu
        num_workers = num_gpus * workers_per_gpu

61
62
63
64
65
66
67
68
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
        pin_memory=False,
        **kwargs)
Kai Chen's avatar
Kai Chen committed
69
70

    return data_loader