test_initialize.py 3.6 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Neel Kant's avatar
Neel Kant committed
16
17
18
19
from commons import print_separator
from commons import initialize_distributed
import mpu
import torch
20
21
22
23
import sys
sys.path.append("../..")


24
def test_initialize_model_parallel(tensor_model_parallel_size):
25
26
27

    if torch.distributed.get_rank() == 0:
        print('> testing initialize_model_parallel with size {} ...'.format(
28
29
            tensor_model_parallel_size))
    tensor_model_parallel_size_ = min(tensor_model_parallel_size,
30
31
                               torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
32
    mpu.initialize_model_parallel(tensor_model_parallel_size_)
33
34
35
36
37
38
39
40
    assert mpu.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
41
42
43
44
45
    world_size = tensor_model_parallel_size_
    rank = torch.distributed.get_rank() % tensor_model_parallel_size_
    assert world_size == mpu.get_tensor_model_parallel_world_size()
    assert rank == mpu.get_tensor_model_parallel_rank()
    check(mpu.get_tensor_model_parallel_group(), world_size, rank)
46
47

    # Data parallel.
48
49
    world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
    rank = torch.distributed.get_rank() // tensor_model_parallel_size
50
51
52
53
54
55
56
57
58
59
60
61
    assert world_size == mpu.get_data_parallel_world_size()
    assert rank == mpu.get_data_parallel_rank()
    check(mpu.get_data_parallel_group(), world_size, rank)

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')


62
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
63
64

    if torch.distributed.get_rank() == 0:
65
66
67
        print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
            tensor_model_parallel_size_))
    tensor_model_parallel_size = min(tensor_model_parallel_size_,
68
69
                              torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
70
    mpu.initialize_model_parallel(tensor_model_parallel_size)
71
72
73
    assert mpu.model_parallel_is_initialized()

    # Checks
74
75
    src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank()
    assert mpu.get_tensor_model_parallel_src_rank() == src_rank
76
77
78
79
80
81
82
83
84
85
86
87
88

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')


if __name__ == '__main__':

    initialize_distributed()
    world_size = torch.distributed.get_world_size()
89
90
    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
91
        print_separator('test initialize model parallel')
92
        test_initialize_model_parallel(tensor_model_parallel_size)
93
        print_separator('test model parallel source rank')
94
95
        test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2