cpp_extensions.py 23.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""TE FP8 extensions and GEMMs"""

Shijie's avatar
Shijie committed
6
import math
7
8
9
from typing import Optional, Tuple, Union
import paddle
import transformer_engine_paddle as tex
10
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
11
from .fp8 import FP8TensorMeta
12

13
14
15
BACKEND_F16m512_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

def gemm(
    A: paddle.Tensor,
    B: paddle.Tensor,
    dtype: paddle.dtype,
    workspace: paddle.Tensor,
    gelu: bool = False,
    gelu_input: Optional[paddle.Tensor] = None,
    grad: bool = False,
    accumulate: bool = False,
    layout: str = "TN",
    out: Optional[paddle.Tensor] = None,
    bias: Optional[paddle.Tensor] = None,
    use_bias: bool = False,
) -> Tuple[Union[paddle.Tensor, None], ...]:
    """Non FP8 GEMM."""

    assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
    transa = layout[0] == "T"
    transb = layout[1] == "T"

    return_output = False
    if out is None:
        out = paddle.empty(
            shape=[
                B.shape[1] if transb else B.shape[0],
                A.shape[0] if transa else A.shape[1],
            ],
            dtype=dtype,
        )
        return_output = True

    if gelu and not grad:
        gelu_input = paddle.empty_like(out, dtype=dtype)
    elif not gelu:
        gelu_input = None

    if grad and use_bias:
        grad_bias = paddle.empty(shape=[B.shape[1]], dtype=out.dtype)
    else:
        grad_bias = None

    bias = bias if use_bias else None

    assert A.dtype == dtype and B.dtype == dtype, \
        f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
    input_dtype = TE_DType[dtype]
    output_dtype = TE_DType[out.dtype]
    if use_bias:
        bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
    else:
        bias_dtype = output_dtype

    tex.te_gemm(
        A,
        None,
        B,
        None,
        grad_bias if grad else bias,
        out,
        None,    # out_scale
        None,    # out_amax
        gelu_input,
        workspace,
        0,    # A_index
        0,    # B_index
        0,    # D_index
        int(input_dtype),
        int(input_dtype),
        int(output_dtype),
        int(bias_dtype),
        transa,
        transb,
        grad,
        workspace.shape[0],
        accumulate,
        False,    # use_split_accumulator
        0,    # math_sm_count
    )

    if return_output:
        return out, grad_bias, gelu_input
    return None, grad_bias, gelu_input


def fp8_gemm(
    A: paddle.Tensor,
    A_scale_inv: paddle.Tensor,
104
    A_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
105
106
107
    A_dtype: tex.DType,
    B: paddle.Tensor,
    B_scale_inv: paddle.Tensor,
108
    B_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
109
110
111
112
113
114
115
    B_dtype: tex.DType,
    out_dtype: paddle.dtype,
    workspace: paddle.Tensor,
    gelu: bool = False,
    accumulate: bool = False,
    out: Optional[paddle.Tensor] = None,
    out_index=None,
116
    fp8_meta_tensor: FP8TensorMeta = None,
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    bias: Optional[paddle.Tensor] = None,
    use_bias: bool = False,
    use_split_accumulator: bool = False,
    D_dtype: Optional[tex.DType] = None,
) -> paddle.Tensor:
    """TN layout GEMM with fp8 inputs."""

    if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
        assert fp8_meta_tensor is not None and out_index is not None

    return_output = False
    if out is None:
        out = paddle.empty(
            shape=[
                B.shape[0],
                A.shape[0],
            ],
            dtype=out_dtype,
        )
        return_output = True
    # Use bfloat16 as default bias_dtype
    bias_dtype = paddle.bfloat16 if bias is None else bias.dtype
    if gelu:
        gelu_input = paddle.empty_like(out, dtype=bias_dtype)
    else:
        gelu_input = None
    bias_dtype = TE_DType[bias_dtype]

    out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype

    tex.te_gemm(
        A,
        A_scale_inv,
        B,
        B_scale_inv,
        bias if use_bias else None,
        out,
        None if out_index is None else fp8_meta_tensor.scale,
        None if out_index is None else fp8_meta_tensor.amax_history,
        gelu_input,    # this is pre_gelu_out
        workspace,
158
159
        A_fp8_tensor.value,
        B_fp8_tensor.value,
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        0 if out_index is None else out_index,
        int(A_dtype),
        int(B_dtype),
        int(out_dtype),
        int(bias_dtype),
        True,    # transa
        False,    # transb
        False,    # grad
        workspace.shape[0],
        accumulate,
        use_split_accumulator,
        0,    # math_sm_count
    )

    if return_output:
        if gelu:
            return out, gelu_input
        return out
    if gelu:
        return gelu_input
    return None


def cast_to_fp8(
    inp: paddle.Tensor,
185
186
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
187
188
189
190
191
192
193
194
    otype: tex.DType,
) -> paddle.Tensor:
    """Cast input to FP8"""
    out, _, _ = tex.cast_to_fp8(
        inp,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
195
        fp8_tensor.value,
196
197
198
199
200
201
202
        int(otype),
    )
    return out


def cast_from_fp8(
    inp: paddle.Tensor,
203
204
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
205
206
207
208
209
210
211
    itype: tex.DType,
    otype: tex.DType,
) -> paddle.Tensor:
    """Cast input from FP8"""
    return tex.cast_from_fp8(
        inp,
        fp8_meta_tensor.scale_inv,
212
        fp8_tensor.value,
213
214
215
        int(itype),
        int(otype),
    )
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230


def transpose(
    inp: paddle.Tensor,
    otype: tex.DType,
) -> paddle.Tensor:
    """Transpose input"""
    return tex.te_transpose(
        inp,
        int(otype),
    )


def cast_transpose(
    inp: paddle.Tensor,
231
232
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
233
234
235
236
237
238
239
240
    otype: tex.DType,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]:
    """Cast + Transpose with FP8 output"""
    cast_out, transpose_out, _, _ = tex.te_cast_transpose(
        inp,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
241
        fp8_tensor.value,
242
243
244
245
246
247
        int(otype),
    )

    return cast_out, transpose_out


248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def cast_transpose_bgrad(
    inp: paddle.Tensor,
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
    otype: tex.DType,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor], None]:
    """Fused Cast + Transpose + Bias Grad"""
    grad_bias, cast_out, transpose_out, _, _ = tex.te_cast_transpose_bgrad(
        inp,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor.value,
        int(otype),
    )

    return grad_bias, cast_out, transpose_out


267
268
269
270
271
272
273
274
275
276
277
278
279
def te_gelu(
    inp: paddle.Tensor,
    otype: tex.DType,
) -> paddle.Tensor:
    """Non FP8 GELU"""
    return tex.te_gelu(
        inp,
        int(otype),
    )


def gelu_fp8(
    inp: paddle.Tensor,
280
281
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
282
283
284
285
286
287
288
289
    otype: tex.DType,
) -> paddle.Tensor:
    """GELU + FP8 cast"""
    out, _, _ = tex.te_gelu_fp8(
        inp,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
290
        fp8_tensor.value,
291
292
293
294
295
296
297
298
299
        int(otype),
    )

    return out


def dgelu_cast_transpose_bgrad_fp8(
    grad_output: paddle.Tensor,
    gelu_input: paddle.Tensor,
300
301
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
302
303
304
305
306
307
308
309
310
311
312
313
    otype: tex.DType,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """
    Fused dgelu + cast / transpose / reduce the result of
    the GELU backward along the first dimension
    """
    cast_dgelu, transpose_dgelu, dbias, _, _ = tex.te_cast_transpose_bgrad_dgelu(
        grad_output,
        gelu_input,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
314
        fp8_tensor.value,
315
316
317
318
319
320
321
322
323
324
325
        int(otype),
    )

    return cast_dgelu, transpose_dgelu, dbias


def layernorm_fwd_fp8(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    bias: paddle.Tensor,
    eps: float,
326
327
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
328
329
330
331
332
333
334
335
    otype: tex.DType,
    sm_margin: int = 0,
    zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """LayerNorm with FP8 output"""
    out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8(inp, weight, bias, fp8_meta_tensor.scale,
                                                     fp8_meta_tensor.amax_history,
                                                     fp8_meta_tensor.scale_inv, eps,
336
                                                     fp8_tensor.value, int(otype), sm_margin,
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
                                                     zero_centered_gamma)
    return out, mu, rsigma


def layernorm_fwd(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    bias: paddle.Tensor,
    eps: float,
    otype: tex.DType,
    sm_margin: int = 0,
    zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 LayerNorm forward"""
    return tex.te_layernorm_fwd(inp, weight, bias, eps, int(otype), sm_margin, zero_centered_gamma)


def layernorm_bwd(
    dz: paddle.Tensor,
    x: paddle.Tensor,
    mu: paddle.Tensor,
    rsigma: paddle.Tensor,
    gamma: paddle.Tensor,
    sm_margin: int = 0,
    zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 LayerNorm backward"""
    return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma)
Shijie's avatar
Shijie committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381


def rmsnorm_fwd(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    eps: float,
    otype: tex.DType,
    sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 RMSNorm forward"""
    return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin)


def rmsnorm_fwd_fp8(
    inp: paddle.Tensor,
    weight: paddle.Tensor,
    eps: float,
382
383
    fp8_meta_tensor: FP8TensorMeta,
    fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
Shijie's avatar
Shijie committed
384
385
386
387
388
389
    otype: tex.DType,
    sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """RMSNorm with FP8 output"""
    out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale,
                                               fp8_meta_tensor.amax_history,
390
                                               fp8_meta_tensor.scale_inv, eps, fp8_tensor.value,
Shijie's avatar
Shijie committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                                               int(otype), sm_margin)
    return out, rsigma


def rmsnorm_bwd(
    dz: paddle.Tensor,
    x: paddle.Tensor,
    rsigma: paddle.Tensor,
    gamma: paddle.Tensor,
    sm_margin: int = 0,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
    """Non-FP8 RMSNorm backward"""
    return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin)


406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
def mask_to_cu_seqlens(
    mask: paddle.Tensor,
    need_kv: bool = False,
) -> paddle.Tensor:
    """Convert mask to cu_seqlens"""
    # mask shape: [b, 1, s_q, s_kv]
    q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3]
    q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
    q_cu_seqlens[0] = 0
    kv_cu_seqlens = None
    if need_kv:
        kv_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
        kv_cu_seqlens[0] = 0
    tex.mask_to_cu_seqlens(mask, q_cu_seqlens, kv_cu_seqlens, q_seqlen, kv_seqlen, need_kv)
    return q_cu_seqlens, kv_cu_seqlens


Shijie's avatar
Shijie committed
423
424
425
426
427
428
def fused_attn_fwd_qkvpacked(
    qkv: paddle.Tensor,
    cu_seqlens: paddle.Tensor,
    is_training: bool,
    max_seqlen: int,
    qkv_dtype: tex.DType,
429
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
430
431
432
433
    Bias: paddle.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
434
    qkv_layout: str = "bs3hd",
Shijie's avatar
Shijie committed
435
436
437
438
439
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
    """Fused Attention FWD for packed QKV input"""

440
441
    assert (qkv_dtype in (tex.DType.kBFloat16,
                          tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
Shijie's avatar
Shijie committed
442

443
    b = cu_seqlens.shape[0] - 1
Shijie's avatar
Shijie committed
444
445
446
447
448
449
450
451
452
453
454
455
456
    total_seqs = qkv.shape[0] * qkv.shape[1]
    h = qkv.shape[3]
    d = qkv.shape[4]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    if bias_type != "no_bias":
        assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
        assert (Bias.shape == [1, h, max_seqlen, max_seqlen
                              ]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
        assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv."

457
458
459
460
461
462
463
464
465
466
467
468
    assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
           ), "Fused attention does not support this input combination."

    # BF16/FP16 fused attention API from fmha_v1 apex
    if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
        rng_elts_per_thread = (max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA -
                               1) // BACKEND_F16m512_THREADS_PER_CTA

    # BF16/FP16 fused attention API from fmha_v2
    if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
        rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS

Shijie's avatar
Shijie committed
469
    if set_zero:
Shijie's avatar
Shijie committed
470
        out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype)
Shijie's avatar
Shijie committed
471
    else:
Shijie's avatar
Shijie committed
472
        out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype)
Shijie's avatar
Shijie committed
473
474

    if is_training:
Shijie's avatar
Shijie committed
475
476
477
478
479
480
        if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
        elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype='float32')
        else:
            raise ValueError("Unsupported fused attention backend.")
Shijie's avatar
Shijie committed
481
482
483
    else:
        softmax_aux = None

484
485
486
487
    rng_state = paddle.empty(shape=[
        2,
    ], dtype=paddle.int64)

Shijie's avatar
Shijie committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    # execute kernel
    tex.te_fused_attn_fwd_qkvpacked(
        qkv,
        cu_seqlens,
        Bias,
        out,
        softmax_aux,
        rng_state,
        b,
        h,
        d,
        total_seqs,
        max_seqlen,
        is_training,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
508
        rng_elts_per_thread,
Shijie's avatar
Shijie committed
509
    )
510
    return out, softmax_aux, rng_state
Shijie's avatar
Shijie committed
511
512
513
514
515


def fused_attn_bwd_qkvpacked(
    qkv: paddle.Tensor,
    cu_seqlens: paddle.Tensor,
516
    rng_state: paddle.Tensor,
Shijie's avatar
Shijie committed
517
518
519
    o: paddle.Tensor,
    d_o: paddle.Tensor,
    softmax_aux: paddle.Tensor,
520
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
521
522
523
524
525
    max_seqlen: int,
    qkv_dtype: tex.DType,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
526
    qkv_layout: str = "bs3hd",
Shijie's avatar
Shijie committed
527
528
529
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
530
    """Fused Attention BWD for packed QKV input"""
Shijie's avatar
Shijie committed
531

532
533
    assert (qkv_dtype in (tex.DType.kBFloat16,
                          tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
Shijie's avatar
Shijie committed
534

535
    b = cu_seqlens.shape[0] - 1
Shijie's avatar
Shijie committed
536
537
538
539
540
541
542
    total_seqs = qkv.shape[0] * qkv.shape[1]
    h = qkv.shape[3]
    d = qkv.shape[4]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

543
544
545
    assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
           ), "Fused attention does not support this input combination."

Shijie's avatar
Shijie committed
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    if set_zero:
        dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
    else:
        dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype)

    if bias_type != "no_bias":
        dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
    else:
        dbias = None
    # execute kernel
    dqkv, dbias = tex.te_fused_attn_bwd_qkvpacked(
        qkv,
        cu_seqlens,
        o,
        d_o,
        softmax_aux,
        dqkv,
        dbias,
564
        rng_state,
Shijie's avatar
Shijie committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        b,
        h,
        d,
        total_seqs,
        max_seqlen,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
    )

    return dqkv, dbias


def fused_attn_fwd_kvpacked(
    q: paddle.Tensor,
    kv: paddle.Tensor,
    cu_seqlens_q: paddle.Tensor,
    cu_seqlens_kv: paddle.Tensor,
    is_training: bool,
    max_seqlen_q: int,
    max_seqlen_kv: int,
    qkv_dtype: tex.DType,
590
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
591
592
593
594
    Bias: paddle.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
595
    qkv_layout: str = "bshd_bs2hd",
Shijie's avatar
Shijie committed
596
597
598
599
600
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
    """Fused Attention FWD for packed KV input"""

601
602
    assert (qkv_dtype in (tex.DType.kBFloat16,
                          tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
Shijie's avatar
Shijie committed
603
604
605
    assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
           ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"

606
    b = cu_seqlens_q.shape[0] - 1
Shijie's avatar
Shijie committed
607
608
609
610
611
612
613
614
615
616
617
618
619
620
    total_seqs_q = q.shape[0] * q.shape[1]
    total_seqs_kv = kv.shape[0] * kv.shape[1]
    h = q.shape[2]
    d = q.shape[3]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    if bias_type != "no_bias":
        assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
        assert (Bias.shape == [1, h, max_seqlen_q, max_seqlen_kv
                              ]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
        assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv."

621
622
623
624
625
626
627
628
629
630
631
632
    assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
           ), "Fused attention does not support this input combination."

    # BF16/FP16 fused attention API from fmha_v1 apex
    if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
        rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA -
                               1) // BACKEND_F16m512_THREADS_PER_CTA

    # BF16/FP16 fused attention API from fmha_v2
    if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
        rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS

Shijie's avatar
Shijie committed
633
    if set_zero:
Shijie's avatar
Shijie committed
634
        out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
Shijie's avatar
Shijie committed
635
    else:
Shijie's avatar
Shijie committed
636
        out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
Shijie's avatar
Shijie committed
637
638

    if is_training:
Shijie's avatar
Shijie committed
639
640
641
642
643
644
        if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
        elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
            softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32')
        else:
            raise ValueError("Unsupported fused attention backend.")
Shijie's avatar
Shijie committed
645
646
647
    else:
        softmax_aux = None

648
649
650
651
    rng_state = paddle.empty(shape=[
        2,
    ], dtype=paddle.int64)

Shijie's avatar
Shijie committed
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    # execute kernel
    tex.te_fused_attn_fwd_kvpacked(
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_kv,
        Bias,
        out,
        softmax_aux,
        rng_state,
        b,
        h,
        d,
        total_seqs_q,
        total_seqs_kv,
        max_seqlen_q,
        max_seqlen_kv,
        is_training,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
676
        rng_elts_per_thread,
Shijie's avatar
Shijie committed
677
678
    )

679
    return out, softmax_aux, rng_state
Shijie's avatar
Shijie committed
680
681
682
683
684
685
686


def fused_attn_bwd_kvpacked(
    q: paddle.Tensor,
    kv: paddle.Tensor,
    cu_seqlens_q: paddle.Tensor,
    cu_seqlens_kv: paddle.Tensor,
687
    rng_state: paddle.Tensor,
Shijie's avatar
Shijie committed
688
689
690
    o: paddle.Tensor,
    d_o: paddle.Tensor,
    softmax_aux: paddle.Tensor,
691
    fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Shijie's avatar
Shijie committed
692
693
694
695
696
697
    max_seqlen_q: int,
    max_seqlen_kv: int,
    qkv_dtype: tex.DType,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
698
    qkv_layout: str = "bshd_bs2hd",
Shijie's avatar
Shijie committed
699
700
701
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
702
    """Fused Attention BWD for packed KV input"""
Shijie's avatar
Shijie committed
703

704
705
706
707
    assert (qkv_dtype in (tex.DType.kBFloat16,
                          tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
    assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
           ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
Shijie's avatar
Shijie committed
708

709
    b = cu_seqlens_q.shape[0] - 1
Shijie's avatar
Shijie committed
710
711
712
713
714
715
716
717
    total_seqs_q = q.shape[0] * q.shape[1]
    total_seqs_kv = kv.shape[0] * kv.shape[1]
    h = q.shape[2]
    d = q.shape[3]

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

718
719
720
    assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
           ), "Fused attention does not support this input combination."

Shijie's avatar
Shijie committed
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
    if set_zero:
        dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
        dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
    else:
        dq = paddle.empty(shape=q.shape, dtype=q.dtype)
        dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype)
    if bias_type != "no_bias":
        dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
    else:
        dbias = None
    # execute kernel
    tex.te_fused_attn_bwd_kvpacked(
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_kv,
        o,
        d_o,
        softmax_aux,
        dq,
        dkv,
        dbias,
743
        rng_state,
Shijie's avatar
Shijie committed
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
        b,
        h,
        d,
        total_seqs_q,
        total_seqs_kv,
        max_seqlen_q,
        max_seqlen_kv,
        attn_scale,
        dropout,
        qkv_layout,
        bias_type,
        attn_mask_type,
        int(qkv_dtype),
    )
    return dq, dkv, dbias


def scaled_softmax_forward(
    inp: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled softmax forward"""
    return tex.te_scaled_softmax_forward(inp, scale_factor)


def scaled_softmax_backward(
    out_grad: paddle.Tensor,
    softmax_results: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled softmax backward"""
    tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
    return out_grad


def scaled_masked_softmax_forward(
    inp: paddle.Tensor,
    mask: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled masked softmax forward"""

    return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor)


def scaled_masked_softmax_backward(
    out_grad: paddle.Tensor,
    softmax_results: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled masked softmax backward"""
    tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
    return out_grad


def scaled_upper_triang_masked_softmax_forward(
    inp: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled upper triang masked softmax forward"""
    return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor)


def scaled_upper_triang_masked_softmax_backward(
    out_grad: paddle.Tensor,
    softmax_results: paddle.Tensor,
    scale_factor: float,
) -> paddle.Tensor:
    """ scaled upper triang masked softmax backward"""
    tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor)
    return out_grad