test_parallel_state.py 7.49 KB
Newer Older
1
2
3
4
5
6
7
8
import logging
import os

from torch.testing._internal import common_utils

logging.getLogger("torch").setLevel(logging.WARNING)

from apex.transformer import parallel_state
9
10
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
11
12
13
14
15
16
17
18
19
20
21
22
23
24

logging.getLogger("apex").setLevel(logging.WARNING)


os.environ["BACKEND"] = "NCCL"
DATA_PARALLEL_WORLD_SIZE: int = 1


def calc_expected_tensor_model_paralell_rank(
    rank: int, tensor_model_parallel_world_size: int,
) -> int:
    return rank % tensor_model_parallel_world_size


25
class ParallelStateTestBase:
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    def test_initialize_model_parallel(self) -> None:

        self.assertFalse(parallel_state.model_parallel_is_initialized())

        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            with self.subTest(
                tensor_model_parallel_world_size=tensor_model_parallel_world_size
            ):
                if self.world_size % tensor_model_parallel_world_size:
                    continue

                pipeline_model_parallel_world_size = (
                    self.world_size // tensor_model_parallel_world_size
                )

                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size,
                    pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
                )
                self.assertEqual(
                    tensor_model_parallel_world_size,
                    parallel_state.get_tensor_model_parallel_world_size(),
                )
                expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
                    self.rank, tensor_model_parallel_world_size
                )
                self.assertEqual(
                    expected_tensor_model_parallel_rank,
                    parallel_state.get_tensor_model_parallel_rank(),
                )

                expected_tensor_model_parallel_src_rank = (
                    self.rank // tensor_model_parallel_world_size
                ) * tensor_model_parallel_world_size
                self.assertEqual(
                    expected_tensor_model_parallel_src_rank,
                    parallel_state.get_tensor_model_parallel_src_rank(),
                )

                parallel_state.destroy_model_parallel()
                self.assertFalse(parallel_state.model_parallel_is_initialized())

    def test_initialize_model_parallel_with_virtual_and_split(self) -> None:
        if self.world_size < 4:
            self.skipTest("requires >= 4 GPUs")
        self.assertFalse(parallel_state.model_parallel_is_initialized())

        tensor_model_parallel_world_size = 1 + int(self.world_size > 4)
        pipeline_model_parallel_world_size = (
            self.world_size // tensor_model_parallel_world_size
        )
        virtual_pipeline_model_parallel_world_size = 2
        pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2

        parallel_state.initialize_model_parallel(
            tensor_model_parallel_size_=tensor_model_parallel_world_size,
            pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
            virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_world_size,
            pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
        )
        self.assertEqual(
            calc_expected_tensor_model_paralell_rank(
                self.rank, tensor_model_parallel_world_size
            ),
            parallel_state.get_tensor_model_parallel_rank(),
        )
        self.assertEqual(
            pipeline_model_parallel_world_size,
            parallel_state.get_pipeline_model_parallel_world_size(),
        )
        self.assertEqual(
            virtual_pipeline_model_parallel_world_size,
            parallel_state.get_virtual_pipeline_model_parallel_world_size(),
        )

        expected_pipeline_rank = (
            self.rank - (self.rank % tensor_model_parallel_world_size)
        ) % pipeline_model_parallel_world_size
        self.assertEqual(
            expected_pipeline_rank, parallel_state.get_pipeline_model_parallel_rank(),
        )
        # virtual pipeline model parallel rank is lazily set, i.e., right after the call of
        # `initialize_model_parallel`, it's set to 0.
        self.assertEqual(
            0, parallel_state.get_virtual_pipeline_model_parallel_rank(),
        )
        self.assertEqual(
            pipeline_model_parallel_split_rank,
            parallel_state.get_pipeline_model_parallel_split_rank(),
        )

        fake_split_rank = 77
        parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
        self.assertEqual(
            fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()
        )

Perkz Zheng's avatar
Perkz Zheng committed
123
124
125
126
127
128
129
130
131
132
        # relative position embedding groups check
        self.assertEqual(
           expected_pipeline_rank < pipeline_model_parallel_split_rank,
           parallel_state.is_rank_in_encoder_relative_position_embedding_group(),
        )
        self.assertEqual(
           expected_pipeline_rank >= pipeline_model_parallel_split_rank,
           parallel_state.is_rank_in_decoder_relative_position_embedding_group(),
        )

133
134
        parallel_state.destroy_model_parallel()

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    def test_initialize_model_parallel_decoder_only(self) -> None:
        """Initialize model parallelism for decoder-only Transformers like GPT-3"""

        self.assertFalse(parallel_state.model_parallel_is_initialized())

        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            with self.subTest(
                tensor_model_parallel_world_size=tensor_model_parallel_world_size
            ):
                if self.world_size % tensor_model_parallel_world_size:
                    continue

                pipeline_model_parallel_world_size = (
                    self.world_size // tensor_model_parallel_world_size
                )

                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size,
                    pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
                    pipeline_model_parallel_split_rank_=0,
                )
                self.assertEqual(
                    tensor_model_parallel_world_size,
                    parallel_state.get_tensor_model_parallel_world_size(),
                )
                expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
                    self.rank, tensor_model_parallel_world_size
                )
                self.assertEqual(
                    expected_tensor_model_parallel_rank,
                    parallel_state.get_tensor_model_parallel_rank(),
                )

                expected_tensor_model_parallel_src_rank = (
                    self.rank // tensor_model_parallel_world_size
                ) * tensor_model_parallel_world_size
                self.assertEqual(
                    expected_tensor_model_parallel_src_rank,
                    parallel_state.get_tensor_model_parallel_src_rank(),
                )

                parallel_state.destroy_model_parallel()
                self.assertFalse(parallel_state.model_parallel_is_initialized())

179

180
181
182
183
class NcclParallelStateTest(ParallelStateTestBase, NcclDistributedTestBase): pass
class UccParallelStateTest(ParallelStateTestBase, UccDistributedTestBase): pass


184
185
if __name__ == "__main__":
    common_utils.run_tests()