test_comm.py 2.41 KB
Newer Older
アマデウス's avatar
アマデウス committed
1
2
3
import pytest
import torch
import torch.distributed as dist
4

5
from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter
6
7
8
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.initialize import launch
9
10
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
アマデウス's avatar
アマデウス committed
11
12
13
14
15
16
17
18
19

CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))

SIZE = 8


def check_all_gather():
    tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
    tensor = tensor.to(get_current_device())
20
    print("Before:   Rank {0} - {1}".format(dist.get_rank(), tensor))
アマデウス's avatar
アマデウス committed
21
    tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True)
22
    print("After:    Rank {0} - {1}".format(dist.get_rank(), tensor))
アマデウス's avatar
アマデウス committed
23
    op.wait()
24
    print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor))
アマデウス's avatar
アマデウス committed
25
26
27
28
29
30
    torch.cuda.synchronize()


def check_reduce_scatter():
    tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
    tensor = tensor.to(get_current_device())
31
    print("Before:   Rank {0} - {1}".format(dist.get_rank(), tensor))
アマデウス's avatar
アマデウス committed
32
    tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True)
33
    print("After:    Rank {0} - {1}".format(dist.get_rank(), tensor))
アマデウス's avatar
アマデウス committed
34
    op.wait()
35
    print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor))
アマデウス's avatar
アマデウス committed
36
37
38
39
40
41
    torch.cuda.synchronize()


def check_all_reduce():
    tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
    tensor = tensor.to(get_current_device())
42
    print("Before:   Rank {0} - {1}".format(dist.get_rank(), tensor))
アマデウス's avatar
アマデウス committed
43
    tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True)
44
    print("After:    Rank {0} - {1}".format(dist.get_rank(), tensor))
アマデウス's avatar
アマデウス committed
45
    op.wait()
46
    print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor))
アマデウス's avatar
アマデウス committed
47
48
49
    torch.cuda.synchronize()


50
def check_layer(rank, world_size, port):
51
    launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
アマデウス's avatar
アマデウス committed
52
53

    assert dist.get_rank() == gpc.get_global_rank()
54
    print("Rank {} / {}".format(dist.get_rank(), dist.get_world_size()))
アマデウス's avatar
アマデウス committed
55
56
57
58
59
60
61
62
63
64

    check_all_gather()
    check_reduce_scatter()
    check_all_reduce()

    gpc.destroy()
    torch.cuda.empty_cache()


@pytest.mark.dist
65
@rerun_if_address_is_in_use()
アマデウス's avatar
アマデウス committed
66
def test_comm():
67
    spawn(check_layer, 4)
アマデウス's avatar
アマデウス committed
68
69


70
if __name__ == "__main__":
アマデウス's avatar
アマデウス committed
71
    test_comm()