test_operators.py 40.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Test TE operators"""

Shijie's avatar
Shijie committed
6
7
8
import struct

import numpy as np
9
import paddle
Shijie's avatar
Shijie committed
10
import paddle.nn.functional as F
Tim Moon's avatar
Tim Moon committed
11
import pytest
12

13
14
15
16
17
18
19
from utils import (
    assert_allclose,
    create_fp8_meta,
    get_fused_attention_backend,
    is_fused_attention_supported,
)

20
from transformer_engine import transformer_engine_paddle as tex
21
22
23
24
25
26
27
from transformer_engine.paddle.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
    gemm,
    fp8_gemm,
    transpose,
    cast_transpose,
28
    cast_transpose_bgrad,
29
30
    te_gelu,
    gelu_fp8,
31
32
33
34
    swiglu,
    swiglu_fp8,
    swiglu_pd,
    dswiglu,
35
36
37
38
    dgelu_cast_transpose_bgrad_fp8,
    layernorm_fwd_fp8,
    layernorm_fwd,
    layernorm_bwd,
Shijie's avatar
Shijie committed
39
40
41
42
43
44
45
    rmsnorm_fwd_fp8,
    rmsnorm_fwd,
    rmsnorm_bwd,
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
Shijie's avatar
Shijie committed
46
47
    fused_attn_fwd,
    fused_attn_bwd,
Shijie's avatar
Shijie committed
48
49
50
51
52
53
    scaled_softmax_forward,
    scaled_softmax_backward,
    scaled_masked_softmax_forward,
    scaled_masked_softmax_backward,
    scaled_upper_triang_masked_softmax_forward,
    scaled_upper_triang_masked_softmax_backward,
54
)
55
from transformer_engine.paddle.fp8 import is_fp8_available
56
57
from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling
58

59
60
61
62
63
64
65
GEMM_CASES = [
    (256, 256, 512),
    (32, 32, 32),
    (16384, 1024, 2816),
    (16384, 2816, 1024),
    (16384, 1024, 1024),
]
66
67
is_fp8_supported, reason = is_fp8_available()

68
69
70
SELF_ATTN_CASES = [(2, 512, 12, 64)]
CROSS_ATTN_CASES = [(2, 128, 512, 12, 64)]
FLASH_ATTN_CASES = [(2, 1024, 16, 64), (2, 2048, 16, 128)]
Shijie's avatar
Shijie committed
71
72
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]

73

Tian Zheng's avatar
Tian Zheng committed
74
75
76
77
78
79
80
81
@pytest.fixture(autouse=True)
def setup():
    """Setup random seed before each test"""
    np.random.seed(10)
    paddle.seed(11)
    yield


82
83
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("inplace", [True, False])
84
def test_quantize_dequantize(fp8_dtype, inplace):
85
86
87
    """
    Test cast_to_fp8 and cast_from_fp8
    """
88
    a = paddle.rand(shape=(32, 32), dtype="float32")
89
    # Init fp8_meta
90
    fp8_meta = create_fp8_meta()
91
92
93
94
95
96
97
98
99
100
    a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None
    a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype, out=a_fp8)
    b = cast_from_fp8(
        a_fp8,
        fp8_meta,
        FP8FwdTensors.GEMM1_OUTPUT,
        itype=fp8_dtype,
        otype=tex.DType.kFloat32,
    )
    assert_allclose(a, b, rtol=5e-2, atol=5e-2)
101
102


Shijie's avatar
Shijie committed
103
104
105
106
def copy_bits_from_float_to_uint16(f):
    """
    Copy bits
    """
107
    return struct.unpack("<I", struct.pack("<f", f))[0] >> 16
Shijie's avatar
Shijie committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121


def convert_float_to_uint16(float_list):
    """
    convert float to uint16
    """
    new_output = []
    for x in np.nditer(float_list):
        new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
    new_output = np.reshape(new_output, float_list.shape).view(np.uint16)

    return new_output


122
123
124
125
126
127
128
129
130
131
class TestTranspose:
    """
    Test transpose operators
    """

    @staticmethod
    def test_transpose_bf16():
        """
        Test BF16 transpose
        """
132
        a = paddle.rand(shape=(16, 32), dtype="bfloat16")
133
134
135
136
137
        a_transposed = transpose(a, otype=tex.DType.kBFloat16)
        assert_allclose(a_transposed, a.T)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
138
    @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
139
140
141
142
143
144
    def test_transpose_fp8(fp8_dtype):
        """
        Test FP8 transpose
        """
        min_val = -8
        max_val = 8
145
        a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
146
147
        fp8_meta = create_fp8_meta()
        a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
148
        a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype)
149
150
151
152
153
154
155
        a_transposed = cast_from_fp8(
            a_fp8_transposed,
            fp8_meta,
            FP8FwdTensors.GEMM1_INPUT,
            itype=fp8_dtype,
            otype=tex.DType.kFloat32,
        )
156
157
158
159
        assert_allclose(a_transposed, a.T)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
160
161
    @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
    @pytest.mark.parametrize("inplace", [True, False])
162
    def test_cast_transpose(fp8_dtype, inplace):
163
164
165
166
167
        """
        Test cast_transpose
        """
        min_val = -8
        max_val = 8
168
        a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
169
        fp8_meta = create_fp8_meta()
170
171
172
173
        a_fp8_casted, a_fp8_transposed = None, None
        if inplace:
            a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8)
            a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        a_fp8_casted, a_fp8_transposed = cast_transpose(
            a,
            fp8_meta,
            FP8FwdTensors.GEMM1_INPUT,
            otype=fp8_dtype,
            cast_out=a_fp8_casted,
            transpose_out=a_fp8_transposed,
        )

        a_transposed = cast_from_fp8(
            a_fp8_transposed,
            fp8_meta,
            FP8FwdTensors.GEMM1_INPUT,
            itype=fp8_dtype,
            otype=tex.DType.kFloat32,
        )

        a_casted = cast_from_fp8(
            a_fp8_casted,
            fp8_meta,
            FP8FwdTensors.GEMM1_INPUT,
            itype=fp8_dtype,
            otype=tex.DType.kFloat32,
        )
198
199
200
201

        assert_allclose(a_casted, a)
        assert_allclose(a_transposed, a.T)

202
203
    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
204
    @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
205
206
207
208
209
210
    def test_cast_transpose_bgrad(fp8_dtype):
        """
        Test cast_transpose_bgrad
        """
        min_val = -8
        max_val = 8
211
        a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
212
        fp8_meta = create_fp8_meta()
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(
            a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype
        )

        a_transposed = cast_from_fp8(
            a_fp8_transposed,
            fp8_meta,
            FP8FwdTensors.GEMM1_INPUT,
            itype=fp8_dtype,
            otype=tex.DType.kFloat32,
        )

        a_casted = cast_from_fp8(
            a_fp8_casted,
            fp8_meta,
            FP8FwdTensors.GEMM1_INPUT,
            itype=fp8_dtype,
            otype=tex.DType.kFloat32,
        )
232
233
234
235
236

        assert_allclose(a_casted, a)
        assert_allclose(a_transposed, a.T)
        assert_allclose(bgrad, a.sum(axis=0))

237
238
239
240
241
242
243
244
245
246
247

class TestActivation:
    """
    Test activation operators
    """

    @staticmethod
    def test_gelu_bf16():
        """
        Test BF16 GELU Forward
        """
248
        a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
249
250
251
252
253
254
255
        gelu_out = te_gelu(a, otype=tex.DType.kBFloat16)
        gelu_ref = paddle.nn.GELU()(a)

        assert_allclose(gelu_out, gelu_ref, rtol=1e-2)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
256
    @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
257
258
259
260
    def test_gelu_fp8(fp8_dtype):
        """
        Test FP8 GELU Forward
        """
261
        a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
262
        fp8_meta = create_fp8_meta()
263

264
        gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
265

266
267
268
269
270
271
272
        gelu_out = cast_from_fp8(
            gelu_out_fp8,
            fp8_meta,
            FP8FwdTensors.GEMM1_INPUT,
            itype=fp8_dtype,
            otype=tex.DType.kFloat32,
        )
273
274
275
276
277
278
279

        gelu_ref = paddle.nn.GELU()(a)

        assert_allclose(gelu_out, gelu_ref, rtol=0.1, atol=0.01)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
280
    @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
281
282
283
284
285
    def test_gelu_bwd_fp8(fp8_dtype):
        """
        Test FP8 GELU Backward
        """
        # y = GELU(x), calculate ref
286
        x = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
287
288
        x.stop_gradient = False
        y = paddle.nn.GELU()(x)
289
        y_grad = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
290
291
        paddle.autograd.backward([y], [y_grad], True)
        # calculate fp8
292
        fp8_meta = create_fp8_meta()
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(
            y_grad, x, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype
        )

        x_grad = cast_from_fp8(
            x_grad_fp8,
            fp8_meta,
            FP8FwdTensors.GEMM1_INPUT,
            itype=fp8_dtype,
            otype=tex.DType.kFloat32,
        )

        x_grad_t = cast_from_fp8(
            x_grad_t_fp8,
            fp8_meta,
            FP8FwdTensors.GEMM1_INPUT,
            itype=fp8_dtype,
            otype=tex.DType.kFloat32,
        )
312
313
314
315
316

        assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)
        assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01)
        assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01)

317
318
319
320
321
    @staticmethod
    def test_swiglu_bf16():
        """
        Test BF16 SwiGLU Forward
        """
322
        a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
323
324
325
326
327
328
329
        swiglu_out = swiglu(a, otype=tex.DType.kBFloat16)
        swiglu_ref = swiglu_pd(a)

        assert_allclose(swiglu_out, swiglu_ref, rtol=1e-2)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
330
    @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
331
332
333
334
    def test_swiglu_fp8(fp8_dtype):
        """
        Test FP8 SwiGLU Forward
        """
335
        a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
336
337
338
339
        fp8_meta = create_fp8_meta()

        swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)

340
341
342
343
344
345
346
        swiglu_out = cast_from_fp8(
            swiglu_out_fp8,
            fp8_meta,
            FP8FwdTensors.GEMM1_INPUT,
            itype=fp8_dtype,
            otype=tex.DType.kFloat32,
        )
347
348
349
350
351
352
353
354
355
356
357

        swiglu_ref = swiglu_pd(a)

        assert_allclose(swiglu_out, swiglu_ref, rtol=0.1, atol=0.01)

    @staticmethod
    def test_swiglu_bwd():
        """
        Test SwiGLU Backward
        """
        # y = SwiGLU(x), calculate ref
358
        x = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
359
360
        x.stop_gradient = False
        y = swiglu_pd(x)
361
        y_grad = paddle.rand(shape=(16, 16), dtype="bfloat16") * 2 - 1
362
363
364
365
366
367
        paddle.autograd.backward([y], [y_grad], True)
        # calculate fp8
        x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16)

        assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)

368

369
370
371
372
373
374
class TestGemm:
    """
    Tests for gemm(cuBLASLt) operator
    """

    @staticmethod
375
376
377
378
    @pytest.mark.skipif(
        paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU"
    )
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
379
380
381
382
    def test_bf16(m, n, k):
        """
        Test "TN" BF16 GEMM
        """
383
384
        a = paddle.rand(shape=(m, k), dtype="bfloat16")
        b = paddle.rand(shape=(n, k), dtype="bfloat16")
385

386
        workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
387
388
389
390
391
392
393

        ref_out = paddle.matmul(a, b.T)
        # CublasLt inside tex.te_gemm assumes inputs are column major.
        # Mathematically, A@B=C is equivalent to B^T@A^T=C^T, where X^T is the
        # transpose of X.
        # Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T,
        # which is equivalent to a@b^T = C in row major.
394
395
396
        actual_out, _, _ = gemm(
            b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", None, None, False
        )
397

Tim Moon's avatar
Tim Moon committed
398
        assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5)
399
400

    @staticmethod
401
402
403
404
    @pytest.mark.skipif(
        paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU"
    )
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
405
406
407
408
409
410
    def test_bf16_inplace(m, n, k):
        """
        Test "TN" BF16 GEMM, with accumulate=True
        """
        min_val = -16
        max_val = 16
411
412
413
414
        a = paddle.rand(shape=(m, k), dtype="bfloat16")
        b = paddle.rand(shape=(n, k), dtype="bfloat16")
        c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), "bfloat16")
        workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
415
416
417
418

        ref_out = c + paddle.matmul(a, b.T)

        actual_out = paddle.clone(c)
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        _, _, _ = gemm(
            b,
            a,
            paddle.bfloat16,
            workspace,
            False,
            None,
            False,
            True,
            "TN",
            actual_out,
            None,
            False,
        )
433
434
435
436
437

        assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
438
    @pytest.mark.parametrize("m,n,k", GEMM_CASES)
439
440
441
442
    def test_fp8_randint(m, n, k):
        """
        Test "TN" FP8 GEMM
        """
Tim Moon's avatar
Tim Moon committed
443
444
        min_val = -4
        max_val = 4
445
446
        fp8_dtype = tex.DType.kFloat8E4M3
        out_dtype = paddle.float32
447
        fp8_meta = create_fp8_meta(num_gemms=1)
448

449
        a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), "float32")
450

451
        a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
452
        b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), "float32")
453
        b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype)
454
        workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
455
456

        ref_out = paddle.matmul(a, b.T)
457
458
459
460
461
462
463
464
465
466
467
468
        actual_out, _ = fp8_gemm(
            b_casted,
            fp8_meta.scale_inv,
            FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype,
            a_casted,
            fp8_meta.scale_inv,
            FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype,
            out_dtype,
            workspace,
        )
469
470

        assert_allclose(actual_out, ref_out)
471
472
473
474
475
476
477
478
479
480
481
482


class TestLayerNorm:
    """
    Test layernorm operators
    """

    @staticmethod
    def calc_fwd_ref(x, eps, gamma, beta):
        """
        Calculate reference using paddle layer_norm op
        """
483
484
485
        y = paddle.nn.functional.layer_norm(
            x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps
        )
486
487
        mean = paddle.mean(x, axis=-1)
        var = paddle.var(x, axis=-1)
488
        inv_var = paddle.sqrt(1.0 / var)
489
490
491
492
493
494
495
496
497
498
499
        return y, mean, inv_var

    @staticmethod
    def calc_bwd_ref(x, eps, gamma, beta, dy):
        """
        Calculate reference using paddle layer_norm op
        """
        x.stop_gradient = False
        gamma.stop_gradient = False
        beta.stop_gradient = False

500
501
502
        y = paddle.nn.functional.layer_norm(
            x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps
        )
503
504
505
506
507
508
509
510
511
512
513

        paddle.autograd.backward([y], [dy], True)

        return x.grad, gamma.grad, beta.grad

    def test_layernorm_fwd(self):
        """
        Test BF16 LayerNorm Forward
        """
        N, H = (16, 32)
        eps = 1e-3
514
515
516
        x = paddle.uniform(shape=(N, H), dtype="bfloat16")
        gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
        beta = paddle.uniform(shape=(H,), dtype="bfloat16")
517
518
519
520
521

        y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)

        y_ref, mu_ref, rsigma_ref = self.calc_fwd_ref(x, eps, gamma, beta)

Shijie's avatar
Shijie committed
522
        assert_allclose(y, y_ref, rtol=1e-4, atol=1e-4)
523
524
525
526
527
528
529
530
531
532
533
534
        assert_allclose(mu, mu_ref, rtol=1e-3, atol=1e-3)
        assert_allclose(rsigma, rsigma_ref, rtol=5e-2, atol=5e-2)

    @staticmethod
    def test_layernorm_fwd_fp8():
        """
        Test FP8 LayerNorm Forward
        """
        fp8_dtype = tex.DType.kFloat8E4M3
        N, H = (16, 32)
        eps = 1e-3

535
536
537
        x = paddle.uniform(shape=(N, H), dtype="float32")
        gamma = paddle.uniform(shape=(H,), dtype="float32")
        beta = paddle.uniform(shape=(H,), dtype="float32")
538

539
540
        fp8_tensor = FP8FwdTensors.GEMM1_INPUT
        fp8_meta = create_fp8_meta()
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557

        y_ref, mu_ref, rsigma_ref = layernorm_fwd(x, gamma, beta, eps, tex.DType.kFloat32)

        y_fp8, mu, rsigma = layernorm_fwd_fp8(x, gamma, beta, eps, fp8_meta, fp8_tensor, fp8_dtype)

        y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32)

        assert_allclose(y, y_ref, rtol=0.1, atol=0.01)
        assert_allclose(mu, mu_ref)
        assert_allclose(rsigma, rsigma_ref)

    def test_layernorm_bwd(self):
        """
        Test BF16 LayerNorm Backward
        """
        N, H = (16, 32)
        eps = 1e-3
558
559
560
561
        x = paddle.uniform(shape=(N, H), dtype="bfloat16")
        dy = paddle.uniform(shape=(N, H), dtype="bfloat16")
        gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
        beta = paddle.uniform(shape=(H,), dtype="bfloat16")
562
563
564
565
566
567
568
569
570

        dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy)

        _, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)
        dx, dgamma, dbeta = layernorm_bwd(dy, x, mu, rsigma, gamma)

        assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5)
        assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5)
        assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5)
Shijie's avatar
Shijie committed
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607


class TestRMSNorm:
    """
    Test rmsnorm operators
    """

    @staticmethod
    def calc_fwd_ref(x, eps, gamma):
        """
        Calculate rmsnorm reference using paddle op
        """

        norm = paddle.rsqrt(paddle.mean(x**2, axis=-1, keepdim=True) + eps)
        y = x * norm * gamma

        return y

    def calc_bwd_ref(self, x, eps, gamma, dy):
        """
        Calculate rmsnorm bwd reference using paddle op
        """
        x.stop_gradient = False
        gamma.stop_gradient = False

        y = self.calc_fwd_ref(x, eps, gamma)

        paddle.autograd.backward([y], [dy], True)

        return x.grad, gamma.grad

    def test_rmsnorm_fwd(self):
        """
        Test BF16 RMSNorm Forward
        """
        N, H = (16, 32)
        eps = 1e-3
608
609
        x = paddle.uniform(shape=(N, H), dtype="bfloat16")
        gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
Shijie's avatar
Shijie committed
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625

        y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)

        y_ref = self.calc_fwd_ref(x, eps, gamma)

        assert_allclose(y, y_ref, rtol=1e-2, atol=1e-2)

    @staticmethod
    def test_rmsnorm_fwd_fp8():
        """
        Test FP8 RMSNorm Forward
        """
        fp8_dtype = tex.DType.kFloat8E4M3
        N, H = (16, 32)
        eps = 1e-3

626
627
        x = paddle.uniform(shape=(N, H), dtype="float32")
        gamma = paddle.uniform(shape=(H,), dtype="float32")
Shijie's avatar
Shijie committed
628

629
630
        fp8_tensor = FP8FwdTensors.GEMM1_INPUT
        fp8_meta = create_fp8_meta()
Shijie's avatar
Shijie committed
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

        y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32)

        y_fp8, rsigma = rmsnorm_fwd_fp8(x, gamma, eps, fp8_meta, fp8_tensor, fp8_dtype)

        y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32)

        assert_allclose(y, y_ref, rtol=0.1, atol=0.01)
        assert_allclose(rsigma, rsigma_ref)

    def test_rmsnorm_bwd(self):
        """
        Test BF16 RMSNorm Backward
        """
        N, H = (16, 32)
        eps = 1e-3
647
648
649
        x = paddle.uniform(shape=(N, H), dtype="bfloat16")
        dy = paddle.uniform(shape=(N, H), dtype="bfloat16")
        gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
Shijie's avatar
Shijie committed
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664

        dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy)

        _, rsigma = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
        dx, dgamma = rmsnorm_bwd(dy, x, rsigma, gamma)

        assert_allclose(dx, dx_ref, rtol=1e-2, atol=1e-2)
        assert_allclose(dgamma, dgamma_ref, rtol=1e-2, atol=5e-2)


class TestFusedAttn:
    """
    Test fused attention operators
    """

665
    def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode="self_attn", is_causal_masking=False):
Shijie's avatar
Shijie committed
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
        """
        set test input
        """

        def _random(shape):
            if self.dtype == "bfloat16":
                data = np.random.normal(loc=0.0, scale=0.02, size=shape).astype("float32")
                return convert_float_to_uint16(data)
            return np.random.normal(loc=0.0, scale=0.02, size=shape).astype(self.dtype)

        self.batch_size = b
        self.q_seqlen = s_q
        self.kv_seqlen = s_kv
        self.num_heads = h
        self.head_size = d
        self.dropout_prob = 0.0
        self.scaling_factor = 1.0 / np.sqrt(d)
        self.q_shape = (b, s_q, h, d)
        self.kv_shape = (b, s_kv, h, d)
        self.fuse_qkv_shape = (b, s_q, 3, h, d)
        self.fuse_kv_shape = (b, s_kv, 2, h, d)
        self.bias_shape = (1, h, s_q, s_kv)
        self.attn_mode = attn_mode
        self.dtype = dtype
        self.is_causal_masking = is_causal_masking

        self.q = _random(self.q_shape)
        if self.attn_mode == "self_attn":
Shijie's avatar
Shijie committed
694
            assert self.q_seqlen == self.kv_seqlen, "self attention requires q_seqlen == kv_seqlen"
Shijie's avatar
Shijie committed
695
696
697
698
            self.kv = self.q
        else:
            self.kv = _random(self.kv_shape)

Tim Moon's avatar
Tim Moon committed
699
700
701
702
703
704
705
706
707
708
709
710
711
712
        self.q_actual_seqlen = None
        if self.is_causal_masking:
            self.q_actual_seqlen = np.full(
                self.batch_size,
                self.q_seqlen,
                dtype=np.int32,
            )
        else:
            self.q_actual_seqlen = np.random.randint(
                low=20,
                high=self.q_seqlen,
                size=(self.batch_size,),
                dtype=np.int32,
            )
Shijie's avatar
Shijie committed
713
714
715
716
717
718
        self.kv_actual_seqlen = self.q_actual_seqlen

        self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen)
        self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0)
        self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen)
        self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0)
Tim Moon's avatar
Tim Moon committed
719
        self.attn_mask = np.ones(
Shijie's avatar
Shijie committed
720
721
722
            shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
            dtype=np.int32,
        )
Tim Moon's avatar
Tim Moon committed
723
724
725
726
        if self.is_causal_masking:
            assert attn_mode == "self_attn", "only support causal masking for self attention"
            for i in range(0, self.batch_size):
                for j in range(self.q_actual_seqlen[i]):
727
                    self.attn_mask[i, :, j, : j + 1] = 0
Tim Moon's avatar
Tim Moon committed
728
729
        else:
            for i in range(0, self.batch_size):
730
                self.attn_mask[i, :, : self.q_actual_seqlen[i], : self.kv_actual_seqlen[i]] = 0
Shijie's avatar
Shijie committed
731
732
733
734
735
736
737
738
739
740

        dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
        self.dout = paddle.to_tensor(dout, dtype=self.dtype)

    def _get_reference_out(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
        k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
        v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)

741
742
743
        q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3])  # [b, s, h, d] -> [b, h, s, d]
        k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3])  # [b, s, h, d] -> [b, h, s, d]
        v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3])  # [b, s, h, d] -> [b, h, s, d]
Shijie's avatar
Shijie committed
744
745
746
747
748
749
750
751

        qk_out = paddle.matmul(
            x=q_out * self.scaling_factor,
            y=k_out,
            transpose_x=False,
            transpose_y=True,
        )

752
        attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast("bool")
Tim Moon's avatar
Tim Moon committed
753
754
        attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype)
        attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out)
755
        attn_mask_out = paddle.cast(attn_mask_out, "float32")
Shijie's avatar
Shijie committed
756
        softmax_out = F.softmax(attn_mask_out)
Tim Moon's avatar
Tim Moon committed
757
        softmax_out = paddle.cast(softmax_out, self.dtype)
Shijie's avatar
Shijie committed
758
759
760
761
762
763
764
765
766
767
768
769

        if self.dropout_prob:
            dropout_out = F.dropout(
                softmax_out,
                self.dropout_prob,
                training=self.training,
                mode="upscale_in_train",
            )
            qkv_out = paddle.matmul(dropout_out, v_out)
        else:
            qkv_out = paddle.matmul(softmax_out, v_out)

770
        out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3])  # [b, h, s, d] -> [b, s, h, d]
Shijie's avatar
Shijie committed
771
772
773
774
775
776
777
778
779
780
781
782

        paddle.autograd.backward(
            [out],
            [self.dout],
            retain_graph=True,
        )
        return out, q_tensor.grad, k_tensor.grad, v_tensor.grad

    def _get_fused_attention_out(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))

        if self.attn_mode == "self_attn":
783
            qkv = np.stack([self.q, self.kv, self.kv], axis=2)  # [b, s, 3, h, d]
Shijie's avatar
Shijie committed
784
785
786
            qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False)
        else:
            q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
787
            kv = np.stack([self.kv, self.kv], axis=2)  # [b, s, 2, h, d]
Shijie's avatar
Shijie committed
788
789
790
791
792
            kv_tensor = paddle.to_tensor(kv, stop_gradient=False)

        q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
        kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)

793
        qkv_layout = "bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd"
Tim Moon's avatar
Tim Moon committed
794
        fused_attention_backend = get_fused_attention_backend(
795
796
            num_heads=self.num_heads,
            num_gqa_groups=self.num_heads,
Tim Moon's avatar
Tim Moon committed
797
798
            q_seqlen=self.q_seqlen,
            kv_seqlen=self.kv_seqlen,
799
            head_size=self.head_size,
Tim Moon's avatar
Tim Moon committed
800
801
802
803
804
805
            dtype=self.dtype,
            dropout=self.dropout_prob,
            qkv_layout=qkv_layout,
            bias_type="no_bias",
            mask_type="causal" if self.is_causal_masking else "padding",
        )
Shijie's avatar
Shijie committed
806
807
808

        qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
        out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
809
        if self.attn_mode == "self_attn":
810
            out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked(
Shijie's avatar
Shijie committed
811
812
813
814
815
                qkv_tensor,
                q_cu_seqlen_tensor,
                is_training=True,
                max_seqlen=self.q_seqlen,
                qkv_dtype=qkv_dtype,
816
                fused_attention_backend=fused_attention_backend,
Shijie's avatar
Shijie committed
817
818
819
820
                Bias=None,
                attn_scale=self.scaling_factor,
                dropout=self.dropout_prob,
                set_zero=False,
821
822
                attn_mask_type="causal" if self.is_causal_masking else "padding",
            )
Shijie's avatar
Shijie committed
823
824
825
            dqkv, _ = fused_attn_bwd_qkvpacked(
                qkv_tensor,
                q_cu_seqlen_tensor,
826
                rng_state,
Shijie's avatar
Shijie committed
827
828
829
830
831
                out,
                self.dout,
                softmax_aux_tensor,
                max_seqlen=self.q_seqlen,
                qkv_dtype=qkv_dtype,
832
                fused_attention_backend=fused_attention_backend,
Shijie's avatar
Shijie committed
833
834
835
                attn_scale=self.scaling_factor,
                dropout=self.dropout_prob,
                set_zero=False,
836
837
                attn_mask_type="causal" if self.is_causal_masking else "padding",
            )
Shijie's avatar
Shijie committed
838
839
840
            q_grad = dqkv[:, :, 0, :, :]
            k_grad = dqkv[:, :, 1, :, :]
            v_grad = dqkv[:, :, 2, :, :]
841
        else:  # attn_mode == 'cross_attn'
842
843
844
845
846
847
848
849
850
851
852
853
854
            out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked(
                q_tensor,
                kv_tensor,
                q_cu_seqlen_tensor,
                kv_cu_seqlen_tensor,
                is_training=True,
                max_seqlen_q=self.q_seqlen,
                max_seqlen_kv=self.kv_seqlen,
                qkv_dtype=qkv_dtype,
                fused_attention_backend=fused_attention_backend,
                Bias=None,
                attn_scale=self.scaling_factor,
                dropout=self.dropout_prob,
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
                set_zero=False,
            )
            dq, dkv, _ = fused_attn_bwd_kvpacked(
                q_tensor,
                kv_tensor,
                q_cu_seqlen_tensor,
                kv_cu_seqlen_tensor,
                rng_state,
                out,
                self.dout,
                softmax_aux_tensor,
                fused_attention_backend=fused_attention_backend,
                max_seqlen_q=self.q_seqlen,
                max_seqlen_kv=self.kv_seqlen,
                qkv_dtype=qkv_dtype,
                attn_scale=self.scaling_factor,
                dropout=self.dropout_prob,
                set_zero=False,
            )
Shijie's avatar
Shijie committed
874
875
876
877
            q_grad = dq
            k_grad = dkv[:, :, 0, :, :]
            v_grad = dkv[:, :, 1, :, :]

Shijie's avatar
Shijie committed
878
        return out, q_grad, k_grad, v_grad
Shijie's avatar
Shijie committed
879

Shijie's avatar
Shijie committed
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
    def _get_fused_attention_with_separate_qkv(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))

        q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
        k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
        v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)

        q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
        kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)

        qkv_layout = "bshd_bshd_bshd"
        fused_attention_backend = get_fused_attention_backend(
            num_heads=self.num_heads,
            num_gqa_groups=self.num_heads,
            q_seqlen=self.q_seqlen,
            kv_seqlen=self.kv_seqlen,
            head_size=self.head_size,
            dtype=self.dtype,
            dropout=self.dropout_prob,
            qkv_layout=qkv_layout,
            bias_type="no_bias",
            mask_type="causal" if self.is_causal_masking else "padding",
        )

        qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
        out, softmax_aux_tensor, rng_state = fused_attn_fwd(
            q_tensor,
            k_tensor,
            v_tensor,
            q_cu_seqlen_tensor,
            kv_cu_seqlen_tensor,
            is_training=True,
            max_seqlen_q=self.q_seqlen,
            max_seqlen_kv=self.kv_seqlen,
            qkv_dtype=qkv_dtype,
            fused_attention_backend=fused_attention_backend,
            Bias=None,
            attn_scale=self.scaling_factor,
            dropout=self.dropout_prob,
            set_zero=False,
            qkv_layout=qkv_layout,
921
922
            attn_mask_type="causal" if self.is_causal_masking else "padding",
        )
Shijie's avatar
Shijie committed
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
        dq, dk, dv, _ = fused_attn_bwd(
            q_tensor,
            k_tensor,
            v_tensor,
            q_cu_seqlen_tensor,
            kv_cu_seqlen_tensor,
            rng_state,
            out,
            self.dout,
            softmax_aux_tensor,
            fused_attention_backend=fused_attention_backend,
            max_seqlen_q=self.q_seqlen,
            max_seqlen_kv=self.kv_seqlen,
            qkv_dtype=qkv_dtype,
            attn_scale=self.scaling_factor,
            dropout=self.dropout_prob,
            set_zero=False,
            qkv_layout=qkv_layout,
941
942
            attn_mask_type="causal" if self.is_causal_masking else "padding",
        )
Shijie's avatar
Shijie committed
943
944
945

        return out, dq, dk, dv

946
947
948
    @pytest.mark.parametrize("b, s, h, d", SELF_ATTN_CASES)
    @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
    @pytest.mark.parametrize("is_causal_masking", [True, False])
Shijie's avatar
Shijie committed
949
950
951
952
    def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
        """
        test self attention forward + backward
        """
Tim Moon's avatar
Tim Moon committed
953
        if not is_fused_attention_supported(
954
955
956
957
958
959
960
961
962
963
            num_heads=h,
            num_gqa_groups=h,
            q_seqlen=s,
            kv_seqlen=s,
            head_size=d,
            dtype=dtype,
            dropout=0.0,
            qkv_layout="bs3hd",
            bias_type="no_bias",
            mask_type="causal" if is_causal_masking else "padding",
Tim Moon's avatar
Tim Moon committed
964
965
        ):
            pytest.skip("cuDNN fused attention is not supported")
Shijie's avatar
Shijie committed
966
967
968
969
970
971
972
973
        self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
        reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
        fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
        assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
        assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)

974
975
    @pytest.mark.parametrize("b, s_q, s_kv, h, d", CROSS_ATTN_CASES)
    @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
Shijie's avatar
Shijie committed
976
977
978
979
    def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
        """
        test cross attention forward + backward
        """
Tim Moon's avatar
Tim Moon committed
980
        if not is_fused_attention_supported(
981
982
983
984
985
986
987
988
989
990
            num_heads=h,
            num_gqa_groups=h,
            q_seqlen=s_q,
            kv_seqlen=s_kv,
            head_size=d,
            dtype=dtype,
            dropout=0.0,
            qkv_layout="bshd_bs2hd",
            bias_type="no_bias",
            mask_type="padding",
Tim Moon's avatar
Tim Moon committed
991
992
        ):
            pytest.skip("cuDNN fused attention is not supported")
Shijie's avatar
Shijie committed
993
994
        self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn")
        reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
995
996
997
998
999
1000
        fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
        assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
        assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)

1001
1002
1003
    @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES)
    @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
    @pytest.mark.parametrize("is_causal_masking", [True])
1004
1005
1006
1007
    def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
        """
        test flash attention forward + backward
        """
Tim Moon's avatar
Tim Moon committed
1008
        if not is_fused_attention_supported(
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
            num_heads=h,
            num_gqa_groups=h,
            q_seqlen=s,
            kv_seqlen=s,
            head_size=d,
            dtype=dtype,
            dropout=0.0,
            qkv_layout="bs3hd",
            bias_type="no_bias",
            mask_type="causal" if is_causal_masking else "padding",
Tim Moon's avatar
Tim Moon committed
1019
1020
        ):
            pytest.skip("cuDNN fused attention is not supported")
1021
1022
        self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
        reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
Shijie's avatar
Shijie committed
1023
1024
1025
1026
1027
1028
        fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
        assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
        assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)

1029
1030
1031
1032
1033
1034
    @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES)
    @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
    @pytest.mark.parametrize("is_causal_masking", [False, True])
    def test_fused_attn_with_separate_qkv_forward_backward(
        self, b, s, h, d, dtype, is_causal_masking
    ):
Shijie's avatar
Shijie committed
1035
1036
1037
1038
        """
        test flash attention forward + backward with separate qkv inputs
        """
        if not is_fused_attention_supported(
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
            num_heads=h,
            num_gqa_groups=h,
            q_seqlen=s,
            kv_seqlen=s,
            head_size=d,
            dtype=dtype,
            dropout=0.0,
            qkv_layout="bshd_bshd_bshd",
            bias_type="no_bias",
            mask_type="causal" if is_causal_masking else "padding",
Shijie's avatar
Shijie committed
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        ):
            pytest.skip("cuDNN fused attention is not supported")
        self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
        reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
        fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_with_separate_qkv()
        assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
        assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)

Shijie's avatar
Shijie committed
1059
1060
1061
1062
1063
1064

class TestSoftmax:
    """
    Test softmax operators
    """

1065
    @staticmethod
1066
    @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
1067
    def test_scaled_softmax_fwd_bwd(dtype):
Shijie's avatar
Shijie committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
        """test scaled softmax"""
        B, H, S = (16, 4, 32)
        scale = 0.8

        x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
        x.stop_gradient = False
        dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)

        y_ref = F.softmax(scale * x)
        y = scaled_softmax_forward(x, scale)

        paddle.autograd.backward([y_ref], [dy], True)
        dx_ref = x.grad
        dx = scaled_softmax_backward(dy, y, scale)

        assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
        assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)

1086
    @staticmethod
1087
    @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
1088
    def test_scaled_masked_softmax_fwd_bwd(dtype):
Shijie's avatar
Shijie committed
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
        """test scaled masked softmax"""
        B, H, S = (16, 4, 32)
        scale = 0.8

        x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
        x.stop_gradient = False
        dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
        mask = paddle.reshape(x[0, 0] > 0.3, shape=(1, 1, S, S))
        mask_flipped = x[0, 0] <= 0.3
        mask_ref = (mask_flipped.astype(dtype) - 1.0) * 1e4

        y_ref = F.softmax(scale * x + mask_ref)
        y = scaled_masked_softmax_forward(x, mask, scale)

        paddle.autograd.backward([y_ref], [dy], True)
        dx_ref = x.grad
        dx = scaled_masked_softmax_backward(dy, y, scale)

        assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
        assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)

1110
    @staticmethod
1111
    @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
1112
    def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype):
Shijie's avatar
Shijie committed
1113
1114
1115
1116
1117
1118
1119
1120
        """test scaled upper triang masked softmax"""
        B, S = (16, 32)
        scale = 0.8

        x = paddle.uniform(shape=(B, S, S), dtype=dtype)
        x.stop_gradient = False
        dy = paddle.uniform(shape=(B, S, S), dtype=dtype)

1121
        mask = paddle.ones((S, S), dtype="int32")
Shijie's avatar
Shijie committed
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        col_beg, col_end = 1, S
        for row in range(0, S):
            mask[row, col_beg:col_end] = 0
            col_beg += 1

        mask_ref = (mask.astype(dtype) - 1.0) * 1e4

        y_ref = F.softmax(scale * x + mask_ref)
        y = scaled_upper_triang_masked_softmax_forward(x, scale)

        paddle.autograd.backward([y_ref], [dy], True)
        dx_ref = x.grad
        dx = scaled_upper_triang_masked_softmax_backward(dy, y, scale)

        assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3)
        assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
1138
1139


1140
@pytest.mark.parametrize("update_weight_scale_inv", [True, False])
1141
def test_amax_and_scale_update(update_weight_scale_inv):
1142
1143
    """Test update_scale"""
    num_gemm = 6
1144
    history_len = 1024
1145
    recipe = DelayedScaling()
1146
    fp8_dtype = tex.DType.kFloat8E4M3
1147
    fp8_max = recipe.fp8_format.value.max_fwd
1148
    non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2))
1149

1150
    amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32")
1151
1152
1153
    rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0)
    rolled_history_ref[0] = 0.0
    amax_tensor = paddle.max(amax_history_tensor, axis=0)
1154
    scale_tensor = paddle.ones(shape=[num_gemm], dtype="float32")
1155
1156
1157

    def calc_ref(amax, scale, fp8_max, margin=0):
        """Calculate reference scale"""
1158
        sf = (fp8_max / amax) / (2**margin)
1159
1160
1161
1162
        sf = paddle.where(amax > 0.0, sf, scale)
        sf = paddle.where(paddle.isfinite(amax), sf, scale)
        return sf

1163
    scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.0)
1164
    if update_weight_scale_inv:
1165
        scale_inv_ref = 1.0 / scale_ref
1166
1167
    else:
        scale_inv_ref = paddle.zeros_like(scale_tensor)
1168
        scale_inv_ref = paddle.where(non_weight_mask, 1.0 / scale_ref, scale_inv_ref)
1169

1170
1171
1172
1173
    # Placeholder
    scale_actual = paddle.zeros_like(scale_tensor)
    scale_inv_actual = paddle.zeros_like(scale_tensor)

1174
1175
    if update_weight_scale_inv:
        non_weight_mask = paddle.empty([0])
1176
1177
1178
1179
1180
1181
1182
1183
1184
    tex.amax_and_scale_update_inplace(
        _amax_history=amax_history_tensor,
        _scale=scale_actual,
        _scale_inv=scale_inv_actual,
        non_weight_mask=non_weight_mask,
        fp8_dtype=int(fp8_dtype),
        margin=0.0,
        amax_compute="max",
    )
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195

    assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7)
    assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7)
    assert_allclose(amax_history_tensor, rolled_history_ref, rtol=1e-7, atol=1e-7)


def test_update_latest_history():
    """Test update_latest_history"""
    num_gemm = 6
    history_len = 1024

1196
1197
    amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32")
    amax = paddle.rand(shape=[num_gemm], dtype="float32")
1198
1199
1200
1201

    tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax)

    assert_allclose(amax_history_tensor[0], amax, rtol=1e-7, atol=1e-7)