"applications/Chat/vscode:/vscode.git/clone" did not exist on "cb413ccf28a82602eedfeebb3be5224aa02486dc"
test_deterministic_dataloader.py 2.25 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
7
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
from pathlib import Path

import pytest
Frank Lee's avatar
Frank Lee committed
8
import torch
zbian's avatar
zbian committed
9
import torch.distributed as dist
10
from torchvision import datasets, transforms
zbian's avatar
zbian committed
11
12

import colossalai
13
14
15
16
from colossalai.context import Config
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import get_dataloader
17
from colossalai.testing import rerun_if_address_is_in_use, spawn
zbian's avatar
zbian committed
18

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
CONFIG = Config(
    dict(
        train_data=dict(
            dataset=dict(
                type='CIFAR10',
                root=Path(os.environ['DATA']),
                train=True,
                download=True,
            ),
            dataloader=dict(num_workers=2, batch_size=2, shuffle=True),
        ),
        parallel=dict(
            pipeline=dict(size=1),
            tensor=dict(size=1, mode=None),
        ),
        seed=1024,
    ))
zbian's avatar
zbian committed
36
37


38
39
def run_data_sampler(rank, world_size, port):
    dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
40
    colossalai.legacy.launch(**dist_args)
zbian's avatar
zbian committed
41

42
    # build dataset
43
    transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)]
Frank Lee's avatar
Frank Lee committed
44
    transform_pipeline = transforms.Compose(transform_pipeline)
45
46
47
    dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)

    # build dataloader
48
    dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False)
49

zbian's avatar
zbian committed
50
51
52
53
54
55
56
57
58
59
60
    data_iter = iter(dataloader)
    img, label = data_iter.next()
    img = img[0]

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        img_to_compare = img.clone()
    else:
        img_to_compare = img
    dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
61
62
63
64
        # this is without sampler
        # this should be false if data parallel sampler to given to the dataloader
        assert torch.equal(img,
                           img_to_compare), 'Same image was distributed across ranks and expected it to be the same'
Frank Lee's avatar
Frank Lee committed
65
    torch.cuda.empty_cache()
zbian's avatar
zbian committed
66
67


68
@rerun_if_address_is_in_use()
zbian's avatar
zbian committed
69
def test_data_sampler():
70
    spawn(run_data_sampler, 4)
zbian's avatar
zbian committed
71
72
73
74


if __name__ == '__main__':
    test_data_sampler()