test_deterministic_dataloader.py 2.33 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
from functools import partial
from pathlib import Path

import pytest
import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader

import colossalai
from colossalai.builder import build_dataset
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc

CONFIG = dict(
    train_data=dict(
        dataset=dict(
            type='CIFAR10Dataset',
            root=Path(os.environ['DATA']),
            train=True,
            download=True,
            transform_pipeline=[
                dict(type='ToTensor'),
                dict(type='RandomCrop', size=32),
                dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
            ]
        ),
        dataloader=dict(
            num_workers=2,
            batch_size=2,
            shuffle=True
        )
    ),
    parallel=dict(
        pipeline=dict(size=1),
        tensor=dict(size=1, mode=None),
    ),
    seed=1024,
)


def run_data_sampler(local_rank, world_size):
    dist_args = dict(
        config=CONFIG,
        local_rank=local_rank,
        world_size=world_size,
        backend='gloo',
        port='29499',
        host='localhost'
    )
    colossalai.init_dist(**dist_args)
    gpc.set_seed()

    print('finished initialization')

    dataset = build_dataset(gpc.config.train_data.dataset)
    dataloader = DataLoader(dataset=dataset, **gpc.config.train_data.dataloader)
    data_iter = iter(dataloader)
    img, label = data_iter.next()
    img = img[0]

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        img_to_compare = img.clone()
    else:
        img_to_compare = img
    dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        # this is without sampler
        # this should be false if data parallel sampler to given to the dataloader
        assert torch.equal(img,
                           img_to_compare), 'Same image was distributed across ranks and expected it to be the same'


@pytest.mark.cpu
def test_data_sampler():
    world_size = 4
    test_func = partial(run_data_sampler, world_size=world_size)
    mp.spawn(test_func, nprocs=world_size)


if __name__ == '__main__':
    test_data_sampler()