test_data_parallel_sampler.py 1.95 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
7
8
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
from functools import partial
from pathlib import Path

import pytest
Frank Lee's avatar
Frank Lee committed
9
import torch
zbian's avatar
zbian committed
10
11
12
13
import torch.distributed as dist
import torch.multiprocessing as mp

import colossalai
14
from torchvision import transforms, datasets
Frank Lee's avatar
Frank Lee committed
15
from colossalai.context import ParallelMode, Config
zbian's avatar
zbian committed
16
from colossalai.core import global_context as gpc
17
from colossalai.utils import get_dataloader, free_port
18
from colossalai.testing import rerun_if_address_is_in_use
zbian's avatar
zbian committed
19

20
21
22
23
24
25
26
CONFIG = Config(dict(
    parallel=dict(
        pipeline=dict(size=1),
        tensor=dict(size=1, mode=None),
    ),
    seed=1024,
))
zbian's avatar
zbian committed
27
28


29
30
def run_data_sampler(rank, world_size, port):
    dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
Frank Lee's avatar
Frank Lee committed
31
    colossalai.launch(**dist_args)
zbian's avatar
zbian committed
32
33
    print('finished initialization')

34
35
    # build dataset
    transform_pipeline = [transforms.ToTensor()]
Frank Lee's avatar
Frank Lee committed
36
    transform_pipeline = transforms.Compose(transform_pipeline)
37
38
39
40
41
    dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)

    # build dataloader
    dataloader = get_dataloader(dataset, batch_size=8, add_sampler=True)

zbian's avatar
zbian committed
42
43
44
45
46
47
48
49
50
51
52
    data_iter = iter(dataloader)
    img, label = data_iter.next()
    img = img[0]

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        img_to_compare = img.clone()
    else:
        img_to_compare = img
    dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
53
54
        assert not torch.equal(
            img, img_to_compare), 'Same image was distributed across ranks but expected it to be different'
Frank Lee's avatar
Frank Lee committed
55
    torch.cuda.empty_cache()
zbian's avatar
zbian committed
56
57
58


@pytest.mark.cpu
59
@rerun_if_address_is_in_use()
zbian's avatar
zbian committed
60
61
def test_data_sampler():
    world_size = 4
62
    test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
zbian's avatar
zbian committed
63
64
65
66
67
    mp.spawn(test_func, nprocs=world_size)


if __name__ == '__main__':
    test_data_sampler()