test_layers.py 65.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Test TE Paddle Layer-level APIs"""

6
import os
7
from utils import assert_allclose, is_fused_attention_supported
8
9

import paddle
Tim Moon's avatar
Tim Moon committed
10
import pytest
11

Tim Moon's avatar
Tim Moon committed
12
from transformer_engine.common.recipe import DelayedScaling
13
import transformer_engine.paddle as te
14
from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast
Tim Moon's avatar
Tim Moon committed
15

16
17
is_fp8_supported, reason = is_fp8_available()
LINEAR_CASES = [(16, 16, 32), (32, 32, 64)]
18
NORM_CASES = [(16, 32), (256, 1024)]
19
20


Tian Zheng's avatar
Tian Zheng committed
21
22
23
24
25
26
27
@pytest.fixture(autouse=True)
def setup():
    """Setup random seed before each test"""
    paddle.seed(10)
    yield


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('use_fp8', [True, False])
def test_checkpoint(use_fp8):
    """Test checkpoint save / load"""
    bs = 16
    in_features = 16
    out_features = 32
    file_name = "model.pdparams"
    input_tensor = paddle.uniform(shape=(bs, in_features), dtype='float32')
    model = te.Linear(in_features, out_features)
    model_loaded = te.Linear(in_features, out_features)
    # Populate amax_history
    with fp8_autocast(enabled=False, calibrating=True):
        _ = model(input_tensor)
    # Save model
    paddle.save(model.state_dict(), file_name)
    # Get ref output
    with fp8_autocast(enabled=use_fp8):
        out_ref = model(input_tensor)
    # Load model
    model_loaded.set_state_dict(paddle.load(file_name))
    if os.path.exists(file_name):
        os.remove(file_name)
    # Get actual output
    with fp8_autocast(enabled=use_fp8):
        out = model_loaded(input_tensor)

    assert_allclose(out, out_ref)
56
57
58
59
60
61
62
63
64
65
66
67
68
69


def calc_output_and_grad(layer, x, dy):
    """
    Calculate forward and backward pass
    """
    inp = paddle.to_tensor(x)
    inp.stop_gradient = x.stop_gradient
    y = layer(inp)
    y.backward(dy)

    return y, inp.grad if not inp.stop_gradient else None


70
71
@staticmethod
def calc_output_and_grad_ln_out(layer, x, dy, return_ln_out=False):
72
    """
73
    Calculate forward and backward pass for layernorm
74
    """
75
76
77
78
79
80
81
82
83
    inp = paddle.to_tensor(x)
    inp.stop_gradient = x.stop_gradient
    outputs = layer(inp)
    ln_out = None
    if return_ln_out:
        y, ln_out = outputs
    else:
        y = outputs
    y.backward(dy)
84

85
    return y, ln_out, inp.grad if not inp.stop_gradient else None
86
87


88
89
90
91
class TestLinear:
    """
    Tests for Linear layer
    """
92

93
94
95
96
97
98
99
100
101
102
103
104
105
    @staticmethod
    @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
                        reason="BF16 Linear requires Ampere+ GPU")
    @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
    @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize('no_dgrad', [True, False])
    @pytest.mark.parametrize('no_wgrad', [True, False])
    @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
    def test_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad,
                         activation_dtype):
        """
        Test BF16 Linear
        """
106
107
        rtol = 5e-2
        atol = 5e-2
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

        input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)

        paddle.set_default_dtype(activation_dtype)
        layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False)
        layer_pd = te.Linear(in_features,
                             out_features,
                             bias_attr=None if has_bias else False,
                             backend='paddle')
        layer_pd.weight.copy_(layer_te.weight.T, True)
        if has_bias:
            layer_pd.bias.copy_(layer_te.bias, True)

        layer_te.weight.stop_gradient = no_wgrad
        layer_pd.weight.stop_gradient = no_wgrad
        if has_bias:
            layer_te.bias.stop_gradient = no_dbias
            layer_pd.bias.stop_gradient = no_dbias

        out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
        out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
        if has_bias and not no_dbias:
            assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
    @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize('no_dgrad', [True, False])
    @pytest.mark.parametrize('no_wgrad', [True, False])
    @pytest.mark.parametrize('fp8_wgrad', [True, False])
    @pytest.mark.parametrize('do_calibration', [True, False])
    @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
    def test_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad,
                        fp8_wgrad, do_calibration, activation_dtype):
        """
        Test FP8 Linear
        """
        rtol = 0.1
        atol = 0.5

        input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)

        recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))

        paddle.set_default_dtype(activation_dtype)
        layer_te = te.Linear(
            in_features=in_features,
            out_features=out_features,
            bias_attr=None if has_bias else False,
        )
        layer_pd = te.Linear(
            in_features=in_features,
            out_features=out_features,
            bias_attr=None if has_bias else False,
            backend='paddle',
        )
        layer_pd.weight.copy_(layer_te.weight.T, True)
        if has_bias:
            layer_pd.bias.copy_(layer_te.bias, True)

        layer_te.weight.stop_gradient = no_wgrad
        layer_pd.weight.stop_gradient = no_wgrad
        if has_bias:
            layer_te.bias.stop_gradient = no_dbias
            layer_pd.bias.stop_gradient = no_dbias

        with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration,
                          fp8_recipe=recipe):
            out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
            out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
        if has_bias and not no_dbias:
            assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
        if do_calibration:
            assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
199

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
    @pytest.mark.parametrize('activation_dtype', ['bfloat16'])
    @pytest.mark.parametrize('num_microbatch', [8])
    def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch):
        """
        Test FP8 Linear
        """
        rtol = 0.1
        atol = 0.1

        recipe = DelayedScaling()

        paddle.set_default_dtype(activation_dtype)
        layer_cached = te.Linear(
            in_features=in_features,
            out_features=out_features,
        )
        layer_normal = te.Linear(
            in_features=in_features,
            out_features=out_features,
        )
        layer_cached.weight.copy_(layer_normal.weight, True)
        layer_cached.bias.copy_(layer_normal.bias, True)

        for iteration in range(num_microbatch):
            input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
            grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
                out.backward(grad_out)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out_ref = layer_normal(input_tensor)
                out_ref.backward(grad_out)

            assert_allclose(out, out_ref, rtol=rtol, atol=atol)
            assert_allclose(layer_cached.weight.grad,
                            layer_normal.weight.grad,
                            rtol=rtol,
                            atol=atol)

244
245

@pytest.mark.parametrize('bs,hidden_size', NORM_CASES)
246
247
248
249
250
@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize('no_dgrad', [True, False])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype):
251
252
253
254
255
256
257
    """
    Test BF16 LayerNorm
    """
    eps = 1e-3
    rtol = 1e-2
    atol = 1e-2

258
259
260
    x = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
    x.stop_gradient = no_dgrad
    grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
261

262
263
264
265
266
267
    paddle.set_default_dtype(activation_dtype)
    layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False)
    layer_pd = te.LayerNorm(hidden_size=hidden_size,
                            eps=eps,
                            bias_attr=None if has_bias else False,
                            backend='paddle')
268
    layer_pd.weight.copy_(layer_te.weight, True)
269
270
271
272
273
274
275
276
    if has_bias:
        layer_pd.bias.copy_(layer_te.bias, True)

    layer_te.weight.stop_gradient = no_wgrad
    layer_pd.weight.stop_gradient = no_wgrad
    if has_bias:
        layer_te.bias.stop_gradient = no_dbias
        layer_pd.bias.stop_gradient = no_dbias
277
278
279
280
281

    out_ref, grad_input_ref = calc_output_and_grad(layer_pd, x, grad_out)
    out, grad_input = calc_output_and_grad(layer_te, x, grad_out)

    assert_allclose(out, out_ref, rtol=rtol, atol=atol)
282
283
284
285
286
287
    if not no_dgrad:
        assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
    if not no_wgrad:
        assert_allclose(layer_te.weight.grad, layer_pd.weight.grad, rtol=rtol, atol=atol)
    if has_bias and not no_dbias:
        assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
288
289


290
class TestLayerNormLinear:
291
    """
292
    Tests for LayerNormLinear layer
293
294
    """

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    @staticmethod
    @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
                        reason="BF16 Linear requires Ampere+ GPU")
    @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
    @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize('no_dgrad', [True, False])
    @pytest.mark.parametrize('no_wgrad', [True, False])
    @pytest.mark.parametrize('return_ln_out', [True, False])
    @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
    def test_layernorm_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad,
                                   no_wgrad, return_ln_out, activation_dtype):
        """
        Test BF16 LayerNormLinear Layer
        """
        paddle.set_default_dtype(activation_dtype)
310
311
        rtol = 5e-2
        atol = 5e-2
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454

        input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
        eps = 1e-3

        layer_te = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
        )

        layer_pd = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
            backend='paddle',
        )

        layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
        layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
        layer_pd.weight.copy_(layer_te.weight.T, True)
        if has_bias:
            layer_pd.bias.copy_(layer_te.bias, True)

        layer_te.weight.stop_gradient = no_wgrad
        layer_te.ln_weight.stop_gradient = no_wgrad
        layer_te.ln_bias.stop_gradient = no_dbias
        layer_pd.weight.stop_gradient = no_wgrad
        layer_pd.ln_weight.stop_gradient = no_wgrad
        layer_pd.ln_bias.stop_gradient = no_dbias
        if has_bias:
            layer_te.bias.stop_gradient = no_dbias
            layer_pd.bias.stop_gradient = no_dbias

        out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
            layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
        out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
                                                              input_tensor,
                                                              grad_out,
                                                              return_ln_out=return_ln_out)

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
            assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
        if not no_dbias:
            assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
            if has_bias:
                assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
        if return_ln_out:
            assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
    @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize('no_dgrad', [True, False])
    @pytest.mark.parametrize('no_wgrad', [True, False])
    @pytest.mark.parametrize('fp8_wgrad', [True, False])
    @pytest.mark.parametrize('do_calibration', [True, False])
    @pytest.mark.parametrize('return_ln_out', [True, False])
    @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
    def test_layernorm_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad,
                                  no_wgrad, fp8_wgrad, do_calibration, return_ln_out,
                                  activation_dtype):
        """
        Test FP8 LayerNormLinear Layer
        """
        paddle.set_default_dtype(activation_dtype)
        rtol = 0.1
        atol = 0.75

        input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
        eps = 1e-3

        recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))

        layer_te = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
        )

        layer_pd = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
            backend='paddle',
        )

        layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
        layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
        layer_pd.weight.copy_(layer_te.weight.T, True)
        if has_bias:
            layer_pd.bias.copy_(layer_te.bias, True)

        layer_te.weight.stop_gradient = no_wgrad
        layer_te.ln_weight.stop_gradient = no_wgrad
        layer_te.ln_bias.stop_gradient = no_dbias
        layer_pd.weight.stop_gradient = no_wgrad
        layer_pd.ln_weight.stop_gradient = no_wgrad
        layer_pd.ln_bias.stop_gradient = no_dbias
        if has_bias:
            layer_te.bias.stop_gradient = no_dbias
            layer_pd.bias.stop_gradient = no_dbias

        with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration,
                          fp8_recipe=recipe):
            out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
                layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
            out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
                                                                  input_tensor,
                                                                  grad_out,
                                                                  return_ln_out=return_ln_out)

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
            assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
        if not no_dbias:
            assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
            if has_bias:
                assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
        if return_ln_out:
            assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
        if do_calibration:
            assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0

455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
    @pytest.mark.parametrize('activation_dtype', ['bfloat16'])
    @pytest.mark.parametrize('num_microbatch', [8])
    def test_layernorm_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype,
                                             num_microbatch):
        """
        Test FP8 LayerNormLinear Layer
        """
        paddle.set_default_dtype(activation_dtype)
        eps = 1e-3
        rtol = 0.5
        atol = 0.5

        recipe = DelayedScaling()

        layer_cached = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
        )

        layer_normal = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
        )

        layer_cached.ln_weight.copy_(layer_normal.ln_weight, True)
        layer_cached.ln_bias.copy_(layer_normal.ln_bias, True)
        layer_cached.weight.copy_(layer_normal.weight, True)
        layer_cached.bias.copy_(layer_normal.bias, True)

        for iteration in range(num_microbatch):
            input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
            grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
                out.backward(grad_out)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out_ref = layer_normal(input_tensor)
                out_ref.backward(grad_out)

            assert_allclose(out, out_ref, rtol=rtol, atol=atol)
            assert_allclose(layer_cached.weight.grad,
                            layer_normal.weight.grad,
                            rtol=rtol,
                            atol=atol)
            assert_allclose(layer_cached.ln_weight.grad,
                            layer_normal.ln_weight.grad,
                            rtol=rtol,
                            atol=atol)

511
512

class TestLayerNormMLP:
513
    """
514
    Test LayerNormMLP Layer
515
516
    """

517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
    @staticmethod
    @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
                        reason="BF16 Linear requires Ampere+ GPU")
    @pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
    @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize('no_dgrad', [True, False])
    @pytest.mark.parametrize('no_wgrad', [True, False])
    @pytest.mark.parametrize('return_ln_out', [True, False])
    @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
    def test_layernorm_mlp_bf16(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad,
                                no_wgrad, return_ln_out, activation_dtype):
        """
        Tests for TestLayerNormMLP layer
        """
        paddle.set_default_dtype(activation_dtype)
        rtol = 5e-2
        atol = 5e-2

        input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
        eps = 1e-3

        layer_te = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
        )
        layer_pd = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
            backend='paddle',
        )
        layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
        layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
        layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True)
        layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True)
        if has_bias:
            layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True)
            layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True)

        layer_te.fc1_weight.stop_gradient = no_wgrad
        layer_te.fc2_weight.stop_gradient = no_wgrad
        layer_te.ln_weight.stop_gradient = no_wgrad
        layer_te.ln_bias.stop_gradient = no_dbias
        layer_pd.fc1_weight.stop_gradient = no_wgrad
        layer_pd.fc2_weight.stop_gradient = no_wgrad
        layer_pd.ln_weight.stop_gradient = no_wgrad
        layer_pd.ln_bias.stop_gradient = no_dbias
        if has_bias:
            layer_te.fc1_bias.stop_gradient = no_dbias
            layer_te.fc2_bias.stop_gradient = no_dbias
            layer_pd.fc1_bias.stop_gradient = no_dbias
            layer_pd.fc2_bias.stop_gradient = no_dbias

        out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
            layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
        out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
                                                              input_tensor,
                                                              grad_out,
                                                              return_ln_out=return_ln_out)

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
            assert_allclose(layer_te.fc1_weight.grad,
                            layer_pd.fc1_weight.grad.T,
                            rtol=rtol,
                            atol=atol)
            assert_allclose(layer_te.fc2_weight.grad,
                            layer_pd.fc2_weight.grad.T,
                            rtol=rtol,
                            atol=atol)
        if not no_dbias:
            assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
            if has_bias:
                assert_allclose(layer_te.fc1_bias.grad,
                                layer_pd.fc1_bias.grad,
                                rtol=rtol,
                                atol=atol)
                assert_allclose(layer_te.fc2_bias.grad,
                                layer_pd.fc2_bias.grad,
                                rtol=rtol,
                                atol=atol)
        if return_ln_out:
            assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
    @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize('no_dgrad', [True, False])
    @pytest.mark.parametrize('no_wgrad', [True, False])
    @pytest.mark.parametrize('fp8_wgrad', [True, False])
    @pytest.mark.parametrize('do_calibration', [True, False])
    @pytest.mark.parametrize('return_ln_out', [True, False])
    @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
    def test_layernorm_mlp_fp8(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad,
                               no_wgrad, fp8_wgrad, do_calibration, return_ln_out,
                               activation_dtype):
        """
        Test FP8 LayerNormMLP Layer
        """
        paddle.set_default_dtype(activation_dtype)
        rtol = 0.1
        atol = 0.75

        input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
        eps = 1e-3

        recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))

        layer_te = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
        )

        layer_pd = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
            backend='paddle',
        )
        layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
        layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
        layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True)
        layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True)
        if has_bias:
            layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True)
            layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True)

        layer_te.fc1_weight.stop_gradient = no_wgrad
        layer_te.fc2_weight.stop_gradient = no_wgrad
        layer_te.ln_weight.stop_gradient = no_wgrad
        layer_te.ln_bias.stop_gradient = no_dbias
        layer_pd.fc1_weight.stop_gradient = no_wgrad
        layer_pd.fc2_weight.stop_gradient = no_wgrad
        layer_pd.ln_weight.stop_gradient = no_wgrad
        layer_pd.ln_bias.stop_gradient = no_dbias
        if has_bias:
            layer_te.fc1_bias.stop_gradient = no_dbias
            layer_te.fc2_bias.stop_gradient = no_dbias
            layer_pd.fc1_bias.stop_gradient = no_dbias
            layer_pd.fc2_bias.stop_gradient = no_dbias

        with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration,
                          fp8_recipe=recipe):
            out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
                layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
            out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
                                                                  input_tensor,
                                                                  grad_out,
                                                                  return_ln_out=return_ln_out)

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
            assert_allclose(layer_te.fc1_weight.grad,
                            layer_pd.fc1_weight.grad.T,
                            rtol=rtol,
                            atol=atol)
            assert_allclose(layer_te.fc2_weight.grad,
                            layer_pd.fc2_weight.grad.T,
                            rtol=rtol,
                            atol=atol)
        if not no_dbias:
            assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
            if has_bias:
                assert_allclose(layer_te.fc1_bias.grad,
                                layer_pd.fc1_bias.grad,
                                rtol=rtol,
                                atol=atol)
                assert_allclose(layer_te.fc2_bias.grad,
                                layer_pd.fc2_bias.grad,
                                rtol=rtol,
                                atol=atol)
        if return_ln_out:
            assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)

        if do_calibration:
            assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
Shijie's avatar
Shijie committed
714

715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
    @pytest.mark.parametrize('activation_dtype', ['bfloat16'])
    @pytest.mark.parametrize('num_microbatch', [8])
    def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activation_dtype,
                                          num_microbatch):
        """
        Test FP8 LayerNormMLP Layer
        """
        paddle.set_default_dtype(activation_dtype)
        rtol = 1e-5
        atol = 1e-5
        eps = 1e-3

        recipe = DelayedScaling()

        layer_cached = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
        )

        layer_normal = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
        )
        layer_normal.ln_weight.copy_(layer_cached.ln_weight, True)
        layer_normal.ln_bias.copy_(layer_cached.ln_bias, True)
        layer_normal.fc1_weight.copy_(layer_cached.fc1_weight, True)
        layer_normal.fc2_weight.copy_(layer_cached.fc2_weight, True)
        layer_normal.fc1_bias.copy_(layer_cached.fc1_bias, True)
        layer_normal.fc2_bias.copy_(layer_cached.fc2_bias, True)

        # Calibration to make sure weight scale is the same
        input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
        with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
            _ = layer_cached(input_tensor)

        with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
            _ = layer_normal(input_tensor)

        for iteration in range(num_microbatch):
            input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
            grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
                out.backward(grad_out)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out_ref = layer_normal(input_tensor)
                out_ref.backward(grad_out)

            assert_allclose(out, out_ref, rtol=rtol, atol=atol)
            assert_allclose(layer_cached.ln_weight.grad,
                            layer_normal.ln_weight.grad,
                            rtol=rtol,
                            atol=atol)
            assert_allclose(layer_cached.fc1_weight.grad,
                            layer_normal.fc1_weight.grad,
                            rtol=rtol,
                            atol=atol)
            assert_allclose(layer_cached.fc2_weight.grad,
                            layer_normal.fc2_weight.grad,
                            rtol=rtol,
                            atol=atol)

Shijie's avatar
Shijie committed
784
785
786

@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
Shijie's avatar
Shijie committed
787
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
Shijie's avatar
Shijie committed
788
789
790
791
792
793
794
795
796
797
798
799
@pytest.mark.parametrize('attn_type', ['self', 'cross'])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type,
                               mask_type, math_dtype):
    """
    Test DotProductAttention Layer
    """
    paddle.set_default_dtype(math_dtype)
    rtol = 1e-4
    atol = 2e-2
    head_size = hidden_size // num_heads
Tim Moon's avatar
Tim Moon committed
800
801
802

    # Skip if cuDNN fused attention is not supported
    if not is_fused_attention_supported(
803
804
805
806
807
808
809
            num_heads=num_heads,
            num_gqa_groups=num_heads,
            q_seqlen=q_seqlen,
            kv_seqlen=kv_seqlen,
            head_size=head_size,
            dtype=math_dtype,
            dropout=0.0,
Shijie's avatar
Shijie committed
810
            qkv_layout="bshd_bshd_bshd",
811
812
            bias_type="no_bias",
            mask_type=mask_type,
Tim Moon's avatar
Tim Moon committed
813
814
815
    ):
        pytest.skip("cuDNN fused attention is not supported")

Shijie's avatar
Shijie committed
816
817
818
819
820
821
    attn_q_input = paddle.normal(mean=0.0, std=0.02,
                                 shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype)
    attn_k_input = paddle.normal(mean=0.0, std=0.02,
                                 shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
    attn_v_input = paddle.normal(mean=0.0, std=0.02,
                                 shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
Shijie's avatar
Shijie committed
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836

    q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32')
    kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,),
                                      dtype='int32') if attn_type == 'cross' else q_actual_seqlen
    attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')

    grad_out = paddle.normal(mean=0.0, std=0.02,
                             shape=(bs, q_seqlen, num_heads, head_size)).astype('float32')
    for i in range(0, bs):
        grad_out[i, q_actual_seqlen[i]:, :, :] = 0
    grad_out = grad_out.astype(math_dtype)

    for i in range(0, bs):
        attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False

Shijie's avatar
Shijie committed
837
838
839
    head_size = hidden_size // num_heads
    layer_te = te.DotProductAttention(num_heads,
                                      head_size,
Shijie's avatar
Shijie committed
840
841
842
843
                                      attention_dropout=0.0,
                                      attn_mask_type=mask_type,
                                      attention_type=attn_type,
                                      backend='transformer_engine')
Shijie's avatar
Shijie committed
844
845
    layer_pd = te.DotProductAttention(num_heads,
                                      head_size,
Shijie's avatar
Shijie committed
846
847
848
849
850
                                      attention_dropout=0.0,
                                      attn_mask_type=mask_type,
                                      attention_type=attn_type,
                                      backend='paddle')

Shijie's avatar
Shijie committed
851
    def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
Shijie's avatar
Shijie committed
852
        _q = paddle.to_tensor(q, stop_gradient=False)
Shijie's avatar
Shijie committed
853
854
        _k = paddle.to_tensor(k, stop_gradient=False)
        _v = paddle.to_tensor(v, stop_gradient=False)
Shijie's avatar
Shijie committed
855

Shijie's avatar
Shijie committed
856
        out = layer(_q, _k, _v, mask)
Shijie's avatar
Shijie committed
857
        out.backward(dout)
Shijie's avatar
Shijie committed
858
        return out, _q.grad, _k.grad, _v.grad
Shijie's avatar
Shijie committed
859

Shijie's avatar
Shijie committed
860
861
862
863
864
865
866
    out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input,
                                                            attn_v_input, attn_mask, grad_out)
    out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
        layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out)
    valid_out_ref = paddle.full_like(out_ref, 0)
    for i in range(0, bs):
        valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
Shijie's avatar
Shijie committed
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884

    valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
    valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
    valid_v_grad_ref = paddle.full_like(v_grad_ref, 0)
    for i in range(0, bs):
        valid_q_grad_ref[i, 0:q_actual_seqlen[i], :, :] = q_grad_ref[i, 0:q_actual_seqlen[i], :, :]
        valid_k_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = k_grad_ref[i,
                                                                      0:kv_actual_seqlen[i], :, :]
        valid_v_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = v_grad_ref[i,
                                                                      0:kv_actual_seqlen[i], :, :]

    assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol)
    assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol)
    assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol)
    assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize('bs', [1, 2, 8])
Shijie's avatar
Shijie committed
885
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
Shijie's avatar
Shijie committed
886
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
Shijie's avatar
Shijie committed
887
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
Shijie's avatar
Shijie committed
888
889
890
891
892
893
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
Shijie's avatar
Shijie committed
894
895
896
def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
                                   has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
                                   math_dtype, output_layernorm, return_layernorm_output):
Shijie's avatar
Shijie committed
897
898
899
900
901
902
903
904
    """
    Test Transformer Encoder Layer
    """
    paddle.set_default_dtype(math_dtype)
    rtol = 5e-2
    atol = 5e-2
    eps = 1e-3

Tim Moon's avatar
Tim Moon committed
905
906
    # Skip if cuDNN fused attention is not supported
    if not is_fused_attention_supported(
907
            num_heads=num_heads,
Shijie's avatar
Shijie committed
908
            num_gqa_groups=num_gqa_groups,
909
910
911
912
913
            q_seqlen=q_seqlen,
            kv_seqlen=kv_seqlen,
            head_size=hidden_size // num_heads,
            dtype=math_dtype,
            dropout=0.0,
Shijie's avatar
Shijie committed
914
            qkv_layout="bshd_bshd_bshd",
915
916
            bias_type="no_bias",
            mask_type=mask_type,
Tim Moon's avatar
Tim Moon committed
917
918
919
    ):
        pytest.skip("cuDNN fused attention is not supported")

Shijie's avatar
Shijie committed
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
    encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)

    q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
    kv_actual_seqlen = q_actual_seqlen
    attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')

    grad_out = paddle.normal(mean=0.0, std=0.02,
                             shape=(bs, q_seqlen, hidden_size)).astype('float32')
    for i in range(0, bs):
        grad_out[i, q_actual_seqlen[i]:, :] = 0
    grad_out = grad_out.astype(math_dtype)

    for i in range(0, bs):
        attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False

    layer_te = te.TransformerLayer(hidden_size,
                                   ffn_hidden_size,
                                   num_heads,
Shijie's avatar
Shijie committed
938
                                   num_gqa_groups=num_gqa_groups,
Shijie's avatar
Shijie committed
939
940
941
942
943
944
945
946
947
948
949
950
951
                                   layernorm_epsilon=eps,
                                   hidden_dropout=0.0,
                                   attention_dropout=0.0,
                                   weight_attr=None,
                                   bias_attr=None if has_bias else False,
                                   self_attn_mask_type=mask_type,
                                   apply_residual_connection_post_layernorm=return_layernorm_output,
                                   output_layernorm=output_layernorm,
                                   layer_type='encoder',
                                   backend='transformer_engine')
    layer_pd = te.TransformerLayer(hidden_size,
                                   ffn_hidden_size,
                                   num_heads,
Shijie's avatar
Shijie committed
952
                                   num_gqa_groups=num_gqa_groups,
Shijie's avatar
Shijie committed
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
                                   layernorm_epsilon=eps,
                                   hidden_dropout=0.0,
                                   attention_dropout=0.0,
                                   weight_attr=None,
                                   bias_attr=None if has_bias else False,
                                   self_attn_mask_type=mask_type,
                                   apply_residual_connection_post_layernorm=return_layernorm_output,
                                   output_layernorm=output_layernorm,
                                   layer_type='encoder',
                                   backend='paddle')

    # MultiHeadAttention params
    if output_layernorm:
        layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True)
        layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad
        layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad
        if has_bias:
            layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True)
            layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias
            layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
    else:
        layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
            layer_te.self_attention.layernorm_qkv.ln_weight, True)
        layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
            layer_te.self_attention.layernorm_qkv.ln_bias, True)
        layer_pd.self_attention.layernorm_qkv.weight.copy_(
            layer_te.self_attention.layernorm_qkv.weight.T, True)
        layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
        layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
        layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
        layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
        layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
        layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
        if has_bias:
            layer_pd.self_attention.layernorm_qkv.bias.copy_(
                layer_te.self_attention.layernorm_qkv.bias, True)
            layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
            layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias

    layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True)
    layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad
    layer_te.self_attention.proj.weight.stop_gradient = no_wgrad
    if has_bias:
        layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True)
        layer_pd.self_attention.proj.bias.stop_gradient = no_dbias
        layer_te.self_attention.proj.bias.stop_gradient = no_dbias

    # LayerNorm MLP params
    layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True)
    layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
    layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True)
    layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True)
    layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
    layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
    layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
    layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
    layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
    if has_bias:
        layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True)
        layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True)
        layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
        layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
        layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
        layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias

    if output_layernorm:
        layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True)
        layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True)
        layer_pd.layernorm.weight.stop_gradient = no_wgrad
        layer_pd.layernorm.bias.stop_gradient = no_dbias
        layer_te.layernorm.weight.stop_gradient = no_wgrad
        layer_te.layernorm.bias.stop_gradient = no_dbias

    def calc_transformer_output_and_grad(layer, encoder_input, mask, dout):
        _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
        out = layer(_encoder_input, mask)
        out.backward(dout)
        return out, _encoder_input.grad

    out_ref, grad_input_ref = calc_transformer_output_and_grad(layer_pd, encoder_input, attn_mask,
                                                               grad_out)
    out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out)

    assert_allclose(out, out_ref, rtol=rtol, atol=atol)
    assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
    if not no_wgrad:
        if output_layernorm:
            assert_allclose(layer_te.self_attention.qkv.weight.grad,
                            layer_pd.self_attention.qkv.weight.grad.T,
                            rtol=rtol,
                            atol=atol)
        else:
            assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
                            layer_pd.self_attention.layernorm_qkv.weight.grad.T,
                            rtol=rtol,
                            atol=atol)
    if not no_dbias:
        if output_layernorm:
            assert_allclose(layer_te.self_attention.qkv.bias.grad,
                            layer_pd.self_attention.qkv.bias.grad,
                            rtol=0.01,
                            atol=0.5)
        else:
            assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
                            layer_pd.self_attention.layernorm_qkv.bias.grad,
                            rtol=0.01,
                            atol=0.5)


@pytest.mark.parametrize('bs', [1, 2, 8])
Shijie's avatar
Shijie committed
1066
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
Shijie's avatar
Shijie committed
1067
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
Shijie's avatar
Shijie committed
1068
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
Shijie's avatar
Shijie committed
1069
1070
1071
1072
1073
1074
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
Tian Zheng's avatar
Tian Zheng committed
1075
@pytest.mark.parametrize('recompute_core_attention', [True, False])
Shijie's avatar
Shijie committed
1076
1077
1078
def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
                                   has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
                                   math_dtype, output_layernorm, return_layernorm_output,
Tian Zheng's avatar
Tian Zheng committed
1079
                                   recompute_core_attention):
Shijie's avatar
Shijie committed
1080
1081
1082
1083
1084
    """
    Test Transformer Decoder Layer
    """
    paddle.set_default_dtype(math_dtype)
    rtol = 5e-2
1085
    atol = 6e-2
Shijie's avatar
Shijie committed
1086
1087
    eps = 1e-3

Tim Moon's avatar
Tim Moon committed
1088
1089
    # Skip if cuDNN fused attention is not supported
    if not is_fused_attention_supported(
1090
            num_heads=num_heads,
Shijie's avatar
Shijie committed
1091
            num_gqa_groups=num_gqa_groups,
1092
1093
1094
1095
1096
            q_seqlen=q_seqlen,
            kv_seqlen=kv_seqlen,
            head_size=hidden_size // num_heads,
            dtype=math_dtype,
            dropout=0.0,
Shijie's avatar
Shijie committed
1097
            qkv_layout="bshd_bshd_bshd",
1098
1099
            bias_type="no_bias",
            mask_type=mask_type,
Tim Moon's avatar
Tim Moon committed
1100
1101
1102
    ):
        pytest.skip("cuDNN fused attention is not supported")

Shijie's avatar
Shijie committed
1103
1104
1105
1106
    encoder_input = paddle.normal(mean=0.0, std=0.1,
                                  shape=(bs, q_seqlen, hidden_size)).astype(math_dtype)
    encoder_output = paddle.normal(mean=0.0, std=0.1,
                                   shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype)
Shijie's avatar
Shijie committed
1107
1108
1109
1110
1111

    q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
    kv_actual_seqlen = q_actual_seqlen
    attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')

Shijie's avatar
Shijie committed
1112
1113
1114
1115
1116
1117
1118
1119
    grad_out = paddle.normal(mean=0.0, std=0.01,
                             shape=(bs, q_seqlen, hidden_size)).astype('float32')

    # rounding to avoid numerical issues
    encoder_input = paddle.round(encoder_input * 1000) / 1000
    encoder_output = paddle.round(encoder_output * 1000) / 1000
    grad_out = paddle.round(grad_out * 1000) / 1000

Shijie's avatar
Shijie committed
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
    for i in range(0, bs):
        grad_out[i, q_actual_seqlen[i]:, :] = 0
    grad_out = grad_out.astype(math_dtype)

    for i in range(0, bs):
        attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False

    layer_te = te.TransformerLayer(hidden_size,
                                   ffn_hidden_size,
                                   num_heads,
Shijie's avatar
Shijie committed
1130
                                   num_gqa_groups=num_gqa_groups,
Shijie's avatar
Shijie committed
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
                                   layernorm_epsilon=eps,
                                   hidden_dropout=0.0,
                                   attention_dropout=0.0,
                                   weight_attr=None,
                                   bias_attr=None if has_bias else False,
                                   self_attn_mask_type=mask_type,
                                   apply_residual_connection_post_layernorm=return_layernorm_output,
                                   output_layernorm=output_layernorm,
                                   layer_type='decoder',
                                   backend='transformer_engine')
    layer_pd = te.TransformerLayer(hidden_size,
                                   ffn_hidden_size,
                                   num_heads,
Shijie's avatar
Shijie committed
1144
                                   num_gqa_groups=num_gqa_groups,
Shijie's avatar
Shijie committed
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
                                   layernorm_epsilon=eps,
                                   hidden_dropout=0.0,
                                   attention_dropout=0.0,
                                   weight_attr=None,
                                   bias_attr=None if has_bias else False,
                                   self_attn_mask_type=mask_type,
                                   apply_residual_connection_post_layernorm=return_layernorm_output,
                                   output_layernorm=output_layernorm,
                                   layer_type='decoder',
                                   backend='paddle')

    # MultiHeadAttention params - self attn
    if output_layernorm:
        layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True)
        layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad
        layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad
        if has_bias:
            layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True)
            layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias
            layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
    else:
        layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
            layer_te.self_attention.layernorm_qkv.ln_weight, True)
        layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
            layer_te.self_attention.layernorm_qkv.ln_bias, True)
        layer_pd.self_attention.layernorm_qkv.weight.copy_(
            layer_te.self_attention.layernorm_qkv.weight.T, True)
        layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
        layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
        layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
        layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
        layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
        layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
        if has_bias:
            layer_pd.self_attention.layernorm_qkv.bias.copy_(
                layer_te.self_attention.layernorm_qkv.bias, True)
            layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
            layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias

    layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True)
    layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad
    layer_te.self_attention.proj.weight.stop_gradient = no_wgrad
    if has_bias:
        layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True)
        layer_pd.self_attention.proj.bias.stop_gradient = no_dbias
        layer_te.self_attention.proj.bias.stop_gradient = no_dbias

    # MultiHeadAttention params - cross attn
    layer_pd.inter_attention.layernorm_query.ln_weight.copy_(
        layer_te.inter_attention.layernorm_query.ln_weight, True)
    layer_pd.inter_attention.layernorm_query.ln_bias.copy_(
        layer_te.inter_attention.layernorm_query.ln_bias, True)
    layer_pd.inter_attention.layernorm_query.weight.copy_(
        layer_te.inter_attention.layernorm_query.weight.T, True)
    layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
    layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
    layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
    layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
    layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
    layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
    if has_bias:
        layer_pd.inter_attention.layernorm_query.bias.copy_(
            layer_te.inter_attention.layernorm_query.bias, True)
        layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
        layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias

    layer_pd.inter_attention.key_value.weight.copy_(layer_te.inter_attention.key_value.weight.T,
                                                    True)
    layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad
    layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad
    layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True)
    layer_pd.inter_attention.proj.weight.stop_gradient = no_wgrad
    layer_te.inter_attention.proj.weight.stop_gradient = no_wgrad
    if has_bias:
        layer_pd.inter_attention.key_value.bias.copy_(layer_te.inter_attention.key_value.bias, True)
        layer_pd.inter_attention.key_value.bias.stop_gradient = no_dbias
        layer_te.inter_attention.key_value.bias.stop_gradient = no_dbias
        layer_pd.inter_attention.proj.bias.copy_(layer_te.inter_attention.proj.bias, True)
        layer_pd.inter_attention.proj.bias.stop_gradient = no_dbias
        layer_te.inter_attention.proj.bias.stop_gradient = no_dbias

    # LayerNorm MLP params
    layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True)
    layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
    layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True)
    layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True)
    layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
    layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
    layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
    layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
    layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
    if has_bias:
        layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True)
        layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True)
        layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
        layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
        layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
        layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias

    if output_layernorm:
        layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True)
        layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True)
        layer_pd.layernorm.weight.stop_gradient = no_wgrad
        layer_pd.layernorm.bias.stop_gradient = no_dbias
        layer_te.layernorm.weight.stop_gradient = no_wgrad
        layer_te.layernorm.bias.stop_gradient = no_dbias

Tian Zheng's avatar
Tian Zheng committed
1255
1256
1257
1258
1259
1260
1261
    def calc_transformer_output_and_grad(layer,
                                         encoder_input,
                                         mask,
                                         encoder_output,
                                         enc_dec_attn_mask,
                                         dout,
                                         recompute_core_attention=False):
Shijie's avatar
Shijie committed
1262
1263
        _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
        _encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False)
Tian Zheng's avatar
Tian Zheng committed
1264
1265
1266
1267
1268
        out = layer(_encoder_input,
                    mask,
                    _encoder_output,
                    enc_dec_attn_mask,
                    recompute_core_attention=recompute_core_attention)
Shijie's avatar
Shijie committed
1269
1270
1271
1272
1273
1274
        out.backward(dout)
        return out, _encoder_input.grad, _encoder_output.grad

    out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad(
        layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out)
    out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad(
Tian Zheng's avatar
Tian Zheng committed
1275
1276
1277
1278
1279
1280
1281
        layer_te,
        encoder_input,
        attn_mask,
        encoder_output,
        attn_mask,
        grad_out,
        recompute_core_attention=recompute_core_attention)
Shijie's avatar
Shijie committed
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295

    assert_allclose(out, out_ref, rtol=rtol, atol=atol)
    assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol)
    assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol)
    if not no_wgrad:
        if output_layernorm:
            assert_allclose(layer_te.self_attention.qkv.weight.grad,
                            layer_pd.self_attention.qkv.weight.grad.T,
                            rtol=rtol,
                            atol=atol)
        else:
            assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
                            layer_pd.self_attention.layernorm_qkv.weight.grad.T,
                            rtol=rtol,
Shijie's avatar
Shijie committed
1296
                            atol=atol)
Shijie's avatar
Shijie committed
1297
1298
1299
1300
1301
1302
1303
1304
            assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad,
                            layer_pd.inter_attention.layernorm_query.weight.grad.T,
                            rtol=rtol,
                            atol=atol)
    if not no_dbias:
        if output_layernorm:
            assert_allclose(layer_te.self_attention.qkv.bias.grad,
                            layer_pd.self_attention.qkv.bias.grad,
Shijie's avatar
Shijie committed
1305
                            rtol=0.5,
1306
                            atol=0.6)
Shijie's avatar
Shijie committed
1307
1308
1309
1310
1311
1312
1313
1314
1315
        else:
            assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
                            layer_pd.self_attention.layernorm_qkv.bias.grad,
                            rtol=0.01,
                            atol=0.5)
            assert_allclose(layer_te.inter_attention.layernorm_query.bias.grad,
                            layer_pd.inter_attention.layernorm_query.bias.grad,
                            rtol=rtol,
                            atol=atol)
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('bs', [8])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128]])
@pytest.mark.parametrize('mask_type', ['causal'])
@pytest.mark.parametrize('math_dtype', ['bfloat16'])
@pytest.mark.parametrize('num_microbatch', [8])
def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hidden_size, q_seqlen,
                                              kv_seqlen, mask_type, math_dtype, num_microbatch):
    """
    Test Transformer Encoder Layer with FP8 weight caching
    """
    paddle.set_default_dtype(math_dtype)
    rtol = 1e-5
    atol = 1e-5
    eps = 1e-3

    # Skip if cuDNN fused attention is not supported
    if not is_fused_attention_supported(
            num_heads=num_heads,
            num_gqa_groups=num_heads,
            q_seqlen=q_seqlen,
            kv_seqlen=kv_seqlen,
            head_size=hidden_size // num_heads,
            dtype=math_dtype,
            dropout=0.0,
            qkv_layout="bs3hd",
            bias_type="no_bias",
            mask_type=mask_type,
    ):
        pytest.skip("cuDNN fused attention is not supported")

    layer_cached = te.TransformerLayer(hidden_size,
                                       ffn_hidden_size,
                                       num_heads,
                                       layernorm_epsilon=eps,
                                       hidden_dropout=0.0,
                                       attention_dropout=0.0,
                                       weight_attr=None,
                                       bias_attr=None,
                                       self_attn_mask_type=mask_type,
                                       layer_type='encoder')
    layer_normal = te.TransformerLayer(hidden_size,
                                       ffn_hidden_size,
                                       num_heads,
                                       layernorm_epsilon=eps,
                                       hidden_dropout=0.0,
                                       attention_dropout=0.0,
                                       weight_attr=None,
                                       bias_attr=None,
                                       self_attn_mask_type=mask_type,
                                       layer_type='encoder')

    layer_normal.self_attention.layernorm_qkv.ln_weight.copy_(
        layer_cached.self_attention.layernorm_qkv.ln_weight, True)
    layer_normal.self_attention.layernorm_qkv.ln_bias.copy_(
        layer_cached.self_attention.layernorm_qkv.ln_bias, True)
    layer_normal.self_attention.layernorm_qkv.weight.copy_(
        layer_cached.self_attention.layernorm_qkv.weight, True)
    layer_normal.self_attention.layernorm_qkv.bias.copy_(
        layer_cached.self_attention.layernorm_qkv.bias, True)

    layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True)
    layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True)

    # LayerNorm MLP params
    layer_normal.layernorm_mlp.ln_weight.copy_(layer_cached.layernorm_mlp.ln_weight, True)
    layer_normal.layernorm_mlp.ln_bias.copy_(layer_cached.layernorm_mlp.ln_bias, True)
    layer_normal.layernorm_mlp.fc1_weight.copy_(layer_cached.layernorm_mlp.fc1_weight, True)
    layer_normal.layernorm_mlp.fc2_weight.copy_(layer_cached.layernorm_mlp.fc2_weight, True)
    layer_normal.layernorm_mlp.fc1_bias.copy_(layer_cached.layernorm_mlp.fc1_bias, True)
    layer_normal.layernorm_mlp.fc2_bias.copy_(layer_cached.layernorm_mlp.fc2_bias, True)

    recipe = DelayedScaling()

    def generate_input():
        encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)

        q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
        kv_actual_seqlen = q_actual_seqlen
        attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')

        grad_out = paddle.normal(mean=0.0, std=0.02,
                                 shape=(bs, q_seqlen, hidden_size)).astype('float32')
        for i in range(0, bs):
            grad_out[i, q_actual_seqlen[i]:, :] = 0
        grad_out = grad_out.astype(math_dtype)

        for i in range(0, bs):
            attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False

        return encoder_input, attn_mask, grad_out

    # Calibration to make sure weight scale is the same
    encoder_input, mask, _ = generate_input()
    with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
        _ = layer_cached(encoder_input, mask)

    with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
        _ = layer_normal(encoder_input, mask)

    for iteration in range(num_microbatch):
        encoder_input, mask, grad_out = generate_input()

        with fp8_autocast(enabled=True, fp8_recipe=recipe):
            out = layer_cached(encoder_input, mask, is_first_microbatch=(iteration == 0))
            out.backward(grad_out)

        with fp8_autocast(enabled=True, fp8_recipe=recipe):
            out_ref = layer_normal(encoder_input, mask)
            out_ref.backward(grad_out)

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        assert_allclose(layer_cached.self_attention.layernorm_qkv.weight.grad,
                        layer_normal.self_attention.layernorm_qkv.weight.grad,
                        rtol=rtol,
                        atol=atol)