test_deterministic_dataloader.py 2.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
from functools import partial
from pathlib import Path

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms
from torch.utils.data import DataLoader

import colossalai
from colossalai.builder import build_dataset, build_transform
from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc

CONFIG = Config(
    dict(
        train_data=dict(
            dataset=dict(
                type='CIFAR10',
                root=Path(os.environ['DATA']),
                train=True,
                download=True,
            ),
            dataloader=dict(
                num_workers=2,
                batch_size=2,
                shuffle=True
            ),
            transform_pipeline=[
                dict(type='ToTensor'),
                dict(type='RandomCrop', size=32),
                dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
            ]
        ),
        parallel=dict(
            pipeline=dict(size=1),
            tensor=dict(size=1, mode=None),
        ),
        seed=1024,
    )
)


def run_data_sampler(rank, world_size):
    dist_args = dict(
        config=CONFIG,
        rank=rank,
        world_size=world_size,
        backend='gloo',
        port='29904',
        host='localhost'
    )
    colossalai.launch(**dist_args)

    dataset_cfg = gpc.config.train_data.dataset
    dataloader_cfg = gpc.config.train_data.dataloader
    transform_cfg = gpc.config.train_data.transform_pipeline

    # build transform
    transform_pipeline = [build_transform(cfg) for cfg in transform_cfg]
    transform_pipeline = transforms.Compose(transform_pipeline)
    dataset_cfg['transform'] = transform_pipeline

    # build dataset
    dataset = build_dataset(dataset_cfg)

    # build dataloader
    dataloader = DataLoader(dataset=dataset, **dataloader_cfg)

    data_iter = iter(dataloader)
    img, label = data_iter.next()
    img = img[0]

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        img_to_compare = img.clone()
    else:
        img_to_compare = img
    dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        # this is without sampler
        # this should be false if data parallel sampler to given to the dataloader
        assert torch.equal(img,
                           img_to_compare), 'Same image was distributed across ranks and expected it to be the same'
    torch.cuda.empty_cache()


@pytest.mark.cpu
def test_data_sampler():
    world_size = 4
    test_func = partial(run_data_sampler, world_size=world_size)
    mp.spawn(test_func, nprocs=world_size)


if __name__ == '__main__':
    test_data_sampler()