test_initialize.py 3.05 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

Neel Kant's avatar
Neel Kant committed
3
4
5
6
from commons import print_separator
from commons import initialize_distributed
import mpu
import torch
7
8
9
10
import sys
sys.path.append("../..")


11
def test_initialize_model_parallel(tensor_model_parallel_size):
12
13
14

    if torch.distributed.get_rank() == 0:
        print('> testing initialize_model_parallel with size {} ...'.format(
15
16
            tensor_model_parallel_size))
    tensor_model_parallel_size_ = min(tensor_model_parallel_size,
17
18
                               torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
19
    mpu.initialize_model_parallel(tensor_model_parallel_size_)
20
21
22
23
24
25
26
27
    assert mpu.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
28
29
30
31
32
    world_size = tensor_model_parallel_size_
    rank = torch.distributed.get_rank() % tensor_model_parallel_size_
    assert world_size == mpu.get_tensor_model_parallel_world_size()
    assert rank == mpu.get_tensor_model_parallel_rank()
    check(mpu.get_tensor_model_parallel_group(), world_size, rank)
33
34

    # Data parallel.
35
36
    world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
    rank = torch.distributed.get_rank() // tensor_model_parallel_size
37
38
39
40
41
42
43
44
45
46
47
48
    assert world_size == mpu.get_data_parallel_world_size()
    assert rank == mpu.get_data_parallel_rank()
    check(mpu.get_data_parallel_group(), world_size, rank)

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')


49
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
50
51

    if torch.distributed.get_rank() == 0:
52
53
54
        print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
            tensor_model_parallel_size_))
    tensor_model_parallel_size = min(tensor_model_parallel_size_,
55
56
                              torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
57
    mpu.initialize_model_parallel(tensor_model_parallel_size)
58
59
60
    assert mpu.model_parallel_is_initialized()

    # Checks
61
62
    src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank()
    assert mpu.get_tensor_model_parallel_src_rank() == src_rank
63
64
65
66
67
68
69
70
71
72
73
74
75

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')


if __name__ == '__main__':

    initialize_distributed()
    world_size = torch.distributed.get_world_size()
76
77
    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
78
        print_separator('test initialize model parallel')
79
        test_initialize_model_parallel(tensor_model_parallel_size)
80
        print_separator('test model parallel source rank')
81
82
        test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2