"vscode:/vscode.git/clone" did not exist on "33f3023e19f0edfc997788d2da04d02f33671901"
test_deterministic_dataloader.py 2.43 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
7
8
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
from functools import partial
from pathlib import Path

import pytest
Frank Lee's avatar
Frank Lee committed
9
import torch
zbian's avatar
zbian committed
10
11
import torch.distributed as dist
import torch.multiprocessing as mp
12
from torchvision import transforms, datasets
zbian's avatar
zbian committed
13
14

import colossalai
Frank Lee's avatar
Frank Lee committed
15
from colossalai.context import ParallelMode, Config
zbian's avatar
zbian committed
16
from colossalai.core import global_context as gpc
17
from colossalai.utils import get_dataloader, free_port
18
from colossalai.testing import rerun_if_address_is_in_use
19
from torchvision import transforms
zbian's avatar
zbian committed
20

Frank Lee's avatar
Frank Lee committed
21
22
CONFIG = Config(
    dict(
23
24
25
26
27
28
29
30
        train_data=dict(
            dataset=dict(
                type='CIFAR10',
                root=Path(os.environ['DATA']),
                train=True,
                download=True,
            ),
            dataloader=dict(num_workers=2, batch_size=2, shuffle=True),
zbian's avatar
zbian committed
31
        ),
Frank Lee's avatar
Frank Lee committed
32
33
34
35
36
        parallel=dict(
            pipeline=dict(size=1),
            tensor=dict(size=1, mode=None),
        ),
        seed=1024,
37
38
39
40
41
    ))


def run_data_sampler(rank, world_size, port):
    dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
Frank Lee's avatar
Frank Lee committed
42
    colossalai.launch(**dist_args)
zbian's avatar
zbian committed
43

Frank Lee's avatar
Frank Lee committed
44
    # build dataset
45
46
47
    transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)]
    transform_pipeline = transforms.Compose(transform_pipeline)
    dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)
Frank Lee's avatar
Frank Lee committed
48
49

    # build dataloader
50
    dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False)
Frank Lee's avatar
Frank Lee committed
51

zbian's avatar
zbian committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    data_iter = iter(dataloader)
    img, label = data_iter.next()
    img = img[0]

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        img_to_compare = img.clone()
    else:
        img_to_compare = img
    dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        # this is without sampler
        # this should be false if data parallel sampler to given to the dataloader
        assert torch.equal(img,
                           img_to_compare), 'Same image was distributed across ranks and expected it to be the same'
Frank Lee's avatar
Frank Lee committed
67
    torch.cuda.empty_cache()
zbian's avatar
zbian committed
68
69
70


@pytest.mark.cpu
71
@rerun_if_address_is_in_use()
zbian's avatar
zbian committed
72
73
def test_data_sampler():
    world_size = 4
74
    test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
zbian's avatar
zbian committed
75
76
77
78
79
    mp.spawn(test_func, nprocs=world_size)


if __name__ == '__main__':
    test_data_sampler()