test_data_parallel_sampler.py 1.76 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
7
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
from pathlib import Path

import pytest
Frank Lee's avatar
Frank Lee committed
8
import torch
zbian's avatar
zbian committed
9
import torch.distributed as dist
10
from torchvision import datasets, transforms
zbian's avatar
zbian committed
11
12

import colossalai
13
from colossalai.context import Config, ParallelMode
zbian's avatar
zbian committed
14
from colossalai.core import global_context as gpc
15
16
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_dataloader
zbian's avatar
zbian committed
17

18
19
20
21
22
23
24
CONFIG = Config(dict(
    parallel=dict(
        pipeline=dict(size=1),
        tensor=dict(size=1, mode=None),
    ),
    seed=1024,
))
zbian's avatar
zbian committed
25
26


27
28
def run_data_sampler(rank, world_size, port):
    dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
Frank Lee's avatar
Frank Lee committed
29
    colossalai.launch(**dist_args)
zbian's avatar
zbian committed
30
31
    print('finished initialization')

32
33
    # build dataset
    transform_pipeline = [transforms.ToTensor()]
Frank Lee's avatar
Frank Lee committed
34
    transform_pipeline = transforms.Compose(transform_pipeline)
35
36
37
38
39
    dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)

    # build dataloader
    dataloader = get_dataloader(dataset, batch_size=8, add_sampler=True)

zbian's avatar
zbian committed
40
41
42
43
44
45
46
47
48
49
50
    data_iter = iter(dataloader)
    img, label = data_iter.next()
    img = img[0]

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
        img_to_compare = img.clone()
    else:
        img_to_compare = img
    dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))

    if gpc.get_local_rank(ParallelMode.DATA) != 0:
51
52
        assert not torch.equal(
            img, img_to_compare), 'Same image was distributed across ranks but expected it to be different'
Frank Lee's avatar
Frank Lee committed
53
    torch.cuda.empty_cache()
zbian's avatar
zbian committed
54
55


56
@rerun_if_address_is_in_use()
zbian's avatar
zbian committed
57
def test_data_sampler():
58
    spawn(run_data_sampler, 4)
zbian's avatar
zbian committed
59
60
61
62


if __name__ == '__main__':
    test_data_sampler()