test_initialize.py 4.72 KB
Newer Older
Tom Birch's avatar
Tom Birch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# coding=utf-8

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from fairscale.nn.model_parallel import initialize as mpu
from tests.nn.model_parallel.commons import dist_init, spawn_for_all_world_sizes


def run_test_initialize_model_parallel(rank, model_parallel_size):
    dist_init(rank, model_parallel_size)

    if torch.distributed.get_rank() == 0:
        print("> testing initialize_model_parallel with size {} ...".format(model_parallel_size))
    model_parallel_size_ = min(model_parallel_size, torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
    mpu.initialize_model_parallel(model_parallel_size_)
    assert mpu.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
    world_size = model_parallel_size_
    rank = torch.distributed.get_rank() % model_parallel_size_
    assert world_size == mpu.get_model_parallel_world_size()
    assert rank == mpu.get_model_parallel_rank()
    check(mpu.get_model_parallel_group(), world_size, rank)

    # Data parallel.
    world_size = torch.distributed.get_world_size() // model_parallel_size_
    rank = torch.distributed.get_rank() // model_parallel_size
    assert world_size == mpu.get_data_parallel_world_size()
    assert rank == mpu.get_data_parallel_rank()
    check(mpu.get_data_parallel_group(), world_size, rank)

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(">> passed the test :-)")


def run_test_get_model_parallel_src_rank(rank, model_parallel_size_):
    dist_init(rank, model_parallel_size_)

    if torch.distributed.get_rank() == 0:
        print("> testing get_model_parallel_src_rank with size {} ...".format(model_parallel_size_))
    model_parallel_size = min(model_parallel_size_, torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
    mpu.initialize_model_parallel(model_parallel_size)
    assert mpu.model_parallel_is_initialized()

    # Checks
    src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank()
    assert mpu.get_model_parallel_src_rank() == src_rank

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(">> passed the test :-)")


def test_initialize_model_parallel():
    spawn_for_all_world_sizes(run_test_initialize_model_parallel)


def test_get_model_parallel_src_rank():
    spawn_for_all_world_sizes(run_test_get_model_parallel_src_rank)


def test_adjacency(monkeypatch):

    new_groups = []

    data_parallel_size = 32
    pipeline_length = 8
    model_parallel_size = 4

    class MockDistribued:
        def get_rank(self):
            return 0

        def is_initialized(self):
            return True

        def get_world_size(self):
            return data_parallel_size * pipeline_length * model_parallel_size

113
        def new_group(self, args, backend=None):
Tom Birch's avatar
Tom Birch committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
            new_groups.append(args.copy())
            return ()

    monkeypatch.setattr(torch, "distributed", MockDistribued())

    mpu.initialize_model_parallel(model_parallel_size, pipeline_length)

    from collections import defaultdict

    buckets = defaultdict(list)

    for group in new_groups:
        buckets[len(group)].append(group)

Tom Birch's avatar
Tom Birch committed
128
    assert sorted(list(buckets.keys())) == [model_parallel_size, pipeline_length, data_parallel_size]
Tom Birch's avatar
Tom Birch committed
129
130
131

    assert len(buckets[model_parallel_size]) == pipeline_length * data_parallel_size
    assert len(buckets[data_parallel_size]) == model_parallel_size * pipeline_length
Tom Birch's avatar
Tom Birch committed
132
    assert len(buckets[pipeline_length]) == model_parallel_size * data_parallel_size
Tom Birch's avatar
Tom Birch committed
133
134
135
136
137

    # Check that model_parallel groups are contiguous
    for group in buckets[model_parallel_size]:
        assert sorted(group) == group
        assert list(range(group[0], group[-1] + 1)) == group