test_layers.py 18.8 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

Neel Kant's avatar
Neel Kant committed
3
4
5
6
7
8
9
10
from mpu import layers
from commons import set_random_seed
from commons import print_separator
from commons import initialize_distributed
import mpu
from torch.nn.parameter import Parameter
import torch.nn.init as init
import torch
11
12
13
14
15
import random
import sys
sys.path.append("../..")


16
def test_parallel_embedding(tensor_model_parallel_size):
17
18
19

    if torch.distributed.get_rank() == 0:
        print('> testing parallel embedding with model parallel size {} ...'.
20
              format(tensor_model_parallel_size))
21

22
23
    mpu.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
24
25
26
27
28
29
30
31
32

    batch_size = 17
    seq_length = 23
    vocab_size = 48
    hidden_size = 16
    seed = 1236

    set_random_seed(123)
    input_data = torch.LongTensor(
Neel Kant's avatar
Neel Kant committed
33
        size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
34
35
36
37
38
39
40
41
42
43
44
    loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()

    set_random_seed(seed)
    embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()

    output = embedding_original(input_data)
    loss_original = torch.mul(output, loss_weight).sum()
    loss_original.backward()

    set_random_seed(seed)
    embedding_parallel = layers.ParallelEmbedding(
Neel Kant's avatar
Neel Kant committed
45
        vocab_size, hidden_size, init_method=init.normal_).cuda()
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    output = embedding_parallel(input_data)
    loss_parallel = torch.mul(output, loss_weight).sum()
    loss_parallel.backward()

    set_random_seed(seed)
    embedding_vocab_parallel = layers.VocabParallelEmbedding(
        vocab_size, hidden_size, init_method=init.normal_).cuda()
    output = embedding_vocab_parallel(input_data)
    loss_vocab_parallel = torch.mul(output, loss_weight).sum()
    loss_vocab_parallel.backward()

    torch.distributed.barrier()
    error = loss_parallel.sub(loss_original).abs()
    print('   error in loss (parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    torch.distributed.barrier()
    error = loss_vocab_parallel.sub(loss_original).abs()
    print('   error in loss (vocab parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    weight_grad_orig = torch.split(embedding_original.weight.grad,
70
71
                                   hidden_size // tensor_model_parallel_size,
                                   1)[mpu.get_tensor_model_parallel_rank()]
72
73
74
75
76
77
    error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
    print('   error in grad (parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    weight_grad_orig = torch.split(embedding_original.weight.grad,
78
79
                                   vocab_size // tensor_model_parallel_size,
                                   0)[mpu.get_tensor_model_parallel_rank()]
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    error = embedding_vocab_parallel.weight.grad.sub(
        weight_grad_orig).abs().max()
    print('   error in grad (vocab parallel) on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-12, 'error: {}'.format(error)

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')


94
def test_initialize_affine_weight(tensor_model_parallel_size):
95

96
    mpu.initialize_model_parallel(tensor_model_parallel_size)
97
98
    if torch.distributed.get_rank() == 0:
        print('> testing initialize_affine_weight with model parallel '
99
100
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
101
102
103

    seed = 12345
    input_size_coeff = 13
104
    input_size = input_size_coeff * tensor_model_parallel_size
105
    output_size_coeff = 17
106
    output_size = output_size_coeff * tensor_model_parallel_size
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    # ---------------
    # Column parallel
    # ---------------
    weight = torch.empty(output_size_coeff, input_size)
    set_random_seed(seed)
    layers._initialize_affine_weight(weight, output_size, input_size,

                                     output_size_coeff, 0,
                                     torch.nn.init.normal_)
    # Target.
    set_random_seed(seed)
    master_weight = torch.empty(output_size, input_size)
    torch.nn.init.normal_(master_weight)
121
    rank = mpu.get_tensor_model_parallel_rank()
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    my_weight = torch.split(master_weight, output_size_coeff,
                            dim=0)[rank].contiguous().clone()

    # Compare.
    error = weight.sub(my_weight).abs().max()
    torch.distributed.barrier()
    print('   column parallel max error (should be zero) on global rank '
          '{}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # ------------
    # Row parallel
    # ------------
    weight = torch.empty(output_size, input_size_coeff)
    set_random_seed(seed)
    mpu.layers._initialize_affine_weight(weight, output_size, input_size,
                                         input_size_coeff, 1,
                                         torch.nn.init.normal_)
    # Target.
    set_random_seed(seed)
    master_weight = torch.empty(output_size, input_size)
    torch.nn.init.normal_(master_weight)
144
    rank = mpu.get_tensor_model_parallel_rank()
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    my_weight = torch.split(master_weight, input_size_coeff,
                            dim=1)[rank].contiguous().clone()

    # Compare.
    error = weight.sub(my_weight).abs().max()
    torch.distributed.barrier()
    print('   row parallel max error (should be zero) on global rank '
          '{}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')


class IdentityLayer2D(torch.nn.Module):
Neel Kant's avatar
Neel Kant committed
164
    def __init__(self, m, n):
165
166
167
        super(IdentityLayer2D, self).__init__()
        self.weight = Parameter(torch.Tensor(m, n))
        torch.nn.init.xavier_normal_(self.weight)
Neel Kant's avatar
Neel Kant committed
168

169
170
171
172
    def forward(self):
        return self.weight


173
def test_column_parallel_linear(tensor_model_parallel_size):
174

175
    mpu.initialize_model_parallel(tensor_model_parallel_size)
176
177
    if torch.distributed.get_rank() == 0:
        print('> testing ColumnParallelLinear with model parallel '
178
179
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
180
181
182
183

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
184
    input_size = input_size_coeff * tensor_model_parallel_size
185
    output_size_coeff = 17
186
    output_size = output_size_coeff * tensor_model_parallel_size
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    batch_size = 7

    # Network
    identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
    linear_layer = mpu.ColumnParallelLinear(
        input_size, output_size, keep_master_weight_for_test=True).cuda()
    loss_weight = torch.randn([batch_size, output_size]).cuda()
    # Forward
    input_ = identity_layer()
    output = linear_layer(input_)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    # Values.
    dLdY = loss_weight
    X = identity_layer.weight
    A = linear_layer.master_weight.cuda()
    dLdA = torch.matmul(dLdY.t(), X)
    dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
    dLdX = torch.matmul(dLdY, A)

209
    rank = mpu.get_tensor_model_parallel_rank()
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    my_dLdA = torch.split(dLdA, output_size_coeff,
                          dim=0)[rank].contiguous().clone()
    error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdA on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    my_dLdb = torch.split(dLdb, output_size_coeff,
                          dim=0)[rank].contiguous().clone()
    error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdb on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = dLdX.sub(identity_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdX on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')


240
def test_row_parallel_linear(tensor_model_parallel_size):
241

242
    mpu.initialize_model_parallel(tensor_model_parallel_size)
243
244
    if torch.distributed.get_rank() == 0:
        print('> testing RowParallelLinear with model parallel '
245
246
              'size: {}'.format(tensor_model_parallel_size))
    tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
247
248
249
250

    seed = 12345
    set_random_seed(seed)
    input_size_coeff = 13
251
    input_size = input_size_coeff * tensor_model_parallel_size
252
    output_size_coeff = 17
253
    output_size = output_size_coeff * tensor_model_parallel_size
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    batch_size = 7

    # Network
    identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
    linear_layer = mpu.RowParallelLinear(
        input_size, output_size, keep_master_weight_for_test=True).cuda()
    loss_weight = torch.randn([batch_size, output_size]).cuda()
    # Forward
    input_ = identity_layer()
    output = linear_layer(input_)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

    # Values.
    dLdY = loss_weight
    X = identity_layer.weight
    A = linear_layer.master_weight.cuda()
    dLdA = torch.matmul(dLdY.t(), X)
    dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
    dLdX = torch.matmul(dLdY, A)

276
    rank = mpu.get_tensor_model_parallel_rank()
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    my_dLdA = torch.split(dLdA, input_size_coeff,
                          dim=1)[rank].contiguous().clone()
    error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdA on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = dLdb.sub(linear_layer.bias.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdb on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    error = dLdX.sub(identity_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   error in dLdX on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')


class IdentityLayer3D(torch.nn.Module):
Neel Kant's avatar
Neel Kant committed
306
    def __init__(self, m, n, k):
307
308
309
        super(IdentityLayer3D, self).__init__()
        self.weight = Parameter(torch.Tensor(m, n, k))
        torch.nn.init.xavier_normal_(self.weight)
Neel Kant's avatar
Neel Kant committed
310

311
312
313
314
    def forward(self):
        return self.weight


315
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
316
317
                            hidden_size_per_att_head, dropout_prob, batch_size,
                            sequence_length):
318
319
    mpu.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
320
321
322
323
324

    seed = 12345
    set_random_seed(seed)

    num_att_heads = num_att_heads_per_partition * \
Neel Kant's avatar
Neel Kant committed
325
        torch.distributed.get_world_size()
326
327
328
329
330
331
    hidden_size = hidden_size_per_att_head * num_att_heads

    # Network
    identity_layer = IdentityLayer3D(batch_size, sequence_length,
                                     hidden_size).cuda()
    attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
Neel Kant's avatar
Neel Kant committed
332
                                                    dropout_prob).cuda()
333
334
335
336
337
338
339
340
341
    loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
    attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
    # Forward
    input_ = identity_layer()
    output = attention_layer(input_, attention_mask)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

342
    rank = mpu.get_tensor_model_parallel_rank()
343
    mpu.destroy_model_parallel()
344
    return rank, hidden_size, tensor_model_parallel_size, loss, \
345
346
347
        attention_layer, identity_layer


348
def test_parallel_self_attention(tensor_model_parallel_size):
349
350
351

    if torch.distributed.get_rank() == 0:
        print('> testing ParallelSelfAttention with model parallel '
352
              'size: {}'.format(tensor_model_parallel_size))
353
354
355

    num_att_heads_per_partition = 3
    hidden_size_per_att_head = 7
Neel Kant's avatar
Neel Kant committed
356
    dropout_prob = 0.0  # has to be zero
357
358
359
    batch_size = 5
    sequence_length = 13

360
    rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
Neel Kant's avatar
Neel Kant committed
361
        attention_layer_1, identity_layer_1 = parallel_self_attention(
362
363
364
            1, num_att_heads_per_partition,
            hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)

365
    rank, hidden_size, tensor_model_parallel_size, loss, \
Neel Kant's avatar
Neel Kant committed
366
        attention_layer, identity_layer = parallel_self_attention(
367
            tensor_model_parallel_size, num_att_heads_per_partition,
368
369
370
371
372
373
374
375
376
377
378
            hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
    assert hideen_size_1 == hidden_size

    error = loss_1.sub(loss).abs().max()
    torch.distributed.barrier()
    print('   loss error on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 5.0e-6

    my_lin_grad_list = torch.split(
        attention_layer_1.query_key_value.weight.grad,
379
        hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
    error = my_lin_grad.sub(
        attention_layer.query_key_value.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   weight gradient error on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 5.0e-6

    error = identity_layer_1.weight.grad.sub(
        identity_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   input gradient error on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 5.0e-6

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')

Neel Kant's avatar
Neel Kant committed
399

400
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
401
402
                         hidden_size_per_att_head, batch_size, sequence_length):

403
404
    mpu.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
405
406
407
408
409

    seed = 12345
    set_random_seed(seed)

    num_att_heads = num_att_heads_per_partition * \
Neel Kant's avatar
Neel Kant committed
410
        torch.distributed.get_world_size()
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    hidden_size = hidden_size_per_att_head * num_att_heads
    intermediate_size = 4 * hidden_size

    # Network
    identity_layer = IdentityLayer3D(batch_size, sequence_length,
                                     hidden_size).cuda()
    transformer_layer = mpu.BertParallelTransformerLayer(
        hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
        torch.nn.functional.relu, 1.0e-5).cuda()

    loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
    attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
    # Forward
    input_ = identity_layer()
    output = transformer_layer(input_, attention_mask)
    loss = torch.mul(output, loss_weight).sum()
    # Backward
    loss.backward()

430
    rank = mpu.get_tensor_model_parallel_rank()
431
    mpu.destroy_model_parallel()
432
    return rank, hidden_size, tensor_model_parallel_size, loss, \
433
434
435
        transformer_layer, identity_layer


436
def test_parallel_transformer_layer(tensor_model_parallel_size):
437
438
439

    if torch.distributed.get_rank() == 0:
        print('> testing ParallelTransformerLayer with model parallel '
440
              'size: {}'.format(tensor_model_parallel_size))
441
442
443
444
445
446

    num_att_heads_per_partition = 3
    hidden_size_per_att_head = 7
    batch_size = 5
    sequence_length = 13

447
    rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
448
449
450
451
        transformer_layer_1, identity_layer_1 = parallel_transformer(
            1, num_att_heads_per_partition,
            hidden_size_per_att_head, batch_size, sequence_length)

452
    rank, hidden_size, tensor_model_parallel_size, loss, \
453
        transformer_layer, identity_layer = parallel_transformer(
454
            tensor_model_parallel_size, num_att_heads_per_partition,
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
            hidden_size_per_att_head, batch_size, sequence_length)

    error = loss_1.sub(loss).abs().max()
    torch.distributed.barrier()
    print('   loss error on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 5.0e-5, 'error: {}'.format(error)

    error = identity_layer_1.weight.grad.sub(
        identity_layer.weight.grad).abs().max()
    torch.distributed.barrier()
    print('   input gradient error on global rank {}: {}'.format(
        torch.distributed.get_rank(), error))
    assert error < 5.0e-5, 'error: {}'.format(error)

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(' >> passed the test :-)')


if __name__ == '__main__':

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    initialize_distributed()
    world_size = torch.distributed.get_world_size()

    print_separator('test initialize affine weight')
484
485
486
487
    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
        test_initialize_affine_weight(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2
488

489
490
    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
491
        print_separator('test parallel embedding')
492
493
        test_parallel_embedding(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2
494
495

    print_separator('test column-parallel linear')
496
497
498
499
    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
        test_column_parallel_linear(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2
500
501

    print_separator('test row-parallel linear')
502
503
504
505
    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
        test_row_parallel_linear(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2
506
507

    print_separator('test parallel self-attention')
508
509
510
511
    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
        test_parallel_self_attention(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2
512
513

    print_separator('test parallel transformer')
514
515
516
517
    tensor_model_parallel_size = 1
    while tensor_model_parallel_size <= world_size:
        test_parallel_transformer_layer(tensor_model_parallel_size)
        tensor_model_parallel_size *= 2