test_layers.py 63.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Test TE Paddle Layer-level APIs"""

6
import os
7
from utils import assert_allclose, is_fused_attention_supported
8
9

import paddle
Tim Moon's avatar
Tim Moon committed
10
import pytest
11

Tim Moon's avatar
Tim Moon committed
12
from transformer_engine.common.recipe import DelayedScaling
13
import transformer_engine.paddle as te
14
from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast
Tim Moon's avatar
Tim Moon committed
15

16
17
is_fp8_supported, reason = is_fp8_available()
LINEAR_CASES = [(16, 16, 32), (32, 32, 64)]
18
NORM_CASES = [(16, 32), (256, 1024)]
19
20


Tian Zheng's avatar
Tian Zheng committed
21
22
23
24
25
26
27
@pytest.fixture(autouse=True)
def setup():
    """Setup random seed before each test"""
    paddle.seed(10)
    yield


28
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
29
@pytest.mark.parametrize("use_fp8", [True, False])
30
31
32
33
34
35
def test_checkpoint(use_fp8):
    """Test checkpoint save / load"""
    bs = 16
    in_features = 16
    out_features = 32
    file_name = "model.pdparams"
36
    input_tensor = paddle.uniform(shape=(bs, in_features), dtype="float32")
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    model = te.Linear(in_features, out_features)
    model_loaded = te.Linear(in_features, out_features)
    # Populate amax_history
    with fp8_autocast(enabled=False, calibrating=True):
        _ = model(input_tensor)
    # Save model
    paddle.save(model.state_dict(), file_name)
    # Get ref output
    with fp8_autocast(enabled=use_fp8):
        out_ref = model(input_tensor)
    # Load model
    model_loaded.set_state_dict(paddle.load(file_name))
    if os.path.exists(file_name):
        os.remove(file_name)
    # Get actual output
    with fp8_autocast(enabled=use_fp8):
        out = model_loaded(input_tensor)

    assert_allclose(out, out_ref)
56
57
58
59
60
61
62
63
64
65
66
67
68
69


def calc_output_and_grad(layer, x, dy):
    """
    Calculate forward and backward pass
    """
    inp = paddle.to_tensor(x)
    inp.stop_gradient = x.stop_gradient
    y = layer(inp)
    y.backward(dy)

    return y, inp.grad if not inp.stop_gradient else None


70
71
@staticmethod
def calc_output_and_grad_ln_out(layer, x, dy, return_ln_out=False):
72
    """
73
    Calculate forward and backward pass for layernorm
74
    """
75
76
77
78
79
80
81
82
83
    inp = paddle.to_tensor(x)
    inp.stop_gradient = x.stop_gradient
    outputs = layer(inp)
    ln_out = None
    if return_ln_out:
        y, ln_out = outputs
    else:
        y = outputs
    y.backward(dy)
84

85
    return y, ln_out, inp.grad if not inp.stop_gradient else None
86
87


88
89
90
91
class TestLinear:
    """
    Tests for Linear layer
    """
92

93
    @staticmethod
94
95
96
97
98
99
100
101
102
103
104
105
    @pytest.mark.skipif(
        paddle.device.cuda.get_device_capability() < (8, 0),
        reason="BF16 Linear requires Ampere+ GPU",
    )
    @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
    @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize("no_dgrad", [True, False])
    @pytest.mark.parametrize("no_wgrad", [True, False])
    @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
    def test_linear_bf16(
        bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype
    ):
106
107
108
        """
        Test BF16 Linear
        """
109
110
        rtol = 5e-2
        atol = 5e-2
111
112
113
114
115
116
117

        input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)

        paddle.set_default_dtype(activation_dtype)
        layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False)
118
119
120
        layer_pd = te.Linear(
            in_features, out_features, bias_attr=None if has_bias else False, backend="paddle"
        )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        layer_pd.weight.copy_(layer_te.weight.T, True)
        if has_bias:
            layer_pd.bias.copy_(layer_te.bias, True)

        layer_te.weight.stop_gradient = no_wgrad
        layer_pd.weight.stop_gradient = no_wgrad
        if has_bias:
            layer_te.bias.stop_gradient = no_dbias
            layer_pd.bias.stop_gradient = no_dbias

        out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
        out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
        if has_bias and not no_dbias:
            assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
    @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize("no_dgrad", [True, False])
    @pytest.mark.parametrize("no_wgrad", [True, False])
    @pytest.mark.parametrize("fp8_wgrad", [True, False])
    @pytest.mark.parametrize("do_calibration", [True, False])
    @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
    def test_linear_fp8(
        bs,
        in_features,
        out_features,
        has_bias,
        no_dbias,
        no_dgrad,
        no_wgrad,
        fp8_wgrad,
        do_calibration,
        activation_dtype,
    ):
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        """
        Test FP8 Linear
        """
        rtol = 0.1
        atol = 0.5

        input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)

        recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))

        paddle.set_default_dtype(activation_dtype)
        layer_te = te.Linear(
            in_features=in_features,
            out_features=out_features,
            bias_attr=None if has_bias else False,
        )
        layer_pd = te.Linear(
            in_features=in_features,
            out_features=out_features,
            bias_attr=None if has_bias else False,
185
            backend="paddle",
186
187
188
189
190
191
192
193
194
195
196
        )
        layer_pd.weight.copy_(layer_te.weight.T, True)
        if has_bias:
            layer_pd.bias.copy_(layer_te.bias, True)

        layer_te.weight.stop_gradient = no_wgrad
        layer_pd.weight.stop_gradient = no_wgrad
        if has_bias:
            layer_te.bias.stop_gradient = no_dbias
            layer_pd.bias.stop_gradient = no_dbias

197
198
199
        with fp8_autocast(
            enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
        ):
200
201
202
203
204
205
206
207
208
209
210
211
            out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
            out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
        if has_bias and not no_dbias:
            assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
        if do_calibration:
            assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
212

213
214
    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
215
216
217
    @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
    @pytest.mark.parametrize("activation_dtype", ["bfloat16"])
    @pytest.mark.parametrize("num_microbatch", [8])
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch):
        """
        Test FP8 Linear
        """
        rtol = 0.1
        atol = 0.1

        recipe = DelayedScaling()

        paddle.set_default_dtype(activation_dtype)
        layer_cached = te.Linear(
            in_features=in_features,
            out_features=out_features,
        )
        layer_normal = te.Linear(
            in_features=in_features,
            out_features=out_features,
        )
        layer_cached.weight.copy_(layer_normal.weight, True)
        layer_cached.bias.copy_(layer_normal.bias, True)

        for iteration in range(num_microbatch):
            input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
            grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
                out.backward(grad_out)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out_ref = layer_normal(input_tensor)
                out_ref.backward(grad_out)

            assert_allclose(out, out_ref, rtol=rtol, atol=atol)
252
253
254
            assert_allclose(
                layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol
            )
255

256

257
258
259
260
261
@pytest.mark.parametrize("bs,hidden_size", NORM_CASES)
@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
@pytest.mark.parametrize("no_dgrad", [True, False])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
262
def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype):
263
264
265
266
267
268
269
    """
    Test BF16 LayerNorm
    """
    eps = 1e-3
    rtol = 1e-2
    atol = 1e-2

270
271
272
    x = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
    x.stop_gradient = no_dgrad
    grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
273

274
275
    paddle.set_default_dtype(activation_dtype)
    layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False)
276
277
278
    layer_pd = te.LayerNorm(
        hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False, backend="paddle"
    )
279
    layer_pd.weight.copy_(layer_te.weight, True)
280
281
282
283
284
285
286
287
    if has_bias:
        layer_pd.bias.copy_(layer_te.bias, True)

    layer_te.weight.stop_gradient = no_wgrad
    layer_pd.weight.stop_gradient = no_wgrad
    if has_bias:
        layer_te.bias.stop_gradient = no_dbias
        layer_pd.bias.stop_gradient = no_dbias
288
289
290
291
292

    out_ref, grad_input_ref = calc_output_and_grad(layer_pd, x, grad_out)
    out, grad_input = calc_output_and_grad(layer_te, x, grad_out)

    assert_allclose(out, out_ref, rtol=rtol, atol=atol)
293
294
295
296
297
298
    if not no_dgrad:
        assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
    if not no_wgrad:
        assert_allclose(layer_te.weight.grad, layer_pd.weight.grad, rtol=rtol, atol=atol)
    if has_bias and not no_dbias:
        assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
299
300


301
class TestLayerNormLinear:
302
    """
303
    Tests for LayerNormLinear layer
304
305
    """

306
    @staticmethod
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    @pytest.mark.skipif(
        paddle.device.cuda.get_device_capability() < (8, 0),
        reason="BF16 Linear requires Ampere+ GPU",
    )
    @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
    @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize("no_dgrad", [True, False])
    @pytest.mark.parametrize("no_wgrad", [True, False])
    @pytest.mark.parametrize("return_ln_out", [True, False])
    @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
    @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
    def test_layernorm_linear_bf16(
        bs,
        in_features,
        out_features,
        has_bias,
        no_dbias,
        no_dgrad,
        no_wgrad,
        return_ln_out,
        activation_dtype,
        normalization,
    ):
330
331
332
333
        """
        Test BF16 LayerNormLinear Layer
        """
        paddle.set_default_dtype(activation_dtype)
334
335
        rtol = 5e-2
        atol = 5e-2
336
337
338
339
340

        input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
        eps = 1e-3
341
        has_ln_bias = normalization == "LayerNorm"
342
343
344
345
346

        layer_te = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
347
            normalization=normalization,
348
349
350
351
352
353
354
355
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
        )

        layer_pd = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
356
            normalization=normalization,
357
358
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
359
            backend="paddle",
360
361
362
        )

        layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
363
364
        if has_ln_bias:
            layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
365
366
367
368
369
370
371
372
        layer_pd.weight.copy_(layer_te.weight.T, True)
        if has_bias:
            layer_pd.bias.copy_(layer_te.bias, True)

        layer_te.weight.stop_gradient = no_wgrad
        layer_te.ln_weight.stop_gradient = no_wgrad
        layer_pd.weight.stop_gradient = no_wgrad
        layer_pd.ln_weight.stop_gradient = no_wgrad
373
374
375
        if has_ln_bias:
            layer_te.ln_bias.stop_gradient = no_dbias
            layer_pd.ln_bias.stop_gradient = no_dbias
376
377
378
379
380
        if has_bias:
            layer_te.bias.stop_gradient = no_dbias
            layer_pd.bias.stop_gradient = no_dbias

        out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
381
382
383
384
385
            layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
        )
        out, ln_out, grad_input = calc_output_and_grad_ln_out(
            layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
        )
386
387
388
389
390
391
392
393

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
            assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
        if not no_dbias:
394
395
            if has_ln_bias:
                assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
396
397
398
399
400
401
402
            if has_bias:
                assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
        if return_ln_out:
            assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
    @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize("no_dgrad", [True, False])
    @pytest.mark.parametrize("no_wgrad", [True, False])
    @pytest.mark.parametrize("fp8_wgrad", [True, False])
    @pytest.mark.parametrize("do_calibration", [True, False])
    @pytest.mark.parametrize("return_ln_out", [True, False])
    @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
    @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
    def test_layernorm_linear_fp8(
        bs,
        in_features,
        out_features,
        has_bias,
        no_dbias,
        no_dgrad,
        no_wgrad,
        fp8_wgrad,
        do_calibration,
        return_ln_out,
        activation_dtype,
        normalization,
    ):
426
427
428
429
430
431
432
433
434
435
436
        """
        Test FP8 LayerNormLinear Layer
        """
        paddle.set_default_dtype(activation_dtype)
        rtol = 0.1
        atol = 0.75

        input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
        eps = 1e-3
437
        has_ln_bias = normalization == "LayerNorm"
438
439
440
441
442
443
444

        recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))

        layer_te = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
445
            normalization=normalization,
446
447
448
449
450
451
452
453
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
        )

        layer_pd = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
454
            normalization=normalization,
455
456
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
457
            backend="paddle",
458
459
460
        )

        layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
461
462
        if has_ln_bias:
            layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
463
464
465
466
467
468
469
470
        layer_pd.weight.copy_(layer_te.weight.T, True)
        if has_bias:
            layer_pd.bias.copy_(layer_te.bias, True)

        layer_te.weight.stop_gradient = no_wgrad
        layer_te.ln_weight.stop_gradient = no_wgrad
        layer_pd.weight.stop_gradient = no_wgrad
        layer_pd.ln_weight.stop_gradient = no_wgrad
471
472
473
        if has_ln_bias:
            layer_te.ln_bias.stop_gradient = no_dbias
            layer_pd.ln_bias.stop_gradient = no_dbias
474
475
476
477
        if has_bias:
            layer_te.bias.stop_gradient = no_dbias
            layer_pd.bias.stop_gradient = no_dbias

478
479
480
        with fp8_autocast(
            enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
        ):
481
            out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
482
483
484
485
486
                layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
            )
            out, ln_out, grad_input = calc_output_and_grad_ln_out(
                layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
            )
487
488
489
490
491
492
493
494

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
            assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
        if not no_dbias:
495
496
            if has_ln_bias:
                assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
497
498
499
500
501
502
503
            if has_bias:
                assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
        if return_ln_out:
            assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
        if do_calibration:
            assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0

504
505
    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
506
507
508
509
510
511
    @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
    @pytest.mark.parametrize("activation_dtype", ["bfloat16"])
    @pytest.mark.parametrize("num_microbatch", [8])
    def test_layernorm_linear_fp8_microbatch(
        bs, in_features, out_features, activation_dtype, num_microbatch
    ):
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
        """
        Test FP8 LayerNormLinear Layer
        """
        paddle.set_default_dtype(activation_dtype)
        eps = 1e-3
        rtol = 0.5
        atol = 0.5

        recipe = DelayedScaling()

        layer_cached = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
        )

        layer_normal = te.LayerNormLinear(
            in_features=in_features,
            out_features=out_features,
            eps=eps,
        )

        layer_cached.ln_weight.copy_(layer_normal.ln_weight, True)
        layer_cached.ln_bias.copy_(layer_normal.ln_bias, True)
        layer_cached.weight.copy_(layer_normal.weight, True)
        layer_cached.bias.copy_(layer_normal.bias, True)

        for iteration in range(num_microbatch):
            input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype)
            grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
                out.backward(grad_out)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out_ref = layer_normal(input_tensor)
                out_ref.backward(grad_out)

            assert_allclose(out, out_ref, rtol=rtol, atol=atol)
552
553
554
555
556
557
            assert_allclose(
                layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol
            )
            assert_allclose(
                layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol
            )
558

559
560

class TestLayerNormMLP:
561
    """
562
    Test LayerNormMLP Layer
563
564
    """

565
    @staticmethod
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    @pytest.mark.skipif(
        paddle.device.cuda.get_device_capability() < (8, 0),
        reason="BF16 Linear requires Ampere+ GPU",
    )
    @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
    @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize("no_dgrad", [True, False])
    @pytest.mark.parametrize("no_wgrad", [True, False])
    @pytest.mark.parametrize("return_ln_out", [True, False])
    @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
    @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
    @pytest.mark.parametrize("activation", ["gelu", "swiglu"])
    def test_layernorm_mlp_bf16(
        bs,
        hidden_size,
        ffn_hidden_size,
        has_bias,
        no_dbias,
        no_dgrad,
        no_wgrad,
        return_ln_out,
        activation_dtype,
        normalization,
        activation,
    ):
591
592
593
594
595
596
597
598
599
600
601
        """
        Tests for TestLayerNormMLP layer
        """
        paddle.set_default_dtype(activation_dtype)
        rtol = 5e-2
        atol = 5e-2

        input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
        eps = 1e-3
602
        has_ln_bias = normalization == "LayerNorm"
603
604
605
606
607

        layer_te = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
608
609
            normalization=normalization,
            activation=activation,
610
611
612
613
614
615
616
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
        )
        layer_pd = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
617
618
            normalization=normalization,
            activation=activation,
619
620
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
621
            backend="paddle",
622
623
        )
        layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
624
625
        if has_ln_bias:
            layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
626
627
628
629
630
631
632
633
634
635
636
637
        layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True)
        layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True)
        if has_bias:
            layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True)
            layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True)

        layer_te.fc1_weight.stop_gradient = no_wgrad
        layer_te.fc2_weight.stop_gradient = no_wgrad
        layer_te.ln_weight.stop_gradient = no_wgrad
        layer_pd.fc1_weight.stop_gradient = no_wgrad
        layer_pd.fc2_weight.stop_gradient = no_wgrad
        layer_pd.ln_weight.stop_gradient = no_wgrad
638
639
640
        if has_ln_bias:
            layer_te.ln_bias.stop_gradient = no_dbias
            layer_pd.ln_bias.stop_gradient = no_dbias
641
642
643
644
645
646
647
        if has_bias:
            layer_te.fc1_bias.stop_gradient = no_dbias
            layer_te.fc2_bias.stop_gradient = no_dbias
            layer_pd.fc1_bias.stop_gradient = no_dbias
            layer_pd.fc2_bias.stop_gradient = no_dbias

        out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
648
649
650
651
652
            layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
        )
        out, ln_out, grad_input = calc_output_and_grad_ln_out(
            layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
        )
653
654
655
656
657
658

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
659
660
661
662
663
664
            assert_allclose(
                layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol
            )
            assert_allclose(
                layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol
            )
665
        if not no_dbias:
666
667
            if has_ln_bias:
                assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
668
            if has_bias:
669
670
671
672
673
674
                assert_allclose(
                    layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol
                )
                assert_allclose(
                    layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol
                )
675
676
677
678
679
        if return_ln_out:
            assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
    @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
    @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
    @pytest.mark.parametrize("no_dgrad", [True, False])
    @pytest.mark.parametrize("no_wgrad", [True, False])
    @pytest.mark.parametrize("fp8_wgrad", [True, False])
    @pytest.mark.parametrize("do_calibration", [True, False])
    @pytest.mark.parametrize("return_ln_out", [True, False])
    @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
    @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
    @pytest.mark.parametrize("activation", ["gelu", "swiglu"])
    def test_layernorm_mlp_fp8(
        bs,
        hidden_size,
        ffn_hidden_size,
        has_bias,
        no_dbias,
        no_dgrad,
        no_wgrad,
        fp8_wgrad,
        do_calibration,
        return_ln_out,
        activation_dtype,
        normalization,
        activation,
    ):
705
706
707
708
709
710
711
712
713
714
715
        """
        Test FP8 LayerNormMLP Layer
        """
        paddle.set_default_dtype(activation_dtype)
        rtol = 0.1
        atol = 0.75

        input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
        input_tensor.stop_gradient = no_dgrad
        grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
        eps = 1e-3
716
        has_ln_bias = normalization == "LayerNorm"
717
718
719
720
721
722
723

        recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))

        layer_te = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
724
725
            normalization=normalization,
            activation=activation,
726
727
728
729
730
731
732
733
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
        )

        layer_pd = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
734
735
            normalization=normalization,
            activation=activation,
736
737
            bias_attr=None if has_bias else False,
            return_layernorm_output=return_ln_out,
738
            backend="paddle",
739
740
        )
        layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
741
742
        if has_ln_bias:
            layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
743
744
745
746
747
748
749
750
751
752
753
754
        layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True)
        layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True)
        if has_bias:
            layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True)
            layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True)

        layer_te.fc1_weight.stop_gradient = no_wgrad
        layer_te.fc2_weight.stop_gradient = no_wgrad
        layer_te.ln_weight.stop_gradient = no_wgrad
        layer_pd.fc1_weight.stop_gradient = no_wgrad
        layer_pd.fc2_weight.stop_gradient = no_wgrad
        layer_pd.ln_weight.stop_gradient = no_wgrad
755
756
757
        if has_ln_bias:
            layer_te.ln_bias.stop_gradient = no_dbias
            layer_pd.ln_bias.stop_gradient = no_dbias
758
759
760
761
762
763
        if has_bias:
            layer_te.fc1_bias.stop_gradient = no_dbias
            layer_te.fc2_bias.stop_gradient = no_dbias
            layer_pd.fc1_bias.stop_gradient = no_dbias
            layer_pd.fc2_bias.stop_gradient = no_dbias

764
765
766
        with fp8_autocast(
            enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
        ):
767
            out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
768
769
770
771
772
                layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
            )
            out, ln_out, grad_input = calc_output_and_grad_ln_out(
                layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
            )
773
774
775
776
777
778

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
        if not no_dgrad:
            assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
        if not no_wgrad:
            assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
779
780
781
782
783
784
            assert_allclose(
                layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol
            )
            assert_allclose(
                layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol
            )
785
        if not no_dbias:
786
787
            if has_ln_bias:
                assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
788
            if has_bias:
789
790
791
792
793
794
                assert_allclose(
                    layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol
                )
                assert_allclose(
                    layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol
                )
795
796
797
798
799
        if return_ln_out:
            assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)

        if do_calibration:
            assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0
Shijie's avatar
Shijie committed
800

801
802
    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
803
804
805
806
807
808
    @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
    @pytest.mark.parametrize("activation_dtype", ["bfloat16"])
    @pytest.mark.parametrize("num_microbatch", [8])
    def test_layernorm_mlp_fp8_microbatch(
        bs, hidden_size, ffn_hidden_size, activation_dtype, num_microbatch
    ):
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
        """
        Test FP8 LayerNormMLP Layer
        """
        paddle.set_default_dtype(activation_dtype)
        rtol = 1e-5
        atol = 1e-5
        eps = 1e-3

        recipe = DelayedScaling()

        layer_cached = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
        )

        layer_normal = te.LayerNormMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            eps=eps,
        )
        layer_normal.ln_weight.copy_(layer_cached.ln_weight, True)
        layer_normal.ln_bias.copy_(layer_cached.ln_bias, True)
        layer_normal.fc1_weight.copy_(layer_cached.fc1_weight, True)
        layer_normal.fc2_weight.copy_(layer_cached.fc2_weight, True)
        layer_normal.fc1_bias.copy_(layer_cached.fc1_bias, True)
        layer_normal.fc2_bias.copy_(layer_cached.fc2_bias, True)

        # Calibration to make sure weight scale is the same
        input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
        with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
            _ = layer_cached(input_tensor)

        with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
            _ = layer_normal(input_tensor)

        for iteration in range(num_microbatch):
            input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
            grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0))
                out.backward(grad_out)

            with fp8_autocast(enabled=True, fp8_recipe=recipe):
                out_ref = layer_normal(input_tensor)
                out_ref.backward(grad_out)

            assert_allclose(out, out_ref, rtol=rtol, atol=atol)
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
            assert_allclose(
                layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol
            )
            assert_allclose(
                layer_cached.fc1_weight.grad, layer_normal.fc1_weight.grad, rtol=rtol, atol=atol
            )
            assert_allclose(
                layer_cached.fc2_weight.grad, layer_normal.fc2_weight.grad, rtol=rtol, atol=atol
            )


@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("hidden_size, num_heads", [[1024, 16]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize("attn_type", ["self", "cross"])
@pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
875
@pytest.mark.parametrize("deterministic", [True, False])
876
def test_dot_product_attention(
877
    bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype, deterministic
878
):
Shijie's avatar
Shijie committed
879
880
881
882
883
884
885
    """
    Test DotProductAttention Layer
    """
    paddle.set_default_dtype(math_dtype)
    rtol = 1e-4
    atol = 2e-2
    head_size = hidden_size // num_heads
Tim Moon's avatar
Tim Moon committed
886
887
888

    # Skip if cuDNN fused attention is not supported
    if not is_fused_attention_supported(
889
890
891
892
893
894
895
896
897
898
        num_heads=num_heads,
        num_gqa_groups=num_heads,
        q_seqlen=q_seqlen,
        kv_seqlen=kv_seqlen,
        head_size=head_size,
        dtype=math_dtype,
        dropout=0.0,
        qkv_layout="bshd_bshd_bshd",
        bias_type="no_bias",
        mask_type=mask_type,
Tim Moon's avatar
Tim Moon committed
899
900
901
    ):
        pytest.skip("cuDNN fused attention is not supported")

902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
    attn_q_input = paddle.normal(
        mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)
    ).astype(math_dtype)
    attn_k_input = paddle.normal(
        mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)
    ).astype(math_dtype)
    attn_v_input = paddle.normal(
        mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)
    ).astype(math_dtype)

    q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype="int32")
    kv_actual_seqlen = (
        paddle.randint(low=20, high=kv_seqlen, shape=(bs,), dtype="int32")
        if attn_type == "cross"
        else q_actual_seqlen
    )
    attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")

    grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)).astype(
        "float32"
    )
Shijie's avatar
Shijie committed
923
    for i in range(0, bs):
924
        grad_out[i, q_actual_seqlen[i] :, :, :] = 0
Shijie's avatar
Shijie committed
925
926
927
    grad_out = grad_out.astype(math_dtype)

    for i in range(0, bs):
928
        attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
Shijie's avatar
Shijie committed
929

Shijie's avatar
Shijie committed
930
    head_size = hidden_size // num_heads
931
932
933
934

    if deterministic:
        os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"

935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
    layer_te = te.DotProductAttention(
        num_heads,
        head_size,
        attention_dropout=0.0,
        attn_mask_type=mask_type,
        attention_type=attn_type,
        backend="transformer_engine",
    )
    layer_pd = te.DotProductAttention(
        num_heads,
        head_size,
        attention_dropout=0.0,
        attn_mask_type=mask_type,
        attention_type=attn_type,
        backend="paddle",
    )
Shijie's avatar
Shijie committed
951

Shijie's avatar
Shijie committed
952
    def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
Shijie's avatar
Shijie committed
953
        _q = paddle.to_tensor(q, stop_gradient=False)
Shijie's avatar
Shijie committed
954
955
        _k = paddle.to_tensor(k, stop_gradient=False)
        _v = paddle.to_tensor(v, stop_gradient=False)
Shijie's avatar
Shijie committed
956

Shijie's avatar
Shijie committed
957
        out = layer(_q, _k, _v, mask)
Shijie's avatar
Shijie committed
958
        out.backward(dout)
Shijie's avatar
Shijie committed
959
        return out, _q.grad, _k.grad, _v.grad
Shijie's avatar
Shijie committed
960

961
962
963
    out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(
        layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
    )
Shijie's avatar
Shijie committed
964
    out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
965
966
        layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
    )
Shijie's avatar
Shijie committed
967
968
    valid_out_ref = paddle.full_like(out_ref, 0)
    for i in range(0, bs):
969
        valid_out_ref[i, 0 : q_actual_seqlen[i], :, :] = out_ref[i, 0 : q_actual_seqlen[i], :, :]
Shijie's avatar
Shijie committed
970
971
972
973
974

    valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
    valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
    valid_v_grad_ref = paddle.full_like(v_grad_ref, 0)
    for i in range(0, bs):
975
976
977
978
979
980
981
982
983
        valid_q_grad_ref[i, 0 : q_actual_seqlen[i], :, :] = q_grad_ref[
            i, 0 : q_actual_seqlen[i], :, :
        ]
        valid_k_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = k_grad_ref[
            i, 0 : kv_actual_seqlen[i], :, :
        ]
        valid_v_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = v_grad_ref[
            i, 0 : kv_actual_seqlen[i], :, :
        ]
Shijie's avatar
Shijie committed
984
985
986
987
988

    assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol)
    assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol)
    assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol)
    assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)
989
990
991
992
993
994
995
996
997
    if deterministic:
        out2, q_grad2, k_grad2, v_grad2 = calc_attn_output_and_grad(
            layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
        )
        assert_allclose(out, out2, rtol=1e-12, atol=1e-12)
        assert_allclose(q_grad, q_grad2, rtol=1e-12, atol=1e-12)
        assert_allclose(k_grad, k_grad2, rtol=1e-12, atol=1e-12)
        assert_allclose(v_grad, v_grad2, rtol=1e-12, atol=1e-12)
        os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None)
Shijie's avatar
Shijie committed
998
999


1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4])
@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_transformer_encoder_layer(
    bs,
    hidden_size,
    num_heads,
    num_gqa_groups,
    ffn_hidden_size,
    has_bias,
    no_dbias,
    no_wgrad,
    q_seqlen,
    kv_seqlen,
    mask_type,
    math_dtype,
    output_layernorm,
    return_layernorm_output,
    normalization,
):
Shijie's avatar
Shijie committed
1028
1029
1030
1031
1032
1033
1034
    """
    Test Transformer Encoder Layer
    """
    paddle.set_default_dtype(math_dtype)
    rtol = 5e-2
    atol = 5e-2
    eps = 1e-3
1035
    has_ln_bias = normalization == "LayerNorm"
Shijie's avatar
Shijie committed
1036

Tim Moon's avatar
Tim Moon committed
1037
1038
    # Skip if cuDNN fused attention is not supported
    if not is_fused_attention_supported(
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
        num_heads=num_heads,
        num_gqa_groups=num_gqa_groups,
        q_seqlen=q_seqlen,
        kv_seqlen=kv_seqlen,
        head_size=hidden_size // num_heads,
        dtype=math_dtype,
        dropout=0.0,
        qkv_layout="bshd_bshd_bshd",
        bias_type="no_bias",
        mask_type=mask_type,
Tim Moon's avatar
Tim Moon committed
1049
1050
1051
    ):
        pytest.skip("cuDNN fused attention is not supported")

Shijie's avatar
Shijie committed
1052
1053
    encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)

1054
    q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
Shijie's avatar
Shijie committed
1055
    kv_actual_seqlen = q_actual_seqlen
1056
    attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
Shijie's avatar
Shijie committed
1057

1058
1059
1060
    grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype(
        "float32"
    )
Shijie's avatar
Shijie committed
1061
    for i in range(0, bs):
1062
        grad_out[i, q_actual_seqlen[i] :, :] = 0
Shijie's avatar
Shijie committed
1063
1064
1065
    grad_out = grad_out.astype(math_dtype)

    for i in range(0, bs):
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
        attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False

    layer_te = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_heads,
        num_gqa_groups=num_gqa_groups,
        layernorm_epsilon=eps,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        weight_attr=None,
        bias_attr=None if has_bias else False,
        self_attn_mask_type=mask_type,
        apply_residual_connection_post_layernorm=return_layernorm_output,
        output_layernorm=output_layernorm,
        layer_type="encoder",
        normalization=normalization,
        backend="transformer_engine",
    )
    layer_pd = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_heads,
        num_gqa_groups=num_gqa_groups,
        layernorm_epsilon=eps,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        weight_attr=None,
        bias_attr=None if has_bias else False,
        self_attn_mask_type=mask_type,
        apply_residual_connection_post_layernorm=return_layernorm_output,
        output_layernorm=output_layernorm,
        layer_type="encoder",
        normalization=normalization,
        backend="paddle",
    )
Shijie's avatar
Shijie committed
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113

    # MultiHeadAttention params
    if output_layernorm:
        layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True)
        layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad
        layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad
        if has_bias:
            layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True)
            layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias
            layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
    else:
        layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
1114
1115
            layer_te.self_attention.layernorm_qkv.ln_weight, True
        )
Shijie's avatar
Shijie committed
1116
        layer_pd.self_attention.layernorm_qkv.weight.copy_(
1117
1118
            layer_te.self_attention.layernorm_qkv.weight.T, True
        )
Shijie's avatar
Shijie committed
1119
1120
1121
1122
        layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
        layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
        layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
        layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
1123
1124
        if has_ln_bias:
            layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
1125
1126
                layer_te.self_attention.layernorm_qkv.ln_bias, True
            )
1127
1128
            layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
            layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
Shijie's avatar
Shijie committed
1129
1130
        if has_bias:
            layer_pd.self_attention.layernorm_qkv.bias.copy_(
1131
1132
                layer_te.self_attention.layernorm_qkv.bias, True
            )
Shijie's avatar
Shijie committed
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
            layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
            layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias

    layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True)
    layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad
    layer_te.self_attention.proj.weight.stop_gradient = no_wgrad
    if has_bias:
        layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True)
        layer_pd.self_attention.proj.bias.stop_gradient = no_dbias
        layer_te.self_attention.proj.bias.stop_gradient = no_dbias

    # LayerNorm MLP params
    layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True)
    layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True)
    layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True)
    layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
    layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
    layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
1154
1155
1156
1157
    if has_ln_bias:
        layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
        layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
        layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
Shijie's avatar
Shijie committed
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
    if has_bias:
        layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True)
        layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True)
        layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
        layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
        layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
        layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias

    if output_layernorm:
        layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True)
        layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True)
        layer_pd.layernorm.weight.stop_gradient = no_wgrad
        layer_pd.layernorm.bias.stop_gradient = no_dbias
        layer_te.layernorm.weight.stop_gradient = no_wgrad
        layer_te.layernorm.bias.stop_gradient = no_dbias

    def calc_transformer_output_and_grad(layer, encoder_input, mask, dout):
        _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
        out = layer(_encoder_input, mask)
        out.backward(dout)
        return out, _encoder_input.grad

1180
1181
1182
    out_ref, grad_input_ref = calc_transformer_output_and_grad(
        layer_pd, encoder_input, attn_mask, grad_out
    )
Shijie's avatar
Shijie committed
1183
1184
1185
1186
1187
1188
    out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out)

    assert_allclose(out, out_ref, rtol=rtol, atol=atol)
    assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
    if not no_wgrad:
        if output_layernorm:
1189
1190
1191
1192
1193
1194
            assert_allclose(
                layer_te.self_attention.qkv.weight.grad,
                layer_pd.self_attention.qkv.weight.grad.T,
                rtol=rtol,
                atol=atol,
            )
Shijie's avatar
Shijie committed
1195
        else:
1196
1197
1198
1199
1200
1201
            assert_allclose(
                layer_te.self_attention.layernorm_qkv.weight.grad,
                layer_pd.self_attention.layernorm_qkv.weight.grad.T,
                rtol=rtol,
                atol=atol,
            )
Shijie's avatar
Shijie committed
1202
1203
    if not no_dbias:
        if output_layernorm:
1204
1205
1206
1207
1208
1209
            assert_allclose(
                layer_te.self_attention.qkv.bias.grad,
                layer_pd.self_attention.qkv.bias.grad,
                rtol=0.01,
                atol=0.5,
            )
Shijie's avatar
Shijie committed
1210
        else:
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
            assert_allclose(
                layer_te.self_attention.layernorm_qkv.bias.grad,
                layer_pd.self_attention.layernorm_qkv.bias.grad,
                rtol=0.01,
                atol=0.5,
            )


@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4])
@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize("no_wgrad", [True, False])
@pytest.mark.parametrize("mask_type", ["causal", "padding"])
@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize("recompute_core_attention", [True, False])
@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
def test_transformer_decoder_layer(
    bs,
    hidden_size,
    num_heads,
    num_gqa_groups,
    ffn_hidden_size,
    has_bias,
    no_dbias,
    no_wgrad,
    q_seqlen,
    kv_seqlen,
    mask_type,
    math_dtype,
    output_layernorm,
    return_layernorm_output,
    recompute_core_attention,
    normalization,
):
Shijie's avatar
Shijie committed
1249
1250
1251
1252
1253
    """
    Test Transformer Decoder Layer
    """
    paddle.set_default_dtype(math_dtype)
    rtol = 5e-2
1254
    atol = 6e-2
Shijie's avatar
Shijie committed
1255
    eps = 1e-3
1256
    has_ln_bias = normalization == "LayerNorm"
Shijie's avatar
Shijie committed
1257

Tim Moon's avatar
Tim Moon committed
1258
1259
    # Skip if cuDNN fused attention is not supported
    if not is_fused_attention_supported(
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
        num_heads=num_heads,
        num_gqa_groups=num_gqa_groups,
        q_seqlen=q_seqlen,
        kv_seqlen=kv_seqlen,
        head_size=hidden_size // num_heads,
        dtype=math_dtype,
        dropout=0.0,
        qkv_layout="bshd_bshd_bshd",
        bias_type="no_bias",
        mask_type=mask_type,
Tim Moon's avatar
Tim Moon committed
1270
1271
1272
    ):
        pytest.skip("cuDNN fused attention is not supported")

1273
1274
1275
1276
1277
1278
    encoder_input = paddle.normal(mean=0.0, std=0.1, shape=(bs, q_seqlen, hidden_size)).astype(
        math_dtype
    )
    encoder_output = paddle.normal(mean=0.0, std=0.1, shape=(bs, kv_seqlen, hidden_size)).astype(
        math_dtype
    )
Shijie's avatar
Shijie committed
1279

1280
    q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
Shijie's avatar
Shijie committed
1281
    kv_actual_seqlen = q_actual_seqlen
1282
    attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
Shijie's avatar
Shijie committed
1283

1284
1285
1286
    grad_out = paddle.normal(mean=0.0, std=0.01, shape=(bs, q_seqlen, hidden_size)).astype(
        "float32"
    )
Shijie's avatar
Shijie committed
1287
1288
1289
1290
1291
1292

    # rounding to avoid numerical issues
    encoder_input = paddle.round(encoder_input * 1000) / 1000
    encoder_output = paddle.round(encoder_output * 1000) / 1000
    grad_out = paddle.round(grad_out * 1000) / 1000

Shijie's avatar
Shijie committed
1293
    for i in range(0, bs):
1294
        grad_out[i, q_actual_seqlen[i] :, :] = 0
Shijie's avatar
Shijie committed
1295
1296
1297
    grad_out = grad_out.astype(math_dtype)

    for i in range(0, bs):
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
        attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False

    layer_te = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_heads,
        num_gqa_groups=num_gqa_groups,
        layernorm_epsilon=eps,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        weight_attr=None,
        bias_attr=None if has_bias else False,
        self_attn_mask_type=mask_type,
        apply_residual_connection_post_layernorm=return_layernorm_output,
        output_layernorm=output_layernorm,
        layer_type="decoder",
        normalization=normalization,
        backend="transformer_engine",
    )
    layer_pd = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_heads,
        num_gqa_groups=num_gqa_groups,
        layernorm_epsilon=eps,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        weight_attr=None,
        bias_attr=None if has_bias else False,
        self_attn_mask_type=mask_type,
        apply_residual_connection_post_layernorm=return_layernorm_output,
        output_layernorm=output_layernorm,
        layer_type="decoder",
        normalization=normalization,
        backend="paddle",
    )
Shijie's avatar
Shijie committed
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345

    # MultiHeadAttention params - self attn
    if output_layernorm:
        layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True)
        layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad
        layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad
        if has_bias:
            layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True)
            layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias
            layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
    else:
        layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
1346
1347
            layer_te.self_attention.layernorm_qkv.ln_weight, True
        )
Shijie's avatar
Shijie committed
1348
        layer_pd.self_attention.layernorm_qkv.weight.copy_(
1349
1350
            layer_te.self_attention.layernorm_qkv.weight.T, True
        )
Shijie's avatar
Shijie committed
1351
1352
1353
1354
        layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
        layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
        layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
        layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
1355
1356
        if has_ln_bias:
            layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
1357
1358
                layer_te.self_attention.layernorm_qkv.ln_bias, True
            )
1359
1360
            layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
            layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
Shijie's avatar
Shijie committed
1361
1362
        if has_bias:
            layer_pd.self_attention.layernorm_qkv.bias.copy_(
1363
1364
                layer_te.self_attention.layernorm_qkv.bias, True
            )
Shijie's avatar
Shijie committed
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
            layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
            layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias

    layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True)
    layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad
    layer_te.self_attention.proj.weight.stop_gradient = no_wgrad
    if has_bias:
        layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True)
        layer_pd.self_attention.proj.bias.stop_gradient = no_dbias
        layer_te.self_attention.proj.bias.stop_gradient = no_dbias

    # MultiHeadAttention params - cross attn
    layer_pd.inter_attention.layernorm_query.ln_weight.copy_(
1378
1379
        layer_te.inter_attention.layernorm_query.ln_weight, True
    )
Shijie's avatar
Shijie committed
1380
    layer_pd.inter_attention.layernorm_query.weight.copy_(
1381
1382
        layer_te.inter_attention.layernorm_query.weight.T, True
    )
Shijie's avatar
Shijie committed
1383
1384
1385
1386
    layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
    layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
    layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
    layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
1387
1388
    if has_ln_bias:
        layer_pd.inter_attention.layernorm_query.ln_bias.copy_(
1389
1390
            layer_te.inter_attention.layernorm_query.ln_bias, True
        )
1391
1392
        layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
        layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
Shijie's avatar
Shijie committed
1393
1394
    if has_bias:
        layer_pd.inter_attention.layernorm_query.bias.copy_(
1395
1396
            layer_te.inter_attention.layernorm_query.bias, True
        )
Shijie's avatar
Shijie committed
1397
1398
1399
        layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
        layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias

1400
1401
1402
    layer_pd.inter_attention.key_value.weight.copy_(
        layer_te.inter_attention.key_value.weight.T, True
    )
Shijie's avatar
Shijie committed
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
    layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad
    layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad
    layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True)
    layer_pd.inter_attention.proj.weight.stop_gradient = no_wgrad
    layer_te.inter_attention.proj.weight.stop_gradient = no_wgrad
    if has_bias:
        layer_pd.inter_attention.key_value.bias.copy_(layer_te.inter_attention.key_value.bias, True)
        layer_pd.inter_attention.key_value.bias.stop_gradient = no_dbias
        layer_te.inter_attention.key_value.bias.stop_gradient = no_dbias
        layer_pd.inter_attention.proj.bias.copy_(layer_te.inter_attention.proj.bias, True)
        layer_pd.inter_attention.proj.bias.stop_gradient = no_dbias
        layer_te.inter_attention.proj.bias.stop_gradient = no_dbias

    # LayerNorm MLP params
    layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True)
    layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True)
    layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True)
    layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
    layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
    layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad
    layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad
1426
1427
1428
1429
    if has_ln_bias:
        layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True)
        layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias
        layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias
Shijie's avatar
Shijie committed
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
    if has_bias:
        layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True)
        layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True)
        layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
        layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias
        layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias
        layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias

    if output_layernorm:
        layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True)
        layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True)
        layer_pd.layernorm.weight.stop_gradient = no_wgrad
        layer_pd.layernorm.bias.stop_gradient = no_dbias
        layer_te.layernorm.weight.stop_gradient = no_wgrad
        layer_te.layernorm.bias.stop_gradient = no_dbias

1446
1447
1448
1449
1450
1451
1452
1453
1454
    def calc_transformer_output_and_grad(
        layer,
        encoder_input,
        mask,
        encoder_output,
        enc_dec_attn_mask,
        dout,
        recompute_core_attention=False,
    ):
Shijie's avatar
Shijie committed
1455
1456
        _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
        _encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False)
1457
1458
1459
1460
1461
1462
1463
        out = layer(
            _encoder_input,
            mask,
            _encoder_output,
            enc_dec_attn_mask,
            recompute_core_attention=recompute_core_attention,
        )
Shijie's avatar
Shijie committed
1464
1465
1466
1467
        out.backward(dout)
        return out, _encoder_input.grad, _encoder_output.grad

    out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad(
1468
1469
        layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out
    )
Shijie's avatar
Shijie committed
1470
    out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad(
Tian Zheng's avatar
Tian Zheng committed
1471
1472
1473
1474
1475
1476
        layer_te,
        encoder_input,
        attn_mask,
        encoder_output,
        attn_mask,
        grad_out,
1477
1478
        recompute_core_attention=recompute_core_attention,
    )
Shijie's avatar
Shijie committed
1479
1480
1481
1482
1483
1484

    assert_allclose(out, out_ref, rtol=rtol, atol=atol)
    assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol)
    assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol)
    if not no_wgrad:
        if output_layernorm:
1485
1486
1487
1488
1489
1490
            assert_allclose(
                layer_te.self_attention.qkv.weight.grad,
                layer_pd.self_attention.qkv.weight.grad.T,
                rtol=rtol,
                atol=atol,
            )
Shijie's avatar
Shijie committed
1491
        else:
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
            assert_allclose(
                layer_te.self_attention.layernorm_qkv.weight.grad,
                layer_pd.self_attention.layernorm_qkv.weight.grad.T,
                rtol=rtol,
                atol=atol,
            )
            assert_allclose(
                layer_te.inter_attention.layernorm_query.weight.grad,
                layer_pd.inter_attention.layernorm_query.weight.grad.T,
                rtol=rtol,
                atol=atol,
            )
Shijie's avatar
Shijie committed
1504
1505
    if not no_dbias:
        if output_layernorm:
1506
1507
1508
1509
1510
1511
            assert_allclose(
                layer_te.self_attention.qkv.bias.grad,
                layer_pd.self_attention.qkv.bias.grad,
                rtol=0.5,
                atol=0.6,
            )
Shijie's avatar
Shijie committed
1512
        else:
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
            assert_allclose(
                layer_te.self_attention.layernorm_qkv.bias.grad,
                layer_pd.self_attention.layernorm_qkv.bias.grad,
                rtol=0.01,
                atol=0.5,
            )
            assert_allclose(
                layer_te.inter_attention.layernorm_query.bias.grad,
                layer_pd.inter_attention.layernorm_query.bias.grad,
                rtol=rtol,
                atol=atol,
            )
1525
1526
1527


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
@pytest.mark.parametrize("bs", [8])
@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[1024, 16, 4096]])
@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[128, 128]])
@pytest.mark.parametrize("mask_type", ["causal"])
@pytest.mark.parametrize("math_dtype", ["bfloat16"])
@pytest.mark.parametrize("num_microbatch", [8])
def test_transformer_encoder_layer_microbatch(
    bs,
    hidden_size,
    num_heads,
    ffn_hidden_size,
    q_seqlen,
    kv_seqlen,
    mask_type,
    math_dtype,
    num_microbatch,
):
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
    """
    Test Transformer Encoder Layer with FP8 weight caching
    """
    paddle.set_default_dtype(math_dtype)
    rtol = 1e-5
    atol = 1e-5
    eps = 1e-3

    # Skip if cuDNN fused attention is not supported
    if not is_fused_attention_supported(
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
        num_heads=num_heads,
        num_gqa_groups=num_heads,
        q_seqlen=q_seqlen,
        kv_seqlen=kv_seqlen,
        head_size=hidden_size // num_heads,
        dtype=math_dtype,
        dropout=0.0,
        qkv_layout="bs3hd",
        bias_type="no_bias",
        mask_type=mask_type,
1565
1566
1567
    ):
        pytest.skip("cuDNN fused attention is not supported")

1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
    layer_cached = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_heads,
        layernorm_epsilon=eps,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        weight_attr=None,
        bias_attr=None,
        self_attn_mask_type=mask_type,
        layer_type="encoder",
    )
    layer_normal = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_heads,
        layernorm_epsilon=eps,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        weight_attr=None,
        bias_attr=None,
        self_attn_mask_type=mask_type,
        layer_type="encoder",
    )
1592
1593

    layer_normal.self_attention.layernorm_qkv.ln_weight.copy_(
1594
1595
        layer_cached.self_attention.layernorm_qkv.ln_weight, True
    )
1596
    layer_normal.self_attention.layernorm_qkv.ln_bias.copy_(
1597
1598
        layer_cached.self_attention.layernorm_qkv.ln_bias, True
    )
1599
    layer_normal.self_attention.layernorm_qkv.weight.copy_(
1600
1601
        layer_cached.self_attention.layernorm_qkv.weight, True
    )
1602
    layer_normal.self_attention.layernorm_qkv.bias.copy_(
1603
1604
        layer_cached.self_attention.layernorm_qkv.bias, True
    )
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621

    layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True)
    layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True)

    # LayerNorm MLP params
    layer_normal.layernorm_mlp.ln_weight.copy_(layer_cached.layernorm_mlp.ln_weight, True)
    layer_normal.layernorm_mlp.ln_bias.copy_(layer_cached.layernorm_mlp.ln_bias, True)
    layer_normal.layernorm_mlp.fc1_weight.copy_(layer_cached.layernorm_mlp.fc1_weight, True)
    layer_normal.layernorm_mlp.fc2_weight.copy_(layer_cached.layernorm_mlp.fc2_weight, True)
    layer_normal.layernorm_mlp.fc1_bias.copy_(layer_cached.layernorm_mlp.fc1_bias, True)
    layer_normal.layernorm_mlp.fc2_bias.copy_(layer_cached.layernorm_mlp.fc2_bias, True)

    recipe = DelayedScaling()

    def generate_input():
        encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)

1622
        q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
1623
        kv_actual_seqlen = q_actual_seqlen
1624
        attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
1625

1626
1627
1628
        grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype(
            "float32"
        )
1629
        for i in range(0, bs):
1630
            grad_out[i, q_actual_seqlen[i] :, :] = 0
1631
1632
1633
        grad_out = grad_out.astype(math_dtype)

        for i in range(0, bs):
1634
            attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657

        return encoder_input, attn_mask, grad_out

    # Calibration to make sure weight scale is the same
    encoder_input, mask, _ = generate_input()
    with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
        _ = layer_cached(encoder_input, mask)

    with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe):
        _ = layer_normal(encoder_input, mask)

    for iteration in range(num_microbatch):
        encoder_input, mask, grad_out = generate_input()

        with fp8_autocast(enabled=True, fp8_recipe=recipe):
            out = layer_cached(encoder_input, mask, is_first_microbatch=(iteration == 0))
            out.backward(grad_out)

        with fp8_autocast(enabled=True, fp8_recipe=recipe):
            out_ref = layer_normal(encoder_input, mask)
            out_ref.backward(grad_out)

        assert_allclose(out, out_ref, rtol=rtol, atol=atol)
1658
1659
1660
1661
1662
1663
        assert_allclose(
            layer_cached.self_attention.layernorm_qkv.weight.grad,
            layer_normal.self_attention.layernorm_qkv.weight.grad,
            rtol=rtol,
            atol=atol,
        )