test_deterministic_dataloader.py 2.2 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
7
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
from pathlib import Path

import pytest
Frank Lee's avatar
Frank Lee committed
8
import torch
zbian's avatar
zbian committed
9
import torch.distributed as dist
10
from torchvision import datasets, transforms
zbian's avatar
zbian committed
11
12

import colossalai
13
from colossalai.context import Config, ParallelMode
zbian's avatar
zbian committed
14
from colossalai.core import global_context as gpc
15
16
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_dataloader
zbian's avatar
zbian committed
17

Frank Lee's avatar
Frank Lee committed
18
19
CONFIG = Config(
    dict(
20
21
22
23
24
25
26
27
        train_data=dict(
            dataset=dict(
                type='CIFAR10',
                root=Path(os.environ['DATA']),
                train=True,
                download=True,
            ),
            dataloader=dict(num_workers=2, batch_size=2, shuffle=True),
zbian's avatar
zbian committed
28
        ),
Frank Lee's avatar
Frank Lee committed
29
30
31
32
33
        parallel=dict(
            pipeline=dict(size=1),
            tensor=dict(size=1, mode=None),
        ),
        seed=1024,
34
35
36
37
38
    ))


def run_data_sampler(rank, world_size, port):
    dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
Frank Lee's avatar
Frank Lee committed
39
    colossalai.launch(**dist_args)
zbian's avatar
zbian committed
40

Frank Lee's avatar
Frank Lee committed
41
    # build dataset
42
43
44
    transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)]
    transform_pipeline = transforms.Compose(transform_pipeline)
    dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)
Frank Lee's avatar
Frank Lee committed
45
46

    # build dataloader
47
    dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False)
Frank Lee's avatar
Frank Lee committed
48

zbian's avatar
zbian committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    data_iter = iter(dataloader)
    img, label = data_iter.next()
    img = img[0]

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        img_to_compare = img.clone()
    else:
        img_to_compare = img
    dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        # this is without sampler
        # this should be false if data parallel sampler to given to the dataloader
        assert torch.equal(img,
                           img_to_compare), 'Same image was distributed across ranks and expected it to be the same'
Frank Lee's avatar
Frank Lee committed
64
    torch.cuda.empty_cache()
zbian's avatar
zbian committed
65
66


67
@rerun_if_address_is_in_use()
zbian's avatar
zbian committed
68
def test_data_sampler():
69
    spawn(run_data_sampler, 4)
zbian's avatar
zbian committed
70
71
72
73


if __name__ == '__main__':
    test_data_sampler()