test_operators.py 41 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Test TE operators"""

Shijie's avatar
Shijie committed
6
7
import struct

Shijie's avatar
Shijie committed
8
9
10
11
12
13
from utils import (
    assert_allclose,
    create_fp8_meta,
    get_fused_attention_backend,
    is_fused_attention_supported,
)
Shijie's avatar
Shijie committed
14
import numpy as np
15
import paddle
Shijie's avatar
Shijie committed
16
import paddle.nn.functional as F
Tim Moon's avatar
Tim Moon committed
17
import pytest
18

19
20
21
22
23
24
25
26
from utils import (
    assert_allclose,
    create_fp8_meta,
    get_fused_attention_backend,
    is_fused_attention_supported,
)

import transformer_engine_paddle as tex
27
28
29
30
31
32
33
from transformer_engine.paddle.cpp_extensions import (
    cast_to_fp8,
    cast_from_fp8,
    gemm,
    fp8_gemm,
    transpose,
    cast_transpose,
34
    cast_transpose_bgrad,
35
36
37
38
39
40
    te_gelu,
    gelu_fp8,
    dgelu_cast_transpose_bgrad_fp8,
    layernorm_fwd_fp8,
    layernorm_fwd,
    layernorm_bwd,
Shijie's avatar
Shijie committed
41
42
43
44
45
46
47
    rmsnorm_fwd_fp8,
    rmsnorm_fwd,
    rmsnorm_bwd,
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
Shijie's avatar
Shijie committed
48
49
    fused_attn_fwd,
    fused_attn_bwd,
Shijie's avatar
Shijie committed
50
51
52
53
54
55
    scaled_softmax_forward,
    scaled_softmax_backward,
    scaled_masked_softmax_forward,
    scaled_masked_softmax_backward,
    scaled_upper_triang_masked_softmax_forward,
    scaled_upper_triang_masked_softmax_backward,
56
)
57
from transformer_engine.paddle.fp8 import is_fp8_available
58
59
from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling
60
61
62
63
64

GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
              (16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available()

Shijie's avatar
Shijie committed
65
66
SELF_ATTN_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
CROSS_ATTN_CASES = [(32, 128, 512, 16, 64)]
67
FLASH_ATTN_CASES = [(4, 1024, 16, 64), (2, 2048, 16, 128)]
Shijie's avatar
Shijie committed
68
69
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]

70

Tian Zheng's avatar
Tian Zheng committed
71
72
73
74
75
76
77
78
@pytest.fixture(autouse=True)
def setup():
    """Setup random seed before each test"""
    np.random.seed(10)
    paddle.seed(11)
    yield


79
80
81
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize('inplace', [True, False])
def test_quantize_dequantize(fp8_dtype, inplace):
82
83
84
85
86
    """
    Test cast_to_fp8 and cast_from_fp8
    """
    a = paddle.rand(shape=(32, 32), dtype='float32')
    # Init fp8_meta
87
    fp8_meta = create_fp8_meta()
88
89
90
91
92
93
94
95
96
97
    a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None
    a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype, out=a_fp8)
    b = cast_from_fp8(
        a_fp8,
        fp8_meta,
        FP8FwdTensors.GEMM1_OUTPUT,
        itype=fp8_dtype,
        otype=tex.DType.kFloat32,
    )
    assert_allclose(a, b, rtol=5e-2, atol=5e-2)
98
99


Shijie's avatar
Shijie committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def copy_bits_from_float_to_uint16(f):
    """
    Copy bits
    """
    return struct.unpack('<I', struct.pack('<f', f))[0] >> 16


def convert_float_to_uint16(float_list):
    """
    convert float to uint16
    """
    new_output = []
    for x in np.nditer(float_list):
        new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
    new_output = np.reshape(new_output, float_list.shape).view(np.uint16)

    return new_output


119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class TestTranspose:
    """
    Test transpose operators
    """

    @staticmethod
    def test_transpose_bf16():
        """
        Test BF16 transpose
        """
        a = paddle.rand(shape=(16, 32), dtype='bfloat16')
        a_transposed = transpose(a, otype=tex.DType.kBFloat16)
        assert_allclose(a_transposed, a.T)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
    def test_transpose_fp8(fp8_dtype):
        """
        Test FP8 transpose
        """
        min_val = -8
        max_val = 8
        a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
143
144
        fp8_meta = create_fp8_meta()
        a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
145
146
147
        a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype)
        a_transposed = cast_from_fp8(a_fp8_transposed,
                                     fp8_meta,
148
                                     FP8FwdTensors.GEMM1_INPUT,
149
150
151
152
153
154
155
                                     itype=fp8_dtype,
                                     otype=tex.DType.kFloat32)
        assert_allclose(a_transposed, a.T)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
156
157
    @pytest.mark.parametrize('inplace', [True, False])
    def test_cast_transpose(fp8_dtype, inplace):
158
159
160
161
162
163
        """
        Test cast_transpose
        """
        min_val = -8
        max_val = 8
        a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
164
        fp8_meta = create_fp8_meta()
165
166
167
168
        a_fp8_casted, a_fp8_transposed = None, None
        if inplace:
            a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8)
            a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8)
169
170
        a_fp8_casted, a_fp8_transposed = cast_transpose(a,
                                                        fp8_meta,
171
                                                        FP8FwdTensors.GEMM1_INPUT,
172
173
174
                                                        otype=fp8_dtype,
                                                        cast_out=a_fp8_casted,
                                                        transpose_out=a_fp8_transposed)
175
176
177

        a_transposed = cast_from_fp8(a_fp8_transposed,
                                     fp8_meta,
178
                                     FP8FwdTensors.GEMM1_INPUT,
179
180
181
182
183
                                     itype=fp8_dtype,
                                     otype=tex.DType.kFloat32)

        a_casted = cast_from_fp8(a_fp8_casted,
                                 fp8_meta,
184
                                 FP8FwdTensors.GEMM1_INPUT,
185
186
187
188
189
190
                                 itype=fp8_dtype,
                                 otype=tex.DType.kFloat32)

        assert_allclose(a_casted, a)
        assert_allclose(a_transposed, a.T)

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
    def test_cast_transpose_bgrad(fp8_dtype):
        """
        Test cast_transpose_bgrad
        """
        min_val = -8
        max_val = 8
        a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
        fp8_meta = create_fp8_meta()
        bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(a,
                                                                     fp8_meta,
                                                                     FP8FwdTensors.GEMM1_INPUT,
                                                                     otype=fp8_dtype)

        a_transposed = cast_from_fp8(a_fp8_transposed,
                                     fp8_meta,
                                     FP8FwdTensors.GEMM1_INPUT,
                                     itype=fp8_dtype,
                                     otype=tex.DType.kFloat32)

        a_casted = cast_from_fp8(a_fp8_casted,
                                 fp8_meta,
                                 FP8FwdTensors.GEMM1_INPUT,
                                 itype=fp8_dtype,
                                 otype=tex.DType.kFloat32)

        assert_allclose(a_casted, a)
        assert_allclose(a_transposed, a.T)
        assert_allclose(bgrad, a.sum(axis=0))

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

class TestActivation:
    """
    Test activation operators
    """

    @staticmethod
    def test_gelu_bf16():
        """
        Test BF16 GELU Forward
        """
        a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
        gelu_out = te_gelu(a, otype=tex.DType.kBFloat16)
        gelu_ref = paddle.nn.GELU()(a)

        assert_allclose(gelu_out, gelu_ref, rtol=1e-2)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
    def test_gelu_fp8(fp8_dtype):
        """
        Test FP8 GELU Forward
        """
        a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
248
        fp8_meta = create_fp8_meta()
249

250
        gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
251
252
253

        gelu_out = cast_from_fp8(gelu_out_fp8,
                                 fp8_meta,
254
                                 FP8FwdTensors.GEMM1_INPUT,
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
                                 itype=fp8_dtype,
                                 otype=tex.DType.kFloat32)

        gelu_ref = paddle.nn.GELU()(a)

        assert_allclose(gelu_out, gelu_ref, rtol=0.1, atol=0.01)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
    def test_gelu_bwd_fp8(fp8_dtype):
        """
        Test FP8 GELU Backward
        """
        # y = GELU(x), calculate ref
        x = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
        x.stop_gradient = False
        y = paddle.nn.GELU()(x)
        y_grad = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
        paddle.autograd.backward([y], [y_grad], True)
        # calculate fp8
276
277
278
279
280
281
        fp8_meta = create_fp8_meta()
        x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(y_grad,
                                                                         x,
                                                                         fp8_meta,
                                                                         FP8FwdTensors.GEMM1_INPUT,
                                                                         otype=fp8_dtype)
282
283
284

        x_grad = cast_from_fp8(x_grad_fp8,
                               fp8_meta,
285
                               FP8FwdTensors.GEMM1_INPUT,
286
287
288
289
290
                               itype=fp8_dtype,
                               otype=tex.DType.kFloat32)

        x_grad_t = cast_from_fp8(x_grad_t_fp8,
                                 fp8_meta,
291
                                 FP8FwdTensors.GEMM1_INPUT,
292
293
294
295
296
297
298
299
                                 itype=fp8_dtype,
                                 otype=tex.DType.kFloat32)

        assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)
        assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01)
        assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01)


300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
class TestGemm:
    """
    Tests for gemm(cuBLASLt) operator
    """

    @staticmethod
    @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
                        reason="BF16 GEMM requires Ampere+ GPU")
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_bf16(m, n, k):
        """
        Test "TN" BF16 GEMM
        """
        a = paddle.rand(shape=(m, k), dtype='bfloat16')
        b = paddle.rand(shape=(n, k), dtype='bfloat16')

        workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')

        ref_out = paddle.matmul(a, b.T)
        # CublasLt inside tex.te_gemm assumes inputs are column major.
        # Mathematically, A@B=C is equivalent to B^T@A^T=C^T, where X^T is the
        # transpose of X.
        # Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T,
        # which is equivalent to a@b^T = C in row major.
        actual_out, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, False, "TN",
                                None, None, False)

Tim Moon's avatar
Tim Moon committed
327
        assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358

    @staticmethod
    @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
                        reason="BF16 GEMM requires Ampere+ GPU")
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_bf16_inplace(m, n, k):
        """
        Test "TN" BF16 GEMM, with accumulate=True
        """
        min_val = -16
        max_val = 16
        a = paddle.rand(shape=(m, k), dtype='bfloat16')
        b = paddle.rand(shape=(n, k), dtype='bfloat16')
        c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), 'bfloat16')
        workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')

        ref_out = c + paddle.matmul(a, b.T)

        actual_out = paddle.clone(c)
        _, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, True, "TN", actual_out,
                       None, False)

        assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2)

    @staticmethod
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('m,n,k', GEMM_CASES)
    def test_fp8_randint(m, n, k):
        """
        Test "TN" FP8 GEMM
        """
Tim Moon's avatar
Tim Moon committed
359
360
        min_val = -4
        max_val = 4
361
362
        fp8_dtype = tex.DType.kFloat8E4M3
        out_dtype = paddle.float32
363
        fp8_meta = create_fp8_meta(num_gemms=1)
364
365
366

        a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), 'float32')

367
        a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
368
        b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), 'float32')
369
        b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype)
370
371
372
        workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')

        ref_out = paddle.matmul(a, b.T)
373
374
375
        actual_out = fp8_gemm(b_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype,
                              a_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_INPUT, fp8_dtype,
                              out_dtype, workspace)
376
377

        assert_allclose(actual_out, ref_out)
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432


class TestLayerNorm:
    """
    Test layernorm operators
    """

    @staticmethod
    def calc_fwd_ref(x, eps, gamma, beta):
        """
        Calculate reference using paddle layer_norm op
        """
        y = paddle.nn.functional.layer_norm(x=x,
                                            normalized_shape=x.shape[1:],
                                            weight=gamma,
                                            bias=beta,
                                            epsilon=eps)
        mean = paddle.mean(x, axis=-1)
        var = paddle.var(x, axis=-1)
        inv_var = paddle.sqrt(1. / var)
        return y, mean, inv_var

    @staticmethod
    def calc_bwd_ref(x, eps, gamma, beta, dy):
        """
        Calculate reference using paddle layer_norm op
        """
        x.stop_gradient = False
        gamma.stop_gradient = False
        beta.stop_gradient = False

        y = paddle.nn.functional.layer_norm(x=x,
                                            normalized_shape=x.shape[1:],
                                            weight=gamma,
                                            bias=beta,
                                            epsilon=eps)

        paddle.autograd.backward([y], [dy], True)

        return x.grad, gamma.grad, beta.grad

    def test_layernorm_fwd(self):
        """
        Test BF16 LayerNorm Forward
        """
        N, H = (16, 32)
        eps = 1e-3
        x = paddle.uniform(shape=(N, H), dtype='bfloat16')
        gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
        beta = paddle.uniform(shape=(H,), dtype='bfloat16')

        y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)

        y_ref, mu_ref, rsigma_ref = self.calc_fwd_ref(x, eps, gamma, beta)

Shijie's avatar
Shijie committed
433
        assert_allclose(y, y_ref, rtol=1e-4, atol=1e-4)
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        assert_allclose(mu, mu_ref, rtol=1e-3, atol=1e-3)
        assert_allclose(rsigma, rsigma_ref, rtol=5e-2, atol=5e-2)

    @staticmethod
    def test_layernorm_fwd_fp8():
        """
        Test FP8 LayerNorm Forward
        """
        fp8_dtype = tex.DType.kFloat8E4M3
        N, H = (16, 32)
        eps = 1e-3

        x = paddle.uniform(shape=(N, H), dtype='float32')
        gamma = paddle.uniform(shape=(H,), dtype='float32')
        beta = paddle.uniform(shape=(H,), dtype='float32')

450
451
        fp8_tensor = FP8FwdTensors.GEMM1_INPUT
        fp8_meta = create_fp8_meta()
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481

        y_ref, mu_ref, rsigma_ref = layernorm_fwd(x, gamma, beta, eps, tex.DType.kFloat32)

        y_fp8, mu, rsigma = layernorm_fwd_fp8(x, gamma, beta, eps, fp8_meta, fp8_tensor, fp8_dtype)

        y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32)

        assert_allclose(y, y_ref, rtol=0.1, atol=0.01)
        assert_allclose(mu, mu_ref)
        assert_allclose(rsigma, rsigma_ref)

    def test_layernorm_bwd(self):
        """
        Test BF16 LayerNorm Backward
        """
        N, H = (16, 32)
        eps = 1e-3
        x = paddle.uniform(shape=(N, H), dtype='bfloat16')
        dy = paddle.uniform(shape=(N, H), dtype='bfloat16')
        gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
        beta = paddle.uniform(shape=(H,), dtype='bfloat16')

        dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy)

        _, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)
        dx, dgamma, dbeta = layernorm_bwd(dy, x, mu, rsigma, gamma)

        assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5)
        assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5)
        assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5)
Shijie's avatar
Shijie committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539


class TestRMSNorm:
    """
    Test rmsnorm operators
    """

    @staticmethod
    def calc_fwd_ref(x, eps, gamma):
        """
        Calculate rmsnorm reference using paddle op
        """

        norm = paddle.rsqrt(paddle.mean(x**2, axis=-1, keepdim=True) + eps)
        y = x * norm * gamma

        return y

    def calc_bwd_ref(self, x, eps, gamma, dy):
        """
        Calculate rmsnorm bwd reference using paddle op
        """
        x.stop_gradient = False
        gamma.stop_gradient = False

        y = self.calc_fwd_ref(x, eps, gamma)

        paddle.autograd.backward([y], [dy], True)

        return x.grad, gamma.grad

    def test_rmsnorm_fwd(self):
        """
        Test BF16 RMSNorm Forward
        """
        N, H = (16, 32)
        eps = 1e-3
        x = paddle.uniform(shape=(N, H), dtype='bfloat16')
        gamma = paddle.uniform(shape=(H,), dtype='bfloat16')

        y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)

        y_ref = self.calc_fwd_ref(x, eps, gamma)

        assert_allclose(y, y_ref, rtol=1e-2, atol=1e-2)

    @staticmethod
    def test_rmsnorm_fwd_fp8():
        """
        Test FP8 RMSNorm Forward
        """
        fp8_dtype = tex.DType.kFloat8E4M3
        N, H = (16, 32)
        eps = 1e-3

        x = paddle.uniform(shape=(N, H), dtype='float32')
        gamma = paddle.uniform(shape=(H,), dtype='float32')

540
541
        fp8_tensor = FP8FwdTensors.GEMM1_INPUT
        fp8_meta = create_fp8_meta()
Shijie's avatar
Shijie committed
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604

        y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32)

        y_fp8, rsigma = rmsnorm_fwd_fp8(x, gamma, eps, fp8_meta, fp8_tensor, fp8_dtype)

        y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32)

        assert_allclose(y, y_ref, rtol=0.1, atol=0.01)
        assert_allclose(rsigma, rsigma_ref)

    def test_rmsnorm_bwd(self):
        """
        Test BF16 RMSNorm Backward
        """
        N, H = (16, 32)
        eps = 1e-3
        x = paddle.uniform(shape=(N, H), dtype='bfloat16')
        dy = paddle.uniform(shape=(N, H), dtype='bfloat16')
        gamma = paddle.uniform(shape=(H,), dtype='bfloat16')

        dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy)

        _, rsigma = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
        dx, dgamma = rmsnorm_bwd(dy, x, rsigma, gamma)

        assert_allclose(dx, dx_ref, rtol=1e-2, atol=1e-2)
        assert_allclose(dgamma, dgamma_ref, rtol=1e-2, atol=5e-2)


class TestFusedAttn:
    """
    Test fused attention operators
    """

    def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode='self_attn', is_causal_masking=False):
        """
        set test input
        """

        def _random(shape):
            if self.dtype == "bfloat16":
                data = np.random.normal(loc=0.0, scale=0.02, size=shape).astype("float32")
                return convert_float_to_uint16(data)
            return np.random.normal(loc=0.0, scale=0.02, size=shape).astype(self.dtype)

        self.batch_size = b
        self.q_seqlen = s_q
        self.kv_seqlen = s_kv
        self.num_heads = h
        self.head_size = d
        self.dropout_prob = 0.0
        self.scaling_factor = 1.0 / np.sqrt(d)
        self.q_shape = (b, s_q, h, d)
        self.kv_shape = (b, s_kv, h, d)
        self.fuse_qkv_shape = (b, s_q, 3, h, d)
        self.fuse_kv_shape = (b, s_kv, 2, h, d)
        self.bias_shape = (1, h, s_q, s_kv)
        self.attn_mode = attn_mode
        self.dtype = dtype
        self.is_causal_masking = is_causal_masking

        self.q = _random(self.q_shape)
        if self.attn_mode == "self_attn":
Shijie's avatar
Shijie committed
605
            assert self.q_seqlen == self.kv_seqlen, "self attention requires q_seqlen == kv_seqlen"
Shijie's avatar
Shijie committed
606
607
608
609
            self.kv = self.q
        else:
            self.kv = _random(self.kv_shape)

Tim Moon's avatar
Tim Moon committed
610
611
612
613
614
615
616
617
618
619
620
621
622
623
        self.q_actual_seqlen = None
        if self.is_causal_masking:
            self.q_actual_seqlen = np.full(
                self.batch_size,
                self.q_seqlen,
                dtype=np.int32,
            )
        else:
            self.q_actual_seqlen = np.random.randint(
                low=20,
                high=self.q_seqlen,
                size=(self.batch_size,),
                dtype=np.int32,
            )
Shijie's avatar
Shijie committed
624
625
626
627
628
629
        self.kv_actual_seqlen = self.q_actual_seqlen

        self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen)
        self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0)
        self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen)
        self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0)
Tim Moon's avatar
Tim Moon committed
630
        self.attn_mask = np.ones(
Shijie's avatar
Shijie committed
631
632
633
            shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
            dtype=np.int32,
        )
Tim Moon's avatar
Tim Moon committed
634
635
636
637
        if self.is_causal_masking:
            assert attn_mode == "self_attn", "only support causal masking for self attention"
            for i in range(0, self.batch_size):
                for j in range(self.q_actual_seqlen[i]):
638
                    self.attn_mask[i, :, j, :j + 1] = 0
Tim Moon's avatar
Tim Moon committed
639
640
641
        else:
            for i in range(0, self.batch_size):
                self.attn_mask[i, :, :self.q_actual_seqlen[i], :self.kv_actual_seqlen[i]] = 0
Shijie's avatar
Shijie committed
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662

        dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
        self.dout = paddle.to_tensor(dout, dtype=self.dtype)

    def _get_reference_out(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
        k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
        v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)

        q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3])    # [b, s, h, d] -> [b, h, s, d]
        k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3])    # [b, s, h, d] -> [b, h, s, d]
        v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3])    # [b, s, h, d] -> [b, h, s, d]

        qk_out = paddle.matmul(
            x=q_out * self.scaling_factor,
            y=k_out,
            transpose_x=False,
            transpose_y=True,
        )

Tim Moon's avatar
Tim Moon committed
663
664
665
666
        attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast('bool')
        attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype)
        attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out)
        attn_mask_out = paddle.cast(attn_mask_out, 'float32')
Shijie's avatar
Shijie committed
667
        softmax_out = F.softmax(attn_mask_out)
Tim Moon's avatar
Tim Moon committed
668
        softmax_out = paddle.cast(softmax_out, self.dtype)
Shijie's avatar
Shijie committed
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703

        if self.dropout_prob:
            dropout_out = F.dropout(
                softmax_out,
                self.dropout_prob,
                training=self.training,
                mode="upscale_in_train",
            )
            qkv_out = paddle.matmul(dropout_out, v_out)
        else:
            qkv_out = paddle.matmul(softmax_out, v_out)

        out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3])    # [b, h, s, d] -> [b, s, h, d]

        paddle.autograd.backward(
            [out],
            [self.dout],
            retain_graph=True,
        )
        return out, q_tensor.grad, k_tensor.grad, v_tensor.grad

    def _get_fused_attention_out(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))

        if self.attn_mode == "self_attn":
            qkv = np.stack([self.q, self.kv, self.kv], axis=2)    # [b, s, 3, h, d]
            qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False)
        else:
            q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
            kv = np.stack([self.kv, self.kv], axis=2)    # [b, s, 2, h, d]
            kv_tensor = paddle.to_tensor(kv, stop_gradient=False)

        q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
        kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)

704
        qkv_layout = ("bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd")
Tim Moon's avatar
Tim Moon committed
705
        fused_attention_backend = get_fused_attention_backend(
706
707
            num_heads=self.num_heads,
            num_gqa_groups=self.num_heads,
Tim Moon's avatar
Tim Moon committed
708
709
            q_seqlen=self.q_seqlen,
            kv_seqlen=self.kv_seqlen,
710
            head_size=self.head_size,
Tim Moon's avatar
Tim Moon committed
711
712
713
714
715
716
            dtype=self.dtype,
            dropout=self.dropout_prob,
            qkv_layout=qkv_layout,
            bias_type="no_bias",
            mask_type="causal" if self.is_causal_masking else "padding",
        )
Shijie's avatar
Shijie committed
717
718
719
720

        qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
        out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
        if self.attn_mode == 'self_attn':
721
            out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked(
Shijie's avatar
Shijie committed
722
723
724
725
726
                qkv_tensor,
                q_cu_seqlen_tensor,
                is_training=True,
                max_seqlen=self.q_seqlen,
                qkv_dtype=qkv_dtype,
727
                fused_attention_backend=fused_attention_backend,
Shijie's avatar
Shijie committed
728
729
730
731
732
733
734
735
                Bias=None,
                attn_scale=self.scaling_factor,
                dropout=self.dropout_prob,
                set_zero=False,
                attn_mask_type="causal" if self.is_causal_masking else "padding")
            dqkv, _ = fused_attn_bwd_qkvpacked(
                qkv_tensor,
                q_cu_seqlen_tensor,
736
                rng_state,
Shijie's avatar
Shijie committed
737
738
739
740
741
                out,
                self.dout,
                softmax_aux_tensor,
                max_seqlen=self.q_seqlen,
                qkv_dtype=qkv_dtype,
742
                fused_attention_backend=fused_attention_backend,
Shijie's avatar
Shijie committed
743
744
745
746
747
748
749
750
                attn_scale=self.scaling_factor,
                dropout=self.dropout_prob,
                set_zero=False,
                attn_mask_type="causal" if self.is_causal_masking else "padding")
            q_grad = dqkv[:, :, 0, :, :]
            k_grad = dqkv[:, :, 1, :, :]
            v_grad = dqkv[:, :, 2, :, :]
        else:    # attn_mode == 'cross_attn'
751
752
753
754
755
756
757
758
759
760
761
762
763
764
            out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked(
                q_tensor,
                kv_tensor,
                q_cu_seqlen_tensor,
                kv_cu_seqlen_tensor,
                is_training=True,
                max_seqlen_q=self.q_seqlen,
                max_seqlen_kv=self.kv_seqlen,
                qkv_dtype=qkv_dtype,
                fused_attention_backend=fused_attention_backend,
                Bias=None,
                attn_scale=self.scaling_factor,
                dropout=self.dropout_prob,
                set_zero=False)
Shijie's avatar
Shijie committed
765
766
767
768
            dq, dkv, _ = fused_attn_bwd_kvpacked(q_tensor,
                                                 kv_tensor,
                                                 q_cu_seqlen_tensor,
                                                 kv_cu_seqlen_tensor,
769
                                                 rng_state,
Shijie's avatar
Shijie committed
770
771
772
                                                 out,
                                                 self.dout,
                                                 softmax_aux_tensor,
773
                                                 fused_attention_backend=fused_attention_backend,
Shijie's avatar
Shijie committed
774
775
776
777
778
779
780
781
782
783
                                                 max_seqlen_q=self.q_seqlen,
                                                 max_seqlen_kv=self.kv_seqlen,
                                                 qkv_dtype=qkv_dtype,
                                                 attn_scale=self.scaling_factor,
                                                 dropout=self.dropout_prob,
                                                 set_zero=False)
            q_grad = dq
            k_grad = dkv[:, :, 0, :, :]
            v_grad = dkv[:, :, 1, :, :]

Shijie's avatar
Shijie committed
784
        return out, q_grad, k_grad, v_grad
Shijie's avatar
Shijie committed
785

Shijie's avatar
Shijie committed
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
    def _get_fused_attention_with_separate_qkv(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))

        q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
        k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
        v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)

        q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
        kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)

        qkv_layout = "bshd_bshd_bshd"
        fused_attention_backend = get_fused_attention_backend(
            num_heads=self.num_heads,
            num_gqa_groups=self.num_heads,
            q_seqlen=self.q_seqlen,
            kv_seqlen=self.kv_seqlen,
            head_size=self.head_size,
            dtype=self.dtype,
            dropout=self.dropout_prob,
            qkv_layout=qkv_layout,
            bias_type="no_bias",
            mask_type="causal" if self.is_causal_masking else "padding",
        )

        qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
        out, softmax_aux_tensor, rng_state = fused_attn_fwd(
            q_tensor,
            k_tensor,
            v_tensor,
            q_cu_seqlen_tensor,
            kv_cu_seqlen_tensor,
            is_training=True,
            max_seqlen_q=self.q_seqlen,
            max_seqlen_kv=self.kv_seqlen,
            qkv_dtype=qkv_dtype,
            fused_attention_backend=fused_attention_backend,
            Bias=None,
            attn_scale=self.scaling_factor,
            dropout=self.dropout_prob,
            set_zero=False,
            qkv_layout=qkv_layout,
            attn_mask_type="causal" if self.is_causal_masking else "padding")
        dq, dk, dv, _ = fused_attn_bwd(
            q_tensor,
            k_tensor,
            v_tensor,
            q_cu_seqlen_tensor,
            kv_cu_seqlen_tensor,
            rng_state,
            out,
            self.dout,
            softmax_aux_tensor,
            fused_attention_backend=fused_attention_backend,
            max_seqlen_q=self.q_seqlen,
            max_seqlen_kv=self.kv_seqlen,
            qkv_dtype=qkv_dtype,
            attn_scale=self.scaling_factor,
            dropout=self.dropout_prob,
            set_zero=False,
            qkv_layout=qkv_layout,
            attn_mask_type="causal" if self.is_causal_masking else "padding")

        return out, dq, dk, dv

Shijie's avatar
Shijie committed
850
851
852
853
854
855
856
    @pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
    @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
    @pytest.mark.parametrize('is_causal_masking', [True, False])
    def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
        """
        test self attention forward + backward
        """
Tim Moon's avatar
Tim Moon committed
857
        if not is_fused_attention_supported(
858
859
860
861
862
863
864
865
866
867
                num_heads=h,
                num_gqa_groups=h,
                q_seqlen=s,
                kv_seqlen=s,
                head_size=d,
                dtype=dtype,
                dropout=0.0,
                qkv_layout="bs3hd",
                bias_type="no_bias",
                mask_type="causal" if is_causal_masking else "padding",
Tim Moon's avatar
Tim Moon committed
868
869
        ):
            pytest.skip("cuDNN fused attention is not supported")
Shijie's avatar
Shijie committed
870
871
872
873
874
875
876
877
878
879
880
881
882
883
        self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
        reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
        fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
        assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
        assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)

    @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES)
    @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
    def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
        """
        test cross attention forward + backward
        """
Tim Moon's avatar
Tim Moon committed
884
        if not is_fused_attention_supported(
885
886
887
888
889
890
891
892
893
894
                num_heads=h,
                num_gqa_groups=h,
                q_seqlen=s_q,
                kv_seqlen=s_kv,
                head_size=d,
                dtype=dtype,
                dropout=0.0,
                qkv_layout="bshd_bs2hd",
                bias_type="no_bias",
                mask_type="padding",
Tim Moon's avatar
Tim Moon committed
895
896
        ):
            pytest.skip("cuDNN fused attention is not supported")
Shijie's avatar
Shijie committed
897
898
        self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn")
        reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
899
900
901
902
903
904
905
906
907
908
909
910
911
        fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
        assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
        assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)

    @pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
    @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
    @pytest.mark.parametrize('is_causal_masking', [True])
    def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
        """
        test flash attention forward + backward
        """
Tim Moon's avatar
Tim Moon committed
912
        if not is_fused_attention_supported(
913
914
915
916
917
918
919
920
921
922
                num_heads=h,
                num_gqa_groups=h,
                q_seqlen=s,
                kv_seqlen=s,
                head_size=d,
                dtype=dtype,
                dropout=0.0,
                qkv_layout="bs3hd",
                bias_type="no_bias",
                mask_type="causal" if is_causal_masking else "padding",
Tim Moon's avatar
Tim Moon committed
923
924
        ):
            pytest.skip("cuDNN fused attention is not supported")
925
926
        self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
        reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
Shijie's avatar
Shijie committed
927
928
929
930
931
932
        fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
        assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
        assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)

Shijie's avatar
Shijie committed
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
    @pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
    @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
    @pytest.mark.parametrize('is_causal_masking', [False, True])
    def test_fused_attn_with_separate_qkv_forward_backward(self, b, s, h, d, dtype,
                                                           is_causal_masking):
        """
        test flash attention forward + backward with separate qkv inputs
        """
        if not is_fused_attention_supported(
                num_heads=h,
                num_gqa_groups=h,
                q_seqlen=s,
                kv_seqlen=s,
                head_size=d,
                dtype=dtype,
                dropout=0.0,
                qkv_layout="bshd_bshd_bshd",
                bias_type="no_bias",
                mask_type="causal" if is_causal_masking else "padding",
        ):
            pytest.skip("cuDNN fused attention is not supported")
        self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
        reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
        fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_with_separate_qkv()
        assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
        assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
        assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)

Shijie's avatar
Shijie committed
962
963
964
965
966
967

class TestSoftmax:
    """
    Test softmax operators
    """

968
    @staticmethod
Shijie's avatar
Shijie committed
969
    @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
970
    def test_scaled_softmax_fwd_bwd(dtype):
Shijie's avatar
Shijie committed
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
        """test scaled softmax"""
        B, H, S = (16, 4, 32)
        scale = 0.8

        x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
        x.stop_gradient = False
        dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)

        y_ref = F.softmax(scale * x)
        y = scaled_softmax_forward(x, scale)

        paddle.autograd.backward([y_ref], [dy], True)
        dx_ref = x.grad
        dx = scaled_softmax_backward(dy, y, scale)

        assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
        assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)

989
    @staticmethod
Shijie's avatar
Shijie committed
990
    @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
991
    def test_scaled_masked_softmax_fwd_bwd(dtype):
Shijie's avatar
Shijie committed
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
        """test scaled masked softmax"""
        B, H, S = (16, 4, 32)
        scale = 0.8

        x = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
        x.stop_gradient = False
        dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype)
        mask = paddle.reshape(x[0, 0] > 0.3, shape=(1, 1, S, S))
        mask_flipped = x[0, 0] <= 0.3
        mask_ref = (mask_flipped.astype(dtype) - 1.0) * 1e4

        y_ref = F.softmax(scale * x + mask_ref)
        y = scaled_masked_softmax_forward(x, mask, scale)

        paddle.autograd.backward([y_ref], [dy], True)
        dx_ref = x.grad
        dx = scaled_masked_softmax_backward(dy, y, scale)

        assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3)
        assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)

1013
    @staticmethod
Shijie's avatar
Shijie committed
1014
    @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
1015
    def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype):
Shijie's avatar
Shijie committed
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
        """test scaled upper triang masked softmax"""
        B, S = (16, 32)
        scale = 0.8

        x = paddle.uniform(shape=(B, S, S), dtype=dtype)
        x.stop_gradient = False
        dy = paddle.uniform(shape=(B, S, S), dtype=dtype)

        mask = paddle.ones((S, S), dtype='int32')
        col_beg, col_end = 1, S
        for row in range(0, S):
            mask[row, col_beg:col_end] = 0
            col_beg += 1

        mask_ref = (mask.astype(dtype) - 1.0) * 1e4

        y_ref = F.softmax(scale * x + mask_ref)
        y = scaled_upper_triang_masked_softmax_forward(x, scale)

        paddle.autograd.backward([y_ref], [dy], True)
        dx_ref = x.grad
        dx = scaled_upper_triang_masked_softmax_backward(dy, y, scale)

        assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3)
        assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
1041
1042


1043
1044
@pytest.mark.parametrize('update_weight_scale_inv', [True, False])
def test_amax_and_scale_update(update_weight_scale_inv):
1045
1046
    """Test update_scale"""
    num_gemm = 6
1047
    history_len = 1024
1048
1049
    recipe = DelayedScaling()
    fp8_max = recipe.fp8_format.value.max_fwd
1050
    non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2))
1051

1052
1053
1054
1055
    amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32')
    rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0)
    rolled_history_ref[0] = 0.0
    amax_tensor = paddle.max(amax_history_tensor, axis=0)
1056
1057
1058
1059
    scale_tensor = paddle.ones(shape=[num_gemm], dtype='float32')

    def calc_ref(amax, scale, fp8_max, margin=0):
        """Calculate reference scale"""
1060
        sf = (fp8_max / amax) / (2**margin)
1061
1062
1063
1064
1065
        sf = paddle.where(amax > 0.0, sf, scale)
        sf = paddle.where(paddle.isfinite(amax), sf, scale)
        return sf

    scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.)
1066
1067
1068
1069
1070
    if update_weight_scale_inv:
        scale_inv_ref = 1. / scale_ref
    else:
        scale_inv_ref = paddle.zeros_like(scale_tensor)
        scale_inv_ref = paddle.where(non_weight_mask, 1. / scale_ref, scale_inv_ref)
1071

1072
1073
1074
1075
1076
1077
1078
    # Placeholder
    scale_actual = paddle.zeros_like(scale_tensor)
    scale_inv_actual = paddle.zeros_like(scale_tensor)

    tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor,
                                      _scale=scale_actual,
                                      _scale_inv=scale_inv_actual,
1079
1080
                                      non_weight_mask=non_weight_mask,
                                      update_weight_scale_inv=update_weight_scale_inv,
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
                                      fp8_max=fp8_max,
                                      margin=0.,
                                      amax_compute="max")

    assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7)
    assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7)
    assert_allclose(amax_history_tensor, rolled_history_ref, rtol=1e-7, atol=1e-7)


def test_update_latest_history():
    """Test update_latest_history"""
    num_gemm = 6
    history_len = 1024

    amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32')
    amax = paddle.rand(shape=[num_gemm], dtype='float32')

    tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax)

    assert_allclose(amax_history_tensor[0], amax, rtol=1e-7, atol=1e-7)