cpp_extensions.py 33.2 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""TE FP8 extensions and GEMMs"""

Shijie's avatar
Shijie committed
6
import math
7
8
from typing import Optional, Tuple, Union
import paddle
9
import paddle.nn.functional as F
10
from transformer_engine import transformer_engine_paddle as tex
11
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
12
from .fp8 import FP8TensorMeta, get_global_fp8_state
13

14
15
16
BACKEND_F16m512_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16

17
18
19
20
21
22
23
24
25
26
27
28

def gemm(
    A: paddle.Tensor,
    B: paddle.Tensor,
    dtype: paddle.dtype,
    workspace: paddle.Tensor,
    gelu: bool = False,
    gelu_input: Optional[paddle.Tensor] = None,
    grad: bool = False,
    accumulate: bool = False,
    layout: str = "TN",
    out: Optional[paddle.Tensor] = None,
Shijie's avatar
Shijie committed
29
    out_dtype: Optional[paddle.dtype] = None,
30
31
32
33
34
35
36
37
38
39
    bias: Optional[paddle.Tensor] = None,
    use_bias: bool = False,
) -> Tuple[Union[paddle.Tensor, None], ...]:
    """Non FP8 GEMM."""

    assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
    transa = layout[0] == "T"
    transb = layout[1] == "T"

    if out is None:
Shijie's avatar
Shijie committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        if accumulate:
            out = paddle.zeros(
                shape=[
                    B.shape[1] if transb else B.shape[0],
                    A.shape[0] if transa else A.shape[1],
                ],
                dtype=out_dtype if out_dtype is not None else dtype,
            )
        else:
            out = paddle.empty(
                shape=[
                    B.shape[1] if transb else B.shape[0],
                    A.shape[0] if transa else A.shape[1],
                ],
                dtype=out_dtype if out_dtype is not None else dtype,
            )
56
57
58
59
60
61
62
63
64
65
66
67
68

    if gelu and not grad:
        gelu_input = paddle.empty_like(out, dtype=dtype)
    elif not gelu:
        gelu_input = None

    if grad and use_bias:
        grad_bias = paddle.empty(shape=[B.shape[1]], dtype=out.dtype)
    else:
        grad_bias = None

    bias = bias if use_bias else None

69
70
71
    assert (
        A.dtype == dtype and B.dtype == dtype
    ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}"
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    input_dtype = TE_DType[dtype]
    output_dtype = TE_DType[out.dtype]
    if use_bias:
        bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
    else:
        bias_dtype = output_dtype

    tex.te_gemm(
        A,
        None,
        B,
        None,
        grad_bias if grad else bias,
        out,
86
87
        None,  # out_scale
        None,  # out_amax
88
89
        gelu_input,
        workspace,
90
91
92
        0,  # A_index
        0,  # B_index
        0,  # D_index
93
94
95
96
97
98
99
100
101
        int(input_dtype),
        int(input_dtype),
        int(output_dtype),
        int(bias_dtype),
        transa,
        transb,
        grad,
        workspace.shape[0],
        accumulate,
102
103
        False,  # use_split_accumulator
        0,  # math_sm_count
104
105
    )

Shijie's avatar
Shijie committed
106
    return out, grad_bias, gelu_input
107
108
109
110
111


def fp8_gemm(
    A: paddle.Tensor,
    A_scale_inv: paddle.Tensor,
112
    A_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
113
114
115
    A_dtype: tex.DType,
    B: paddle.Tensor,
    B_scale_inv: paddle.Tensor,
116
    B_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
117
118
119
120
121
122
123
    B_dtype: tex.DType,
    out_dtype: paddle.dtype,
    workspace: paddle.Tensor,
    gelu: bool = False,
    accumulate: bool = False,
    out: Optional[paddle.Tensor] = None,
    out_index=None,
124
    fp8_meta_tensor: FP8TensorMeta = None,
125
126
127
128
129
130
131
132
133
134
135
    bias: Optional[paddle.Tensor] = None,
    use_bias: bool = False,
    use_split_accumulator: bool = False,
    D_dtype: Optional[tex.DType] = None,
) -> paddle.Tensor:
    """TN layout GEMM with fp8 inputs."""

    if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
        assert fp8_meta_tensor is not None and out_index is not None

    if out is None:
Shijie's avatar
Shijie committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        if accumulate:
            out = paddle.zeros(
                shape=[
                    B.shape[0],
                    A.shape[0],
                ],
                dtype=out_dtype,
            )
        else:
            out = paddle.empty(
                shape=[
                    B.shape[0],
                    A.shape[0],
                ],
                dtype=out_dtype,
            )

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    # Use bfloat16 as default bias_dtype
    bias_dtype = paddle.bfloat16 if bias is None else bias.dtype
    if gelu:
        gelu_input = paddle.empty_like(out, dtype=bias_dtype)
    else:
        gelu_input = None
    bias_dtype = TE_DType[bias_dtype]

    out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype

    tex.te_gemm(
        A,
        A_scale_inv,
        B,
        B_scale_inv,
        bias if use_bias else None,
        out,
        None if out_index is None else fp8_meta_tensor.scale,
        None if out_index is None else fp8_meta_tensor.amax_history,
172
        gelu_input,  # this is pre_gelu_out
173
        workspace,
174
175
        A_fp8_tensor.value,
        B_fp8_tensor.value,
176
177
178
179
180
        0 if out_index is None else out_index,
        int(A_dtype),
        int(B_dtype),
        int(out_dtype),
        int(bias_dtype),
181
182
183
        True,  # transa
        False,  # transb
        False,  # grad
184
185
186
        workspace.shape[0],
        accumulate,
        use_split_accumulator,
187
        0,  # math_sm_count
188
189
    )

Shijie's avatar
Shijie committed
190
    return out, gelu_input
191
192
193
194


def cast_to_fp8(
    inp: paddle.Tensor,
195
196
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
197
    otype: tex.DType,
198
    out: Optional[paddle.Tensor] = None,
199
200
) -> paddle.Tensor:
    """Cast input to FP8"""
201
202
203
204
205
206
207
208
209
210
    if out is None:
        out = paddle.empty(
            shape=inp.shape,
            dtype=paddle.uint8,
        )
    else:
        assert out.shape == inp.shape, "Output shape does not match input shape."
        assert out.dtype == paddle.uint8, "Output should be of uint8 dtype."

    tex.cast_to_fp8(
211
212
        inp,
        fp8_meta_tensor.scale,
213
        out,
214
215
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
216
        fp8_tensor.value,
217
218
219
220
221
222
223
        int(otype),
    )
    return out


def cast_from_fp8(
    inp: paddle.Tensor,
224
225
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
226
227
228
229
230
231
232
    itype: tex.DType,
    otype: tex.DType,
) -> paddle.Tensor:
    """Cast input from FP8"""
    return tex.cast_from_fp8(
        inp,
        fp8_meta_tensor.scale_inv,
233
        fp8_tensor.value,
234
235
236
        int(itype),
        int(otype),
    )
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251


def transpose(
    inp: paddle.Tensor,
    otype: tex.DType,
) -> paddle.Tensor:
    """Transpose input"""
    return tex.te_transpose(
        inp,
        int(otype),
    )


def cast_transpose(
    inp: paddle.Tensor,
252
253
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
254
    otype: tex.DType,
255
256
    cast_out: Optional[paddle.Tensor] = None,
    transpose_out: Optional[paddle.Tensor] = None,
257
258
) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]:
    """Cast + Transpose with FP8 output"""
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    if cast_out is None:
        cast_out = paddle.empty(
            shape=inp.shape,
            dtype=paddle.uint8,
        )
    else:
        assert cast_out.shape == inp.shape, "cast_out shape does not match input shape."
        assert cast_out.dtype == paddle.uint8, "cast_out should be of uint8 dtype."

    if transpose_out is None:
        transpose_out = paddle.empty(
            shape=[inp.shape[1], inp.shape[0]],
            dtype=paddle.uint8,
        )
    else:
274
275
276
277
        assert transpose_out.shape == [
            inp.shape[1],
            inp.shape[0],
        ], "Transposed output shape does not match input shape."
278
279
280
        assert transpose_out.dtype == paddle.uint8, "Output should be of uint8 dtype."

    tex.te_cast_transpose(
281
282
        inp,
        fp8_meta_tensor.scale,
283
284
        cast_out,
        transpose_out,
285
286
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
287
        fp8_tensor.value,
288
289
290
291
292
293
        int(otype),
    )

    return cast_out, transpose_out


294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def cast_transpose_bgrad(
    inp: paddle.Tensor,
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
    otype: tex.DType,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor], None]:
    """Fused Cast + Transpose + Bias Grad"""
    grad_bias, cast_out, transpose_out, _, _ = tex.te_cast_transpose_bgrad(
        inp,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor.value,
        int(otype),
    )

    return grad_bias, cast_out, transpose_out


313
314
315
316
317
318
319
320
321
322
323
324
325
def te_gelu(
    inp: paddle.Tensor,
    otype: tex.DType,
) -> paddle.Tensor:
    """Non FP8 GELU"""
    return tex.te_gelu(
        inp,
        int(otype),
    )


def gelu_fp8(
    inp: paddle.Tensor,
326
327
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
328
329
330
331
332
333
334
335
    otype: tex.DType,
) -> paddle.Tensor:
    """GELU + FP8 cast"""
    out, _, _ = tex.te_gelu_fp8(
        inp,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
336
        fp8_tensor.value,
337
338
339
340
341
342
        int(otype),
    )

    return out


343
344
345
346
347
348
349
350
351
352
353
def swiglu(
    inp: paddle.Tensor,
    otype: tex.DType,
) -> paddle.Tensor:
    """Non FP8 SWIGLU"""
    return tex.te_swiglu(
        inp,
        int(otype),
    )


354
355
356
def swiglu_pd(
    inp: paddle.Tensor,
) -> paddle.Tensor:
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    """Native SWIGLU"""
    gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1)
    out = F.silu(gate_out) * up_out
    return out


def swiglu_fp8(
    inp: paddle.Tensor,
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
    otype: tex.DType,
) -> paddle.Tensor:
    """SWIGLU + FP8 cast"""
    out, _, _ = tex.te_swiglu_fp8(
        inp,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor.value,
        int(otype),
    )

    return out


def dswiglu(
    grad_output: paddle.Tensor,
    swiglu_input: paddle.Tensor,
    otype: tex.DType,
) -> paddle.Tensor:
    """dSWIGLU"""
    return tex.te_dswiglu(
        grad_output,
        swiglu_input,
        int(otype),
    )


395
396
397
def dgelu_cast_transpose_bgrad_fp8(
    grad_output: paddle.Tensor,
    gelu_input: paddle.Tensor,
398
399
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
400
401
402
403
404
405
406
407
408
409
410
411
    otype: tex.DType,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """
    Fused dgelu + cast / transpose / reduce the result of
    the GELU backward along the first dimension
    """
    cast_dgelu, transpose_dgelu, dbias, _, _ = tex.te_cast_transpose_bgrad_dgelu(
        grad_output,
        gelu_input,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
412
        fp8_tensor.value,
413
414
415
416
417
418
419
420
421
422
423
        int(otype),
    )

    return cast_dgelu, transpose_dgelu, dbias


def layernorm_fwd_fp8(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    bias: paddle.Tensor,
    eps: float,
424
425
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
426
427
428
429
430
    otype: tex.DType,
    sm_margin: int = 0,
    zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """LayerNorm with FP8 output"""
431
432
433
434
435
436
437
438
439
440
441
442
443
    out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8(
        inp,
        weight,
        bias,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        eps,
        fp8_tensor.value,
        int(otype),
        sm_margin,
        zero_centered_gamma,
    )
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    return out, mu, rsigma


def layernorm_fwd(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    bias: paddle.Tensor,
    eps: float,
    otype: tex.DType,
    sm_margin: int = 0,
    zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 LayerNorm forward"""
    return tex.te_layernorm_fwd(inp, weight, bias, eps, int(otype), sm_margin, zero_centered_gamma)


def layernorm_bwd(
    dz: paddle.Tensor,
    x: paddle.Tensor,
    mu: paddle.Tensor,
    rsigma: paddle.Tensor,
    gamma: paddle.Tensor,
    sm_margin: int = 0,
    zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 LayerNorm backward"""
    return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma)
Shijie's avatar
Shijie committed
471
472
473
474
475
476
477
478


def rmsnorm_fwd(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    eps: float,
    otype: tex.DType,
    sm_margin: int = 0,
479
    zero_centered_gamma: bool = False,
Shijie's avatar
Shijie committed
480
481
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 RMSNorm forward"""
482
    return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin, zero_centered_gamma)
Shijie's avatar
Shijie committed
483
484
485
486
487
488


def rmsnorm_fwd_fp8(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    eps: float,
489
490
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
Shijie's avatar
Shijie committed
491
492
    otype: tex.DType,
    sm_margin: int = 0,
493
    zero_centered_gamma: bool = False,
Shijie's avatar
Shijie committed
494
495
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """RMSNorm with FP8 output"""
496
497
498
499
500
501
502
503
504
505
506
507
    out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(
        inp,
        weight,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        eps,
        fp8_tensor.value,
        int(otype),
        sm_margin,
        zero_centered_gamma,
    )
Shijie's avatar
Shijie committed
508
509
510
511
512
513
514
515
516
    return out, rsigma


def rmsnorm_bwd(
    dz: paddle.Tensor,
    x: paddle.Tensor,
    rsigma: paddle.Tensor,
    gamma: paddle.Tensor,
    sm_margin: int = 0,
517
    zero_centered_gamma: bool = False,
Shijie's avatar
Shijie committed
518
519
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 RMSNorm backward"""
520
    return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma)
Shijie's avatar
Shijie committed
521
522


523
524
525
526
527
528
def mask_to_cu_seqlens(
    mask: paddle.Tensor,
    need_kv: bool = False,
) -> paddle.Tensor:
    """Convert mask to cu_seqlens"""
    # mask shape: [b, 1, s_q, s_kv]
529
530
    if get_global_fp8_state().is_cudagraph_enabled():
        raise RuntimeError("mask_to_cu_seqlens is not supported with cuda graphs.")
531
532
533
534
535
536
537
538
539
540
541
    q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3]
    q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
    q_cu_seqlens[0] = 0
    kv_cu_seqlens = None
    if need_kv:
        kv_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
        kv_cu_seqlens[0] = 0
    tex.mask_to_cu_seqlens(mask, q_cu_seqlens, kv_cu_seqlens, q_seqlen, kv_seqlen, need_kv)
    return q_cu_seqlens, kv_cu_seqlens


Shijie's avatar
Shijie committed
542
543
544
545
546
547
def fused_attn_fwd_qkvpacked(
    qkv: paddle.Tensor,
    cu_seqlens: paddle.Tensor,
    is_training: bool,
    max_seqlen: int,
    qkv_dtype: tex.DType,
548
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
549
550
551
552
    Bias: paddle.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
553
    qkv_layout: str = "bs3hd",
Shijie's avatar
Shijie committed
554
555
556
557
558
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
    """Fused Attention FWD for packed QKV input"""

559
560
561
562
    assert qkv_dtype in (
        tex.DType.kBFloat16,
        tex.DType.kFloat16,
    ), "Only support bf16/fp16 for fused attention."
Shijie's avatar
Shijie committed
563

564
    b = cu_seqlens.shape[0] - 1
Shijie's avatar
Shijie committed
565
566
567
568
569
570
571
572
573
    total_seqs = qkv.shape[0] * qkv.shape[1]
    h = qkv.shape[3]
    d = qkv.shape[4]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    if bias_type != "no_bias":
        assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
574
575
576
577
578
579
580
581
582
583
584
        assert Bias.shape == [
            1,
            h,
            max_seqlen,
            max_seqlen,
        ], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
        assert Bias.dtype == qkv.dtype, "bias tensor must be in the same dtype as qkv."

    assert (
        fused_attention_backend != FusedAttnBackend["No_Backend"]
    ), "Fused attention does not support this input combination."
585

586
    rng_elts_per_thread = None
587
588
    # BF16/FP16 fused attention API from fmha_v1 apex
    if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
589
590
591
        rng_elts_per_thread = (
            max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA - 1
        ) // BACKEND_F16m512_THREADS_PER_CTA
592
593
594
595
596

    # BF16/FP16 fused attention API from fmha_v2
    if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
        rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS

597
598
599
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        set_zero = True
Shijie's avatar
Shijie committed
600
    if set_zero:
Shijie's avatar
Shijie committed
601
        out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype)
Shijie's avatar
Shijie committed
602
    else:
Shijie's avatar
Shijie committed
603
        out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype)
Shijie's avatar
Shijie committed
604
605

    if is_training:
Shijie's avatar
Shijie committed
606
607
608
        if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
        elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
609
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype="float32")
Shijie's avatar
Shijie committed
610
611
        else:
            raise ValueError("Unsupported fused attention backend.")
Shijie's avatar
Shijie committed
612
613
614
    else:
        softmax_aux = None

615
616
617
618
619
620
    rng_state = paddle.empty(
        shape=[
            2,
        ],
        dtype=paddle.int64,
    )
621

Shijie's avatar
Shijie committed
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
    # execute kernel
    tex.te_fused_attn_fwd_qkvpacked(
        qkv,
        cu_seqlens,
        Bias,
        out,
        softmax_aux,
        rng_state,
        b,
        h,
        d,
        total_seqs,
        max_seqlen,
        is_training,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
642
        rng_elts_per_thread,
Shijie's avatar
Shijie committed
643
    )
644
    return out, softmax_aux, rng_state
Shijie's avatar
Shijie committed
645
646
647
648
649


def fused_attn_bwd_qkvpacked(
    qkv: paddle.Tensor,
    cu_seqlens: paddle.Tensor,
650
    rng_state: paddle.Tensor,
Shijie's avatar
Shijie committed
651
652
653
    o: paddle.Tensor,
    d_o: paddle.Tensor,
    softmax_aux: paddle.Tensor,
654
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
655
656
657
658
659
    max_seqlen: int,
    qkv_dtype: tex.DType,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
660
    qkv_layout: str = "bs3hd",
Shijie's avatar
Shijie committed
661
662
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
663
    deterministic: bool = False,
Shijie's avatar
Shijie committed
664
) -> Tuple[paddle.Tensor, paddle.Tensor]:
665
    """Fused Attention BWD for packed QKV input"""
Shijie's avatar
Shijie committed
666

667
668
669
670
    assert qkv_dtype in (
        tex.DType.kBFloat16,
        tex.DType.kFloat16,
    ), "Only support bf16/fp16 for fused attention."
Shijie's avatar
Shijie committed
671

672
    b = cu_seqlens.shape[0] - 1
Shijie's avatar
Shijie committed
673
674
675
676
677
678
679
    total_seqs = qkv.shape[0] * qkv.shape[1]
    h = qkv.shape[3]
    d = qkv.shape[4]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

680
681
682
    assert (
        fused_attention_backend != FusedAttnBackend["No_Backend"]
    ), "Fused attention does not support this input combination."
683

684
685
686
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        set_zero = True
Shijie's avatar
Shijie committed
687
688
689
690
691
692
    if set_zero:
        dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
    else:
        dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype)

    if bias_type != "no_bias":
693
694
695
696
        if qkv_format == "thd":
            dbias = paddle.zero(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
        else:
            dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
Shijie's avatar
Shijie committed
697
698
699
700
701
702
703
704
705
706
707
    else:
        dbias = None
    # execute kernel
    dqkv, dbias = tex.te_fused_attn_bwd_qkvpacked(
        qkv,
        cu_seqlens,
        o,
        d_o,
        softmax_aux,
        dqkv,
        dbias,
708
        rng_state,
Shijie's avatar
Shijie committed
709
710
711
712
713
714
715
716
717
718
719
        b,
        h,
        d,
        total_seqs,
        max_seqlen,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
720
        deterministic,
Shijie's avatar
Shijie committed
721
722
723
724
725
726
727
728
729
730
731
732
733
734
    )

    return dqkv, dbias


def fused_attn_fwd_kvpacked(
    q: paddle.Tensor,
    kv: paddle.Tensor,
    cu_seqlens_q: paddle.Tensor,
    cu_seqlens_kv: paddle.Tensor,
    is_training: bool,
    max_seqlen_q: int,
    max_seqlen_kv: int,
    qkv_dtype: tex.DType,
735
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
736
737
738
739
    Bias: paddle.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
740
    qkv_layout: str = "bshd_bs2hd",
Shijie's avatar
Shijie committed
741
742
743
744
745
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
    """Fused Attention FWD for packed KV input"""

746
747
748
749
750
751
752
    assert qkv_dtype in (
        tex.DType.kBFloat16,
        tex.DType.kFloat16,
    ), "Only support bf16/fp16 for fused attention."
    assert (
        cu_seqlens_q.shape == cu_seqlens_kv.shape
    ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
Shijie's avatar
Shijie committed
753

754
    b = cu_seqlens_q.shape[0] - 1
Shijie's avatar
Shijie committed
755
756
757
758
759
760
761
762
763
764
    total_seqs_q = q.shape[0] * q.shape[1]
    total_seqs_kv = kv.shape[0] * kv.shape[1]
    h = q.shape[2]
    d = q.shape[3]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    if bias_type != "no_bias":
        assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
765
766
767
768
769
770
771
772
773
774
775
        assert Bias.shape == [
            1,
            h,
            max_seqlen_q,
            max_seqlen_kv,
        ], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
        assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as q and kv."

    assert (
        fused_attention_backend != FusedAttnBackend["No_Backend"]
    ), "Fused attention does not support this input combination."
776

777
    rng_elts_per_thread = None
778
779
    # BF16/FP16 fused attention API from fmha_v1 apex
    if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
780
781
782
        rng_elts_per_thread = (
            max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1
        ) // BACKEND_F16m512_THREADS_PER_CTA
783
784
785
786
787

    # BF16/FP16 fused attention API from fmha_v2
    if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
        rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS

788
789
790
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        set_zero = True
Shijie's avatar
Shijie committed
791
    if set_zero:
Shijie's avatar
Shijie committed
792
        out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
Shijie's avatar
Shijie committed
793
    else:
Shijie's avatar
Shijie committed
794
        out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
Shijie's avatar
Shijie committed
795
796

    if is_training:
Shijie's avatar
Shijie committed
797
798
799
        if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
        elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
800
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32")
Shijie's avatar
Shijie committed
801
802
        else:
            raise ValueError("Unsupported fused attention backend.")
Shijie's avatar
Shijie committed
803
804
805
    else:
        softmax_aux = None

806
807
808
809
810
811
    rng_state = paddle.empty(
        shape=[
            2,
        ],
        dtype=paddle.int64,
    )
812

Shijie's avatar
Shijie committed
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
    # execute kernel
    tex.te_fused_attn_fwd_kvpacked(
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_kv,
        Bias,
        out,
        softmax_aux,
        rng_state,
        b,
        h,
        d,
        total_seqs_q,
        total_seqs_kv,
        max_seqlen_q,
        max_seqlen_kv,
        is_training,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
837
        rng_elts_per_thread,
Shijie's avatar
Shijie committed
838
839
    )

840
    return out, softmax_aux, rng_state
Shijie's avatar
Shijie committed
841
842
843
844
845
846
847


def fused_attn_bwd_kvpacked(
    q: paddle.Tensor,
    kv: paddle.Tensor,
    cu_seqlens_q: paddle.Tensor,
    cu_seqlens_kv: paddle.Tensor,
848
    rng_state: paddle.Tensor,
Shijie's avatar
Shijie committed
849
850
851
    o: paddle.Tensor,
    d_o: paddle.Tensor,
    softmax_aux: paddle.Tensor,
852
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
853
854
855
856
857
858
    max_seqlen_q: int,
    max_seqlen_kv: int,
    qkv_dtype: tex.DType,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
859
    qkv_layout: str = "bshd_bs2hd",
Shijie's avatar
Shijie committed
860
861
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
862
    deterministic: bool = False,
Shijie's avatar
Shijie committed
863
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
864
    """Fused Attention BWD for packed KV input"""
Shijie's avatar
Shijie committed
865

866
867
868
869
870
871
872
    assert qkv_dtype in (
        tex.DType.kBFloat16,
        tex.DType.kFloat16,
    ), "Only support bf16/fp16 for fused attention."
    assert (
        cu_seqlens_q.shape == cu_seqlens_kv.shape
    ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
Shijie's avatar
Shijie committed
873

874
    b = cu_seqlens_q.shape[0] - 1
Shijie's avatar
Shijie committed
875
876
877
878
879
880
881
882
    total_seqs_q = q.shape[0] * q.shape[1]
    total_seqs_kv = kv.shape[0] * kv.shape[1]
    h = q.shape[2]
    d = q.shape[3]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

883
884
885
    assert (
        fused_attention_backend != FusedAttnBackend["No_Backend"]
    ), "Fused attention does not support this input combination."
886

887
888
889
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        set_zero = True
Shijie's avatar
Shijie committed
890
891
892
893
894
895
896
    if set_zero:
        dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
        dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
    else:
        dq = paddle.empty(shape=q.shape, dtype=q.dtype)
        dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype)
    if bias_type != "no_bias":
897
898
899
900
        if qkv_format == "thd":
            dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
        else:
            dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
Shijie's avatar
Shijie committed
901
902
903
904
905
906
907
908
909
910
911
912
913
914
    else:
        dbias = None
    # execute kernel
    tex.te_fused_attn_bwd_kvpacked(
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_kv,
        o,
        d_o,
        softmax_aux,
        dq,
        dkv,
        dbias,
915
        rng_state,
Shijie's avatar
Shijie committed
916
917
918
919
920
921
922
923
924
925
926
927
928
        b,
        h,
        d,
        total_seqs_q,
        total_seqs_kv,
        max_seqlen_q,
        max_seqlen_kv,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
929
        deterministic,
Shijie's avatar
Shijie committed
930
931
932
933
    )
    return dq, dkv, dbias


Shijie's avatar
Shijie committed
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
def fused_attn_fwd(
    q: paddle.Tensor,
    k: paddle.Tensor,
    v: paddle.Tensor,
    cu_seqlens_q: paddle.Tensor,
    cu_seqlens_kv: paddle.Tensor,
    is_training: bool,
    max_seqlen_q: int,
    max_seqlen_kv: int,
    qkv_dtype: tex.DType,
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
    Bias: paddle.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
    qkv_layout: str = "bshd_bshd_bshd",
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
    """Fused Attention FWD for unpacked QKV input"""

955
956
957
958
959
960
961
962
963
964
    assert qkv_dtype in (
        tex.DType.kBFloat16,
        tex.DType.kFloat16,
    ), "Only support bf16/fp16 for fused attention."
    assert (
        cu_seqlens_q.shape == cu_seqlens_kv.shape
    ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
    assert (
        qkv_layout == "bshd_bshd_bshd"
    ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
Shijie's avatar
Shijie committed
965
966
967
968
969
970
971
972
973
974
    b = cu_seqlens_q.shape[0] - 1

    h = q.shape[-2]
    d = q.shape[-1]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    if bias_type != "no_bias":
        assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
975
976
977
978
979
980
981
982
983
984
985
        assert Bias.shape == [
            1,
            h,
            max_seqlen_q,
            max_seqlen_kv,
        ], "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
        assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as qkv."

    assert (
        fused_attention_backend != FusedAttnBackend["No_Backend"]
    ), "Fused attention does not support this input combination."
Shijie's avatar
Shijie committed
986

987
    rng_elts_per_thread = None
Shijie's avatar
Shijie committed
988
989
    # BF16/FP16 fused attention API from fmha_v1 apex
    if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
990
991
992
        rng_elts_per_thread = (
            max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1
        ) // BACKEND_F16m512_THREADS_PER_CTA
Shijie's avatar
Shijie committed
993
994
995
996
997

    # BF16/FP16 fused attention API from fmha_v2
    if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
        rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS

998
999
1000
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        set_zero = True
Shijie's avatar
Shijie committed
1001
1002
1003
1004
1005
1006
1007
1008
1009
    if set_zero:
        out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
    else:
        out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)

    if is_training:
        if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
        elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
1010
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32")
Shijie's avatar
Shijie committed
1011
1012
1013
1014
1015
        else:
            raise ValueError("Unsupported fused attention backend.")
    else:
        softmax_aux = None

1016
1017
1018
1019
1020
1021
    rng_state = paddle.empty(
        shape=[
            2,
        ],
        dtype=paddle.int64,
    )
Shijie's avatar
Shijie committed
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070

    # execute kernel
    tex.te_fused_attn_fwd(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        Bias,
        out,
        softmax_aux,
        rng_state,
        b,
        h,
        d,
        max_seqlen_q,
        max_seqlen_kv,
        is_training,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
        rng_elts_per_thread,
    )
    return out, softmax_aux, rng_state


def fused_attn_bwd(
    q: paddle.Tensor,
    k: paddle.Tensor,
    v: paddle.Tensor,
    cu_seqlens_q: paddle.Tensor,
    cu_seqlens_kv: paddle.Tensor,
    rng_state: paddle.Tensor,
    o: paddle.Tensor,
    d_o: paddle.Tensor,
    softmax_aux: paddle.Tensor,
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
    max_seqlen_q: int,
    max_seqlen_kv: int,
    qkv_dtype: tex.DType,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
    qkv_layout: str = "bshd_bshd_bshd",
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
1071
    deterministic: bool = False,
Shijie's avatar
Shijie committed
1072
1073
1074
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Fused Attention BWD for packed KV input"""

1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
    assert qkv_dtype in (
        tex.DType.kBFloat16,
        tex.DType.kFloat16,
    ), "Only support bf16/fp16 for fused attention."
    assert (
        cu_seqlens_q.shape == cu_seqlens_kv.shape
    ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
    assert (
        qkv_layout == "bshd_bshd_bshd"
    ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
Shijie's avatar
Shijie committed
1085
1086
1087
1088
1089
1090
1091
1092

    b = cu_seqlens_q.shape[0] - 1
    h = q.shape[-2]
    d = q.shape[-1]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

1093
1094
1095
    assert (
        fused_attention_backend != FusedAttnBackend["No_Backend"]
    ), "Fused attention does not support this input combination."
Shijie's avatar
Shijie committed
1096

1097
1098
1099
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if qkv_format == "thd":
        set_zero = True
Shijie's avatar
Shijie committed
1100
1101
1102
1103
1104
1105
1106
1107
1108
    if set_zero:
        dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
        dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype)
        dv = paddle.full(shape=v.shape, fill_value=0, dtype=v.dtype)
    else:
        dq = paddle.empty(shape=q.shape, dtype=q.dtype)
        dk = paddle.empty(shape=k.shape, dtype=k.dtype)
        dv = paddle.empty(shape=v.shape, dtype=v.dtype)
    if bias_type != "no_bias":
1109
1110
1111
1112
        if qkv_format == "thd":
            dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
        else:
            dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
Shijie's avatar
Shijie committed
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
    else:
        dbias = None
    # execute kernel
    tex.te_fused_attn_bwd(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        o,
        d_o,
        softmax_aux,
        dq,
        dk,
        dv,
        dbias,
        rng_state,
        b,
        h,
        d,
        max_seqlen_q,
        max_seqlen_kv,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
1141
        deterministic,
Shijie's avatar
Shijie committed
1142
1143
1144
1145
    )
    return dq, dk, dv, dbias


Shijie's avatar
Shijie committed
1146
1147
1148
1149
def scaled_softmax_forward(
    inp: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
1150
    """scaled softmax forward"""
Shijie's avatar
Shijie committed
1151
1152
1153
1154
1155
1156
1157
1158
    return tex.te_scaled_softmax_forward(inp, scale_factor)


def scaled_softmax_backward(
    out_grad: paddle.Tensor,
    softmax_results: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
1159
    """scaled softmax backward"""
Shijie's avatar
Shijie committed
1160
1161
1162
1163
1164
1165
1166
1167
1168
    tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
    return out_grad


def scaled_masked_softmax_forward(
    inp: paddle.Tensor,
    mask: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
1169
    """scaled masked softmax forward"""
Shijie's avatar
Shijie committed
1170
1171
1172
1173
1174
1175
1176
1177
1178

    return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor)


def scaled_masked_softmax_backward(
    out_grad: paddle.Tensor,
    softmax_results: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
1179
    """scaled masked softmax backward"""
Shijie's avatar
Shijie committed
1180
1181
1182
1183
1184
1185
1186
1187
    tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
    return out_grad


def scaled_upper_triang_masked_softmax_forward(
    inp: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
1188
    """scaled upper triang masked softmax forward"""
Shijie's avatar
Shijie committed
1189
1190
1191
1192
1193
1194
1195
1196
    return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor)


def scaled_upper_triang_masked_softmax_backward(
    out_grad: paddle.Tensor,
    softmax_results: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
1197
    """scaled upper triang masked softmax backward"""
Shijie's avatar
Shijie committed
1198
1199
    tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor)
    return out_grad