cpp_extensions.py 31.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""TE FP8 extensions and GEMMs"""

Shijie's avatar
Shijie committed
6
import math
7
8
from typing import Optional, Tuple, Union
import paddle
9
import paddle.nn.functional as F
10
from transformer_engine import transformer_engine_paddle as tex
11
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
12
from .fp8 import FP8TensorMeta
13

14
15
16
BACKEND_F16m512_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16

17
18
19
20
21
22
23
24
25
26
27
28

def gemm(
    A: paddle.Tensor,
    B: paddle.Tensor,
    dtype: paddle.dtype,
    workspace: paddle.Tensor,
    gelu: bool = False,
    gelu_input: Optional[paddle.Tensor] = None,
    grad: bool = False,
    accumulate: bool = False,
    layout: str = "TN",
    out: Optional[paddle.Tensor] = None,
Shijie's avatar
Shijie committed
29
    out_dtype: Optional[paddle.dtype] = None,
30
31
32
33
34
35
36
37
38
39
    bias: Optional[paddle.Tensor] = None,
    use_bias: bool = False,
) -> Tuple[Union[paddle.Tensor, None], ...]:
    """Non FP8 GEMM."""

    assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
    transa = layout[0] == "T"
    transb = layout[1] == "T"

    if out is None:
Shijie's avatar
Shijie committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        if accumulate:
            out = paddle.zeros(
                shape=[
                    B.shape[1] if transb else B.shape[0],
                    A.shape[0] if transa else A.shape[1],
                ],
                dtype=out_dtype if out_dtype is not None else dtype,
            )
        else:
            out = paddle.empty(
                shape=[
                    B.shape[1] if transb else B.shape[0],
                    A.shape[0] if transa else A.shape[1],
                ],
                dtype=out_dtype if out_dtype is not None else dtype,
            )
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

    if gelu and not grad:
        gelu_input = paddle.empty_like(out, dtype=dtype)
    elif not gelu:
        gelu_input = None

    if grad and use_bias:
        grad_bias = paddle.empty(shape=[B.shape[1]], dtype=out.dtype)
    else:
        grad_bias = None

    bias = bias if use_bias else None

    assert A.dtype == dtype and B.dtype == dtype, \
        f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
    input_dtype = TE_DType[dtype]
    output_dtype = TE_DType[out.dtype]
    if use_bias:
        bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
    else:
        bias_dtype = output_dtype

    tex.te_gemm(
        A,
        None,
        B,
        None,
        grad_bias if grad else bias,
        out,
        None,    # out_scale
        None,    # out_amax
        gelu_input,
        workspace,
        0,    # A_index
        0,    # B_index
        0,    # D_index
        int(input_dtype),
        int(input_dtype),
        int(output_dtype),
        int(bias_dtype),
        transa,
        transb,
        grad,
        workspace.shape[0],
        accumulate,
        False,    # use_split_accumulator
        0,    # math_sm_count
    )

Shijie's avatar
Shijie committed
105
    return out, grad_bias, gelu_input
106
107
108
109
110


def fp8_gemm(
    A: paddle.Tensor,
    A_scale_inv: paddle.Tensor,
111
    A_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
112
113
114
    A_dtype: tex.DType,
    B: paddle.Tensor,
    B_scale_inv: paddle.Tensor,
115
    B_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
116
117
118
119
120
121
122
    B_dtype: tex.DType,
    out_dtype: paddle.dtype,
    workspace: paddle.Tensor,
    gelu: bool = False,
    accumulate: bool = False,
    out: Optional[paddle.Tensor] = None,
    out_index=None,
123
    fp8_meta_tensor: FP8TensorMeta = None,
124
125
126
127
128
129
130
131
132
133
134
    bias: Optional[paddle.Tensor] = None,
    use_bias: bool = False,
    use_split_accumulator: bool = False,
    D_dtype: Optional[tex.DType] = None,
) -> paddle.Tensor:
    """TN layout GEMM with fp8 inputs."""

    if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
        assert fp8_meta_tensor is not None and out_index is not None

    if out is None:
Shijie's avatar
Shijie committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        if accumulate:
            out = paddle.zeros(
                shape=[
                    B.shape[0],
                    A.shape[0],
                ],
                dtype=out_dtype,
            )
        else:
            out = paddle.empty(
                shape=[
                    B.shape[0],
                    A.shape[0],
                ],
                dtype=out_dtype,
            )

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    # Use bfloat16 as default bias_dtype
    bias_dtype = paddle.bfloat16 if bias is None else bias.dtype
    if gelu:
        gelu_input = paddle.empty_like(out, dtype=bias_dtype)
    else:
        gelu_input = None
    bias_dtype = TE_DType[bias_dtype]

    out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype

    tex.te_gemm(
        A,
        A_scale_inv,
        B,
        B_scale_inv,
        bias if use_bias else None,
        out,
        None if out_index is None else fp8_meta_tensor.scale,
        None if out_index is None else fp8_meta_tensor.amax_history,
        gelu_input,    # this is pre_gelu_out
        workspace,
173
174
        A_fp8_tensor.value,
        B_fp8_tensor.value,
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        0 if out_index is None else out_index,
        int(A_dtype),
        int(B_dtype),
        int(out_dtype),
        int(bias_dtype),
        True,    # transa
        False,    # transb
        False,    # grad
        workspace.shape[0],
        accumulate,
        use_split_accumulator,
        0,    # math_sm_count
    )

Shijie's avatar
Shijie committed
189
    return out, gelu_input
190
191
192
193


def cast_to_fp8(
    inp: paddle.Tensor,
194
195
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
196
    otype: tex.DType,
197
    out: Optional[paddle.Tensor] = None,
198
199
) -> paddle.Tensor:
    """Cast input to FP8"""
200
201
202
203
204
205
206
207
208
209
    if out is None:
        out = paddle.empty(
            shape=inp.shape,
            dtype=paddle.uint8,
        )
    else:
        assert out.shape == inp.shape, "Output shape does not match input shape."
        assert out.dtype == paddle.uint8, "Output should be of uint8 dtype."

    tex.cast_to_fp8(
210
211
        inp,
        fp8_meta_tensor.scale,
212
        out,
213
214
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
215
        fp8_tensor.value,
216
217
218
219
220
221
222
        int(otype),
    )
    return out


def cast_from_fp8(
    inp: paddle.Tensor,
223
224
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
225
226
227
228
229
230
231
    itype: tex.DType,
    otype: tex.DType,
) -> paddle.Tensor:
    """Cast input from FP8"""
    return tex.cast_from_fp8(
        inp,
        fp8_meta_tensor.scale_inv,
232
        fp8_tensor.value,
233
234
235
        int(itype),
        int(otype),
    )
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250


def transpose(
    inp: paddle.Tensor,
    otype: tex.DType,
) -> paddle.Tensor:
    """Transpose input"""
    return tex.te_transpose(
        inp,
        int(otype),
    )


def cast_transpose(
    inp: paddle.Tensor,
251
252
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
253
    otype: tex.DType,
254
255
    cast_out: Optional[paddle.Tensor] = None,
    transpose_out: Optional[paddle.Tensor] = None,
256
257
) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]:
    """Cast + Transpose with FP8 output"""
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    if cast_out is None:
        cast_out = paddle.empty(
            shape=inp.shape,
            dtype=paddle.uint8,
        )
    else:
        assert cast_out.shape == inp.shape, "cast_out shape does not match input shape."
        assert cast_out.dtype == paddle.uint8, "cast_out should be of uint8 dtype."

    if transpose_out is None:
        transpose_out = paddle.empty(
            shape=[inp.shape[1], inp.shape[0]],
            dtype=paddle.uint8,
        )
    else:
        assert transpose_out.shape == [inp.shape[1], inp.shape[0]
                                      ], "Transposed output shape does not match input shape."
        assert transpose_out.dtype == paddle.uint8, "Output should be of uint8 dtype."

    tex.te_cast_transpose(
278
279
        inp,
        fp8_meta_tensor.scale,
280
281
        cast_out,
        transpose_out,
282
283
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
284
        fp8_tensor.value,
285
286
287
288
289
290
        int(otype),
    )

    return cast_out, transpose_out


291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def cast_transpose_bgrad(
    inp: paddle.Tensor,
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
    otype: tex.DType,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor], None]:
    """Fused Cast + Transpose + Bias Grad"""
    grad_bias, cast_out, transpose_out, _, _ = tex.te_cast_transpose_bgrad(
        inp,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor.value,
        int(otype),
    )

    return grad_bias, cast_out, transpose_out


310
311
312
313
314
315
316
317
318
319
320
321
322
def te_gelu(
    inp: paddle.Tensor,
    otype: tex.DType,
) -> paddle.Tensor:
    """Non FP8 GELU"""
    return tex.te_gelu(
        inp,
        int(otype),
    )


def gelu_fp8(
    inp: paddle.Tensor,
323
324
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
325
326
327
328
329
330
331
332
    otype: tex.DType,
) -> paddle.Tensor:
    """GELU + FP8 cast"""
    out, _, _ = tex.te_gelu_fp8(
        inp,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
333
        fp8_tensor.value,
334
335
336
337
338
339
        int(otype),
    )

    return out


340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def swiglu(
    inp: paddle.Tensor,
    otype: tex.DType,
) -> paddle.Tensor:
    """Non FP8 SWIGLU"""
    return tex.te_swiglu(
        inp,
        int(otype),
    )


def swiglu_pd(inp: paddle.Tensor,) -> paddle.Tensor:
    """Native SWIGLU"""
    gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1)
    out = F.silu(gate_out) * up_out
    return out


def swiglu_fp8(
    inp: paddle.Tensor,
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
    otype: tex.DType,
) -> paddle.Tensor:
    """SWIGLU + FP8 cast"""
    out, _, _ = tex.te_swiglu_fp8(
        inp,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor.value,
        int(otype),
    )

    return out


def dswiglu(
    grad_output: paddle.Tensor,
    swiglu_input: paddle.Tensor,
    otype: tex.DType,
) -> paddle.Tensor:
    """dSWIGLU"""
    return tex.te_dswiglu(
        grad_output,
        swiglu_input,
        int(otype),
    )


390
391
392
def dgelu_cast_transpose_bgrad_fp8(
    grad_output: paddle.Tensor,
    gelu_input: paddle.Tensor,
393
394
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
395
396
397
398
399
400
401
402
403
404
405
406
    otype: tex.DType,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """
    Fused dgelu + cast / transpose / reduce the result of
    the GELU backward along the first dimension
    """
    cast_dgelu, transpose_dgelu, dbias, _, _ = tex.te_cast_transpose_bgrad_dgelu(
        grad_output,
        gelu_input,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
407
        fp8_tensor.value,
408
409
410
411
412
413
414
415
416
417
418
        int(otype),
    )

    return cast_dgelu, transpose_dgelu, dbias


def layernorm_fwd_fp8(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    bias: paddle.Tensor,
    eps: float,
419
420
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
421
422
423
424
425
426
427
428
    otype: tex.DType,
    sm_margin: int = 0,
    zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """LayerNorm with FP8 output"""
    out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8(inp, weight, bias, fp8_meta_tensor.scale,
                                                     fp8_meta_tensor.amax_history,
                                                     fp8_meta_tensor.scale_inv, eps,
429
                                                     fp8_tensor.value, int(otype), sm_margin,
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
                                                     zero_centered_gamma)
    return out, mu, rsigma


def layernorm_fwd(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    bias: paddle.Tensor,
    eps: float,
    otype: tex.DType,
    sm_margin: int = 0,
    zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 LayerNorm forward"""
    return tex.te_layernorm_fwd(inp, weight, bias, eps, int(otype), sm_margin, zero_centered_gamma)


def layernorm_bwd(
    dz: paddle.Tensor,
    x: paddle.Tensor,
    mu: paddle.Tensor,
    rsigma: paddle.Tensor,
    gamma: paddle.Tensor,
    sm_margin: int = 0,
    zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 LayerNorm backward"""
    return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma)
Shijie's avatar
Shijie committed
458
459
460
461
462
463
464
465


def rmsnorm_fwd(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    eps: float,
    otype: tex.DType,
    sm_margin: int = 0,
466
    zero_centered_gamma: bool = False,
Shijie's avatar
Shijie committed
467
468
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 RMSNorm forward"""
469
    return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin, zero_centered_gamma)
Shijie's avatar
Shijie committed
470
471
472
473
474
475


def rmsnorm_fwd_fp8(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    eps: float,
476
477
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
Shijie's avatar
Shijie committed
478
479
    otype: tex.DType,
    sm_margin: int = 0,
480
    zero_centered_gamma: bool = False,
Shijie's avatar
Shijie committed
481
482
483
484
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """RMSNorm with FP8 output"""
    out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale,
                                               fp8_meta_tensor.amax_history,
485
                                               fp8_meta_tensor.scale_inv, eps, fp8_tensor.value,
486
                                               int(otype), sm_margin, zero_centered_gamma)
Shijie's avatar
Shijie committed
487
488
489
490
491
492
493
494
495
    return out, rsigma


def rmsnorm_bwd(
    dz: paddle.Tensor,
    x: paddle.Tensor,
    rsigma: paddle.Tensor,
    gamma: paddle.Tensor,
    sm_margin: int = 0,
496
    zero_centered_gamma: bool = False,
Shijie's avatar
Shijie committed
497
498
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 RMSNorm backward"""
499
    return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma)
Shijie's avatar
Shijie committed
500
501


502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
def mask_to_cu_seqlens(
    mask: paddle.Tensor,
    need_kv: bool = False,
) -> paddle.Tensor:
    """Convert mask to cu_seqlens"""
    # mask shape: [b, 1, s_q, s_kv]
    q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3]
    q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
    q_cu_seqlens[0] = 0
    kv_cu_seqlens = None
    if need_kv:
        kv_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
        kv_cu_seqlens[0] = 0
    tex.mask_to_cu_seqlens(mask, q_cu_seqlens, kv_cu_seqlens, q_seqlen, kv_seqlen, need_kv)
    return q_cu_seqlens, kv_cu_seqlens


Shijie's avatar
Shijie committed
519
520
521
522
523
524
def fused_attn_fwd_qkvpacked(
    qkv: paddle.Tensor,
    cu_seqlens: paddle.Tensor,
    is_training: bool,
    max_seqlen: int,
    qkv_dtype: tex.DType,
525
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
526
527
528
529
    Bias: paddle.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
530
    qkv_layout: str = "bs3hd",
Shijie's avatar
Shijie committed
531
532
533
534
535
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
    """Fused Attention FWD for packed QKV input"""

536
537
    assert (qkv_dtype in (tex.DType.kBFloat16,
                          tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
Shijie's avatar
Shijie committed
538

539
    b = cu_seqlens.shape[0] - 1
Shijie's avatar
Shijie committed
540
541
542
543
544
545
546
547
548
549
550
551
552
    total_seqs = qkv.shape[0] * qkv.shape[1]
    h = qkv.shape[3]
    d = qkv.shape[4]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    if bias_type != "no_bias":
        assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
        assert (Bias.shape == [1, h, max_seqlen, max_seqlen
                              ]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
        assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv."

553
554
555
556
557
558
559
560
561
562
563
564
    assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
           ), "Fused attention does not support this input combination."

    # BF16/FP16 fused attention API from fmha_v1 apex
    if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
        rng_elts_per_thread = (max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA -
                               1) // BACKEND_F16m512_THREADS_PER_CTA

    # BF16/FP16 fused attention API from fmha_v2
    if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
        rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS

Shijie's avatar
Shijie committed
565
    if set_zero:
Shijie's avatar
Shijie committed
566
        out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype)
Shijie's avatar
Shijie committed
567
    else:
Shijie's avatar
Shijie committed
568
        out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype)
Shijie's avatar
Shijie committed
569
570

    if is_training:
Shijie's avatar
Shijie committed
571
572
573
574
575
576
        if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
        elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype='float32')
        else:
            raise ValueError("Unsupported fused attention backend.")
Shijie's avatar
Shijie committed
577
578
579
    else:
        softmax_aux = None

580
581
582
583
    rng_state = paddle.empty(shape=[
        2,
    ], dtype=paddle.int64)

Shijie's avatar
Shijie committed
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    # execute kernel
    tex.te_fused_attn_fwd_qkvpacked(
        qkv,
        cu_seqlens,
        Bias,
        out,
        softmax_aux,
        rng_state,
        b,
        h,
        d,
        total_seqs,
        max_seqlen,
        is_training,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
604
        rng_elts_per_thread,
Shijie's avatar
Shijie committed
605
    )
606
    return out, softmax_aux, rng_state
Shijie's avatar
Shijie committed
607
608
609
610
611


def fused_attn_bwd_qkvpacked(
    qkv: paddle.Tensor,
    cu_seqlens: paddle.Tensor,
612
    rng_state: paddle.Tensor,
Shijie's avatar
Shijie committed
613
614
615
    o: paddle.Tensor,
    d_o: paddle.Tensor,
    softmax_aux: paddle.Tensor,
616
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
617
618
619
620
621
    max_seqlen: int,
    qkv_dtype: tex.DType,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
622
    qkv_layout: str = "bs3hd",
Shijie's avatar
Shijie committed
623
624
625
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
626
    """Fused Attention BWD for packed QKV input"""
Shijie's avatar
Shijie committed
627

628
629
    assert (qkv_dtype in (tex.DType.kBFloat16,
                          tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
Shijie's avatar
Shijie committed
630

631
    b = cu_seqlens.shape[0] - 1
Shijie's avatar
Shijie committed
632
633
634
635
636
637
638
    total_seqs = qkv.shape[0] * qkv.shape[1]
    h = qkv.shape[3]
    d = qkv.shape[4]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

639
640
641
    assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
           ), "Fused attention does not support this input combination."

Shijie's avatar
Shijie committed
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    if set_zero:
        dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
    else:
        dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype)

    if bias_type != "no_bias":
        dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
    else:
        dbias = None
    # execute kernel
    dqkv, dbias = tex.te_fused_attn_bwd_qkvpacked(
        qkv,
        cu_seqlens,
        o,
        d_o,
        softmax_aux,
        dqkv,
        dbias,
660
        rng_state,
Shijie's avatar
Shijie committed
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
        b,
        h,
        d,
        total_seqs,
        max_seqlen,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
    )

    return dqkv, dbias


def fused_attn_fwd_kvpacked(
    q: paddle.Tensor,
    kv: paddle.Tensor,
    cu_seqlens_q: paddle.Tensor,
    cu_seqlens_kv: paddle.Tensor,
    is_training: bool,
    max_seqlen_q: int,
    max_seqlen_kv: int,
    qkv_dtype: tex.DType,
686
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
687
688
689
690
    Bias: paddle.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
691
    qkv_layout: str = "bshd_bs2hd",
Shijie's avatar
Shijie committed
692
693
694
695
696
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
    """Fused Attention FWD for packed KV input"""

697
698
    assert (qkv_dtype in (tex.DType.kBFloat16,
                          tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
Shijie's avatar
Shijie committed
699
700
701
    assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
           ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"

702
    b = cu_seqlens_q.shape[0] - 1
Shijie's avatar
Shijie committed
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    total_seqs_q = q.shape[0] * q.shape[1]
    total_seqs_kv = kv.shape[0] * kv.shape[1]
    h = q.shape[2]
    d = q.shape[3]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    if bias_type != "no_bias":
        assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
        assert (Bias.shape == [1, h, max_seqlen_q, max_seqlen_kv
                              ]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
        assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv."

717
718
719
720
721
722
723
724
725
726
727
728
    assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
           ), "Fused attention does not support this input combination."

    # BF16/FP16 fused attention API from fmha_v1 apex
    if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
        rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA -
                               1) // BACKEND_F16m512_THREADS_PER_CTA

    # BF16/FP16 fused attention API from fmha_v2
    if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
        rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS

Shijie's avatar
Shijie committed
729
    if set_zero:
Shijie's avatar
Shijie committed
730
        out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
Shijie's avatar
Shijie committed
731
    else:
Shijie's avatar
Shijie committed
732
        out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
Shijie's avatar
Shijie committed
733
734

    if is_training:
Shijie's avatar
Shijie committed
735
736
737
738
739
740
        if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
        elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32')
        else:
            raise ValueError("Unsupported fused attention backend.")
Shijie's avatar
Shijie committed
741
742
743
    else:
        softmax_aux = None

744
745
746
747
    rng_state = paddle.empty(shape=[
        2,
    ], dtype=paddle.int64)

Shijie's avatar
Shijie committed
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
    # execute kernel
    tex.te_fused_attn_fwd_kvpacked(
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_kv,
        Bias,
        out,
        softmax_aux,
        rng_state,
        b,
        h,
        d,
        total_seqs_q,
        total_seqs_kv,
        max_seqlen_q,
        max_seqlen_kv,
        is_training,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
772
        rng_elts_per_thread,
Shijie's avatar
Shijie committed
773
774
    )

775
    return out, softmax_aux, rng_state
Shijie's avatar
Shijie committed
776
777
778
779
780
781
782


def fused_attn_bwd_kvpacked(
    q: paddle.Tensor,
    kv: paddle.Tensor,
    cu_seqlens_q: paddle.Tensor,
    cu_seqlens_kv: paddle.Tensor,
783
    rng_state: paddle.Tensor,
Shijie's avatar
Shijie committed
784
785
786
    o: paddle.Tensor,
    d_o: paddle.Tensor,
    softmax_aux: paddle.Tensor,
787
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
788
789
790
791
792
793
    max_seqlen_q: int,
    max_seqlen_kv: int,
    qkv_dtype: tex.DType,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
794
    qkv_layout: str = "bshd_bs2hd",
Shijie's avatar
Shijie committed
795
796
797
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
798
    """Fused Attention BWD for packed KV input"""
Shijie's avatar
Shijie committed
799

800
801
802
803
    assert (qkv_dtype in (tex.DType.kBFloat16,
                          tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
    assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
           ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
Shijie's avatar
Shijie committed
804

805
    b = cu_seqlens_q.shape[0] - 1
Shijie's avatar
Shijie committed
806
807
808
809
810
811
812
813
    total_seqs_q = q.shape[0] * q.shape[1]
    total_seqs_kv = kv.shape[0] * kv.shape[1]
    h = q.shape[2]
    d = q.shape[3]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

814
815
816
    assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
           ), "Fused attention does not support this input combination."

Shijie's avatar
Shijie committed
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
    if set_zero:
        dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
        dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
    else:
        dq = paddle.empty(shape=q.shape, dtype=q.dtype)
        dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype)
    if bias_type != "no_bias":
        dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
    else:
        dbias = None
    # execute kernel
    tex.te_fused_attn_bwd_kvpacked(
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_kv,
        o,
        d_o,
        softmax_aux,
        dq,
        dkv,
        dbias,
839
        rng_state,
Shijie's avatar
Shijie committed
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
        b,
        h,
        d,
        total_seqs_q,
        total_seqs_kv,
        max_seqlen_q,
        max_seqlen_kv,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
    )
    return dq, dkv, dbias


Shijie's avatar
Shijie committed
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
def fused_attn_fwd(
    q: paddle.Tensor,
    k: paddle.Tensor,
    v: paddle.Tensor,
    cu_seqlens_q: paddle.Tensor,
    cu_seqlens_kv: paddle.Tensor,
    is_training: bool,
    max_seqlen_q: int,
    max_seqlen_kv: int,
    qkv_dtype: tex.DType,
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
    Bias: paddle.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
    qkv_layout: str = "bshd_bshd_bshd",
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
    """Fused Attention FWD for unpacked QKV input"""

    assert (qkv_dtype in (tex.DType.kBFloat16,
                          tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
    assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
           ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
    assert (qkv_layout == "bshd_bshd_bshd"
           ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
    b = cu_seqlens_q.shape[0] - 1

    h = q.shape[-2]
    d = q.shape[-1]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    if bias_type != "no_bias":
        assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
        assert (Bias.shape == [
            1, h, max_seqlen_q, max_seqlen_kv
        ]), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
        assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as qkv."

    assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
           ), "Fused attention does not support this input combination."

    # BF16/FP16 fused attention API from fmha_v1 apex
    if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
        rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA -
                               1) // BACKEND_F16m512_THREADS_PER_CTA

    # BF16/FP16 fused attention API from fmha_v2
    if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
        rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS

    if set_zero:
        out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
    else:
        out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)

    if is_training:
        if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
        elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32')
        else:
            raise ValueError("Unsupported fused attention backend.")
    else:
        softmax_aux = None

    rng_state = paddle.empty(shape=[
        2,
    ], dtype=paddle.int64)

    # execute kernel
    tex.te_fused_attn_fwd(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        Bias,
        out,
        softmax_aux,
        rng_state,
        b,
        h,
        d,
        max_seqlen_q,
        max_seqlen_kv,
        is_training,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
        rng_elts_per_thread,
    )
    return out, softmax_aux, rng_state


def fused_attn_bwd(
    q: paddle.Tensor,
    k: paddle.Tensor,
    v: paddle.Tensor,
    cu_seqlens_q: paddle.Tensor,
    cu_seqlens_kv: paddle.Tensor,
    rng_state: paddle.Tensor,
    o: paddle.Tensor,
    d_o: paddle.Tensor,
    softmax_aux: paddle.Tensor,
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
    max_seqlen_q: int,
    max_seqlen_kv: int,
    qkv_dtype: tex.DType,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
    qkv_layout: str = "bshd_bshd_bshd",
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Fused Attention BWD for packed KV input"""

    assert (qkv_dtype in (tex.DType.kBFloat16,
                          tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
    assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
           ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
    assert (qkv_layout == "bshd_bshd_bshd"
           ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."

    b = cu_seqlens_q.shape[0] - 1
    h = q.shape[-2]
    d = q.shape[-1]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
           ), "Fused attention does not support this input combination."

    if set_zero:
        dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
        dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype)
        dv = paddle.full(shape=v.shape, fill_value=0, dtype=v.dtype)
    else:
        dq = paddle.empty(shape=q.shape, dtype=q.dtype)
        dk = paddle.empty(shape=k.shape, dtype=k.dtype)
        dv = paddle.empty(shape=v.shape, dtype=v.dtype)
    if bias_type != "no_bias":
        dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
    else:
        dbias = None
    # execute kernel
    tex.te_fused_attn_bwd(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        o,
        d_o,
        softmax_aux,
        dq,
        dk,
        dv,
        dbias,
        rng_state,
        b,
        h,
        d,
        max_seqlen_q,
        max_seqlen_kv,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
    )
    return dq, dk, dv, dbias


Shijie's avatar
Shijie committed
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
def scaled_softmax_forward(
    inp: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled softmax forward"""
    return tex.te_scaled_softmax_forward(inp, scale_factor)


def scaled_softmax_backward(
    out_grad: paddle.Tensor,
    softmax_results: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled softmax backward"""
    tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
    return out_grad


def scaled_masked_softmax_forward(
    inp: paddle.Tensor,
    mask: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled masked softmax forward"""

    return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor)


def scaled_masked_softmax_backward(
    out_grad: paddle.Tensor,
    softmax_results: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled masked softmax backward"""
    tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
    return out_grad


def scaled_upper_triang_masked_softmax_forward(
    inp: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled upper triang masked softmax forward"""
    return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor)


def scaled_upper_triang_masked_softmax_backward(
    out_grad: paddle.Tensor,
    softmax_results: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled upper triang masked softmax backward"""
    tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor)
    return out_grad