test_2d_init.py 3.01 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from functools import partial
from pathlib import Path

import pytest
import torch.multiprocessing as mp

from colossalai import init_dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc

CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute()


def check_data_parallel_rank(rank):
    if rank in [0, 1, 2, 3, 4, 5, 6, 7]:
        assert gpc.get_local_rank(ParallelMode.DATA) == 0
    elif rank in [8, 9, 10, 11, 12, 13, 14, 15]:
        assert gpc.get_local_rank(ParallelMode.DATA) == 1


def check_pipeline_parallel_rank(rank):
    if rank in [0, 1, 2, 3]:
        assert gpc.get_local_rank(ParallelMode.PIPELINE) == 0
    elif rank in [4, 5, 6, 7]:
        assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
    elif rank in [8, 9, 10, 11]:
        assert gpc.get_local_rank(ParallelMode.PIPELINE) == 0
    elif rank in [12, 13, 14, 15]:
        assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1


def check_tensor_parallel_rank(rank):
    if rank in [0, 4, 8, 12]:
        assert gpc.get_local_rank(ParallelMode.TENSOR) == 0
    elif rank in [1, 5, 9, 13]:
        assert gpc.get_local_rank(ParallelMode.TENSOR) == 1
    elif rank in [2, 6, 10, 14]:
        assert gpc.get_local_rank(ParallelMode.TENSOR) == 2
    elif rank in [3, 7, 11, 15]:
        assert gpc.get_local_rank(ParallelMode.TENSOR) == 3


def check_2d_parallel_rank(rank):
    if rank in [0, 4, 8, 12]:
        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0
        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 0
    elif rank in [1, 5, 9, 13]:
        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0
        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 1
    elif rank in [2, 6, 10, 14]:
        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 1
        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 0
    elif rank in [3, 7, 11, 15]:
        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 1
        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 1


def init_2d(local_rank, world_size, backend, port, host):
    dist_args = dict(
        config=CONFIG_PATH,
        local_rank=local_rank,
        world_size=world_size,
        backend=backend,
        port=port,
        host=host
    )
    init_dist(**dist_args)

    check_tensor_parallel_rank(local_rank)
    check_data_parallel_rank(local_rank)
    check_2d_parallel_rank(local_rank)
    check_pipeline_parallel_rank(local_rank)

    gpc.destroy()


@pytest.mark.cpu
def test_2d_init():
    """
    As no computation or communication is done, we can run this test on CPU.
    """
    world_size = 16
    test_fn = partial(init_2d,
                      world_size=world_size,
                      backend='gloo',
                      port='29500',
                      host='localhost'
                      )
    mp.spawn(test_fn, nprocs=world_size)


if __name__ == '__main__':
    test_2d_init()