test_layers.py 12.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import logging

import torch
import torch.nn as nn
from torch.testing._internal import common_utils

logging.getLogger("torch").setLevel(logging.WARNING)

from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.testing.commons import set_random_seed
12
13
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
14
15
16
17
18
19
20
21
22
23

logging.getLogger("apex").setLevel(logging.WARNING)


# N.B. (mkozuki): Disable TF32 matrix multiply.
# Matrices used in this test are so small that TF32 matmul
# can be less precise so that `self.assertEqual` raises.
torch.backends.cuda.matmul.allow_tf32 = False


24
class TensorParallelLayerTestBase:
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

    BATCH_SIZE: int = 17
    SEQUENCE_LENGTH: int = 23
    VOCAB_SIZE: int = 48
    HIDDEN_SIZE: int = 16
    INPUT_SIZE_COEFF: int = 13
    OUTPUT_SIZE_COEFF: int = 17
    SEED: int = 123

    def test_parallel_embedding(self) -> None:
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(
                tensor_model_parallel_world_size=tensor_model_parallel_world_size
            ):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size,
                )
44
                set_random_seed(self.SEED + 1)
45
46
                input_tensor = torch.randint(
                    0,
47
                    self.VOCAB_SIZE,
48
                    (
49
50
                        self.BATCH_SIZE,
                        self.SEQUENCE_LENGTH,
51
52
53
54
55
                    ),
                    device="cuda",
                )
                loss_weight = torch.randn(
                    (
56
57
58
                        self.BATCH_SIZE,
                        self.SEQUENCE_LENGTH,
                        self.HIDDEN_SIZE,
59
60
61
62
                    ),
                    device="cuda",
                )

63
                set_random_seed(self.SEED)
64
                embedding_torch = nn.Embedding(
65
66
                    self.VOCAB_SIZE,
                    self.HIDDEN_SIZE,
67
68
69
70
71
72
73
74
                ).cuda()
                output_torch = embedding_torch(input_tensor)
                loss_torch = torch.mul(output_torch, loss_weight).sum()
                loss_torch.backward()

                # N.B. (mkozuki): With affine weight initialization on GPU,
                # it's super difficult to keep the consistency with nn.Embedding.
                # Thus, turning on `use_cpu_initialization`.
75
                set_random_seed(self.SEED)
76
                embedding_vocab_parallel = layers.VocabParallelEmbedding(
77
78
                    self.VOCAB_SIZE,
                    self.HIDDEN_SIZE,
79
80
81
82
83
84
85
86
87
88
89
90
91
92
                    init_method=nn.init.normal_,
                    use_cpu_initialization=True,
                ).cuda()
                output_vocab_parallel = embedding_vocab_parallel(input_tensor)
                loss_vocab_parallel = torch.mul(
                    output_vocab_parallel, loss_weight
                ).sum()
                loss_vocab_parallel.backward()

                self.assertEqual(output_torch, output_vocab_parallel)
                self.assertEqual(loss_torch, loss_vocab_parallel)

                splitted_weight_torch = torch.split(
                    embedding_torch.weight.grad,
93
                    self.VOCAB_SIZE
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
                    // tensor_model_parallel_world_size,
                    0,
                )[parallel_state.get_tensor_model_parallel_rank()]
                self.assertEqual(
                    splitted_weight_torch, embedding_vocab_parallel.weight.grad
                )

                parallel_state.destroy_model_parallel()

    def _affine_weight_init_test_impl(
        self, init_device: str, is_column_parallel: bool
    ) -> None:
        dim = int(not is_column_parallel)
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(
                tensor_model_parallel_world_size=tensor_model_parallel_world_size
            ):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size
                )
116
117
                input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
                output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
118
119

                weight_shape = (
120
                    (self.OUTPUT_SIZE_COEFF, input_size)
121
                    if is_column_parallel
122
                    else (output_size, self.INPUT_SIZE_COEFF)
123
124
                )
                weight = torch.empty(weight_shape)
125
                set_random_seed(self.SEED)
126
127

                sharding_dim_size = (
128
                    self.OUTPUT_SIZE_COEFF
129
                    if is_column_parallel
130
                    else self.INPUT_SIZE_COEFF
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
                )

                if init_device == "cpu":
                    layers._initialize_affine_weight_cpu(
                        weight,
                        output_size,
                        input_size,
                        sharding_dim_size,
                        dim,
                        nn.init.normal_,
                        params_dtype=torch.float32,
                    )
                else:
                    layers._initialize_affine_weight_gpu(
                        weight, torch.nn.init.normal_, dim
                    )
                # Target
148
                set_random_seed(self.SEED)
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
                if init_device == "cpu":
                    main_weight = torch.empty(output_size, input_size)
                    nn.init.normal_(main_weight)
                    curr_weight = torch.split(main_weight, sharding_dim_size, dim=dim)[
                        parallel_state.get_tensor_model_parallel_rank()
                    ]
                else:
                    curr_weight = torch.empty(*weight_shape)
                    nn.init.normal_(curr_weight)
                self.assertEqual(curr_weight, weight)
                parallel_state.destroy_model_parallel()

    def test_affine_weight_init_column_parallel_cpu(self) -> None:
        self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=True)

    def test_affine_weight_init_column_parallel_gpu(self) -> None:
        self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=True)

    def test_affine_weight_init_row_parallel_cpu(self) -> None:
        self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=False)

    def test_affine_weight_init_row_parallel_gpu(self) -> None:
        self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=False)

    def test_row_parallel_linear(self) -> None:
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_parallel_world_size:
                continue
            with self.subTest(
                tensor_model_parallel_world_size=tensor_model_parallel_world_size
            ):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size
                )

184
185
                input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
                output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
186

187
                set_random_seed(self.SEED)
188
189
190
191
192
193
194
195
                linear_layer = layers.RowParallelLinear(
                    input_size,
                    output_size,
                    keep_master_weight_for_test=True,
                    params_dtype=torch.float32,
                    use_cpu_initialization=True,
                ).cuda()
                loss_weight = torch.randn(
196
                    (self.BATCH_SIZE, output_size)
197
198
199
200
                ).cuda()

                # Forward and backward
                input_tensor = torch.randn(
201
                    self.BATCH_SIZE, input_size, requires_grad=True
202
203
204
205
206
207
208
209
210
211
212
213
214
                ).cuda()
                input_tensor.retain_grad()
                output, _ = linear_layer(input_tensor)
                loss = torch.mul(output, loss_weight).sum()
                loss.backward()
                self.assertIsNotNone(input_tensor.grad)

                with torch.no_grad():
                    dldy = loss_weight.clone()
                    x = input_tensor.clone()
                    a = linear_layer.master_weight.cuda()
                dlda = torch.matmul(dldy.t(), x)
                dldb = torch.matmul(
215
                    torch.ones(self.BATCH_SIZE, 1).cuda().t(), dldy
216
217
218
219
220
                ).view(-1)
                dldx = torch.matmul(dldy, a)

                with torch.no_grad():
                    curr_dlda = torch.split(
221
                        dlda, self.INPUT_SIZE_COEFF, dim=1
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
                    )[parallel_state.get_tensor_model_parallel_rank()].clone()
                self.assertEqual(linear_layer.weight.grad, curr_dlda)
                self.assertEqual(input_tensor.grad, dldx)
                self.assertEqual(linear_layer.bias.grad, dldb)

                parallel_state.destroy_model_parallel()

    def test_column_parallel_linear(self):
        self._column_parallel_linear_test_impl(False, False)

    def test_column_parallel_linear_no_async(self):
        self._column_parallel_linear_test_impl(True, False)

    def test_column_parallel_linear_gradient_accumulation_fusion(self):
        self._column_parallel_linear_test_impl(False, True)

    def _column_parallel_linear_test_impl(
        self,
        no_async_tensor_model_parallel_allreduce: bool,
        gradient_accumulation_fusion: bool,
    ):
        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            with self.subTest(
                tensor_model_parallel_world_size=tensor_model_parallel_world_size
            ):
                if self.world_size % tensor_model_parallel_world_size:
                    continue
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=tensor_model_parallel_world_size,
                )

253
                feature_size_coeff = self.INPUT_SIZE_COEFF
254
255
256
                feature_size = feature_size_coeff * tensor_model_parallel_world_size
                hidden_size = feature_size

257
                set_random_seed(self.SEED)
258
                input_tensor = torch.randn(
259
                    self.BATCH_SIZE,
260
261
262
263
264
265
266
                    hidden_size,
                    feature_size,
                    device="cuda",
                    requires_grad=True,
                )
                input_tensor.retain_grad()
                loss_weight = torch.randn(
267
                    (self.BATCH_SIZE, hidden_size, feature_size,),
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                    device="cuda",
                )
                linear = layers.ColumnParallelLinear(
                    feature_size,
                    feature_size,
                    bias=False,
                    keep_master_weight_for_test=True,
                    params_dtype=torch.float32,
                    use_cpu_initialization=True,
                    no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce,
                    gradient_accumulation_fusion=gradient_accumulation_fusion,
                ).cuda()
                if gradient_accumulation_fusion:
                    with torch.no_grad():
                        linear.weight.main_grad = torch.randn_like(linear.weight)
                output, _ = linear(input_tensor)
                self.assertEqual(
                    output.shape,
286
                    (self.BATCH_SIZE, hidden_size, feature_size,),
287
288
289
290
291
292
293
294
295
296
                )
                loss = torch.mul(output, loss_weight).sum()
                loss.backward()

                with torch.no_grad():
                    dldy = loss_weight.clone()
                    x = input_tensor.clone()
                    a = linear.master_weight.cuda().clone()
                dldx = torch.matmul(dldy, a)
                self.assertEqual(input_tensor.grad, dldx)
297
                # TODO(mkozuki): Cover the other cases.
298
299
300
301
302
303
304
305
306
307
308
309
310
                if (
                    tensor_model_parallel_world_size == 1
                    and not gradient_accumulation_fusion
                ):
                    dlda = torch.matmul(torch.transpose(dldy, 1, 2), x).sum(dim=0)
                    curr_dlda = torch.split(dlda, feature_size_coeff, dim=0)[
                        parallel_state.get_tensor_model_parallel_rank()
                    ]
                    self.assertEqual(linear.weight.grad, curr_dlda)

                parallel_state.destroy_model_parallel()


311
312
313
314
315
316
317
318
class NcclTensorParallelLayerTest(TensorParallelLayerTestBase, NcclDistributedTestBase):
    pass


class UccTensorParallelLayerTest(TensorParallelLayerTestBase, UccDistributedTestBase):
    pass


319
320
if __name__ == "__main__":
    common_utils.run_tests()