cpp_extensions.py 44.7 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""TE FP8 extensions and GEMMs"""
cyanguwa's avatar
cyanguwa committed
6
7
import math
from typing import Optional, Tuple, List, Union
Przemek Tredak's avatar
Przemek Tredak committed
8
9
10
11
import torch
import transformer_engine_extensions as tex
from .constants import TE_DType

cyanguwa's avatar
cyanguwa committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
TORCH_DType = {
    tex.DType.kFloat8E4M3: torch.uint8,
    tex.DType.kFloat8E5M2: torch.uint8,
    tex.DType.kFloat16: torch.half,
    tex.DType.kBFloat16: torch.bfloat16,
    tex.DType.kFloat32: torch.float32,
    tex.DType.kInt32: torch.int32,
}

def check_tensor(x: torch.Tensor):
    """Check tensor properties."""
    assert (x.is_cuda and x.is_contiguous()
            ), "Tensor should be a GPU tensor and contiguous."

def check_qkv(qkv: torch.Tensor, dtype: torch.dtype):
    """Check tensor properties."""
    check_tensor(qkv)
    assert (qkv.dtype is dtype
            and qkv.dim() == 4
            and qkv.shape[1] == 3
            ), """QKV should be in [total_seqs, 3, num_heads, head_dim] shape
    and {dtype} dtype."""

def check_q(q: torch.Tensor, dtype: torch.dtype):
    """Check tensor properties."""
    check_tensor(q)
    assert (q.dtype is dtype
            and q.dim() == 3
            ), """Q should be in [total_seqs, num_heads, head_dim] shape
    and {dtype} dtype."""

def check_kv(kv: torch.Tensor, dtype: torch.dtype):
    """Check tensor properties."""
    check_tensor(kv)
    assert (kv.dtype is dtype
            and kv.dim() == 4
            and kv.shape[1] == 2
            ), """KV should be in [total_seqs, 2, num_heads, head_dim] shape
    and {dtype} dtype."""

def check_o(o: torch.Tensor, dtype: torch.dtype):
    """Check tensor properties."""
    check_tensor(o)
    assert (o.dtype is dtype
            and o.dim() == 3
            ), """O and dO should be in [total_seqs, num_heads, head_dim] shape
    and {dtype} dtype."""

def check_stats(stats: torch.Tensor, b: int, h: int, s: int):
    """Check tensor properties."""
    check_tensor(stats)
    assert (stats.dtype is torch.float32
            and stats.dim() == 4
            and stats.shape == torch.Size([b, h, s, 1])
            ), """M and ZInv should be in [batch_size, num_heads, max_seqlen_q, 1]
    shape and float32 dtype."""

def check_cu_seqlens(cu_seqlens: torch.Tensor):
    """Check tensor properties."""
    check_tensor(cu_seqlens)
    assert (cu_seqlens.dtype is torch.int32
            and cu_seqlens.dim() == 1
            ), """cu_seqlens should be in [batch_size +1] shape and int32 dtype."""

def check_scalar(scalar: torch.Tensor):
    """Check tensor properties."""
    check_tensor(scalar)
    assert (scalar.dtype is torch.float32
            and scalar.dim() <= 1
            and scalar.numel() == 1
            ), "amax/scale/descale tensors should be scalars in float32 dtype."

def check_rng_state(rng_state: torch.Tensor):
    """Check tensor properties."""
    check_tensor(rng_state)
    assert (rng_state.dtype is torch.int64
            and rng_state.numel() == 2
            ), "rng_state should be [seed, offset] and in int64 dtype."

def fused_attn_fwd_qkvpacked(
    is_training: bool,
    max_seqlen: int,
    cu_seqlens: torch.Tensor,
    qkv: torch.Tensor,
    qkv_dtype: tex.DType,
    bias: torch.Tensor = None,
    d_scale_qkv: torch.Tensor = None,
    q_scale_s: torch.Tensor = None,
    q_scale_o: torch.Tensor = None,
    amax_s: torch.Tensor = None,
    amax_o: torch.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
    qkv_layout: str = "qkv_interleaved",
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
    rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
    """Fused Attention FWD for packed QKV input.

    Parameters
    ----------
    is_training: bool
                if True, runs training and produces auxiliary tensors aux_ctx_tensors
                for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
    max_seqlen: int
                max sequence length for QKV, used for padding; may be larger than max(cu_seqlens)
    cu_seqlens: torch.Tensor
                accumulative sequence lengths for QKV; shape [batch_size + 1]
    qkv: torch.Tensor
                input tensor QKV;
                shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
    qkv_dtype: tex.DType
                data type of QKV; in tex.DType, not torch.dtype
    bias: torch.Tensor, default = None
128
129
                input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
                shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
cyanguwa's avatar
cyanguwa committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    d_scale_qkv: torch.Tensor, default = None
                input tensor for the dequantization of QKV in FP8 computations
    q_scale_s: torch.Tensor, default = None
                input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
    q_scale_o: torch.Tensor, default = None
                input tensor for the quantization of O in FP8 computations
    amax_s: torch.Tensor, default = None
                output tensor, amax of S, used by the next iteration in FP8 computations
    amax_o: torch.Tensor, default = None
                output tensor, amax of O, used by the next iteration in FP8 computations
    attn_scale: float, default = None
                if not None, use attn_scale as the attention scale for Q*K.T BMM;
                if None, use 1.0/sqrt(head_dim) as the default
    dropout: float, default = 0.0
                dropout probability, 0.0 means no dropout, 1.0 means no output;
                dropout must be 0.0 if is_training is False
    set_zero: bool, default = True
                if True, initializes the output tensor O to zero using the mha_fill method;
                if False, doesn't initialize O after its allocation
    qkv_layout: str, default = "qkv_interleaved"
                layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
    bias_type: str, default = "no_bias"
                type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
    attn_mask_type: str, default = "padding"
                type of the attention mask; {"padding", "causal", "no_mask"}
    rng_gen: torch.Generator, default = None
                random number generator;
                if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen

    Returns
    ----------
    o: torch.Tensor
                output tensor O, of the attention calculation; same data type as QKV;
                shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
    aux_ctx_tensors: List[torch.Tensor]
                auxiliary output tensors used for the backward;
                if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state]
                if is_training is False, aux_ctx_tensors = [rng_state]
                M: torch.Tensor
                    max(Q*K.T)
                    shape [batch_size, num_heads, max_seqlen, 1], dtype float32
                ZInv: torch.Tensor
                    1/sum(e^(x - max(x))), where x=Q*K.T
                    shape [batch_size, num_heads, max_seqlen, 1], dtype float32
                rng_state: torch.Tensor
                    state of the random number generator;
                    [seed, offset], dtype uint64
    """

    check_cu_seqlens(cu_seqlens)
    b = cu_seqlens.numel() - 1
    qkv_type = TORCH_DType[qkv_dtype]
    check_qkv(qkv, qkv_type)

    total_seqs = qkv.size(0)
    h = qkv.size(2)
    d = qkv.size(3)

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

191
192
193
194
195
196
197
    if bias_type != "no_bias":
        assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
        assert (bias.shape == [1, h, max_seqlen, max_seqlen]
               ), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
        assert (bias.dtype == qkv.dtype
               ), "bias tensor must be in the same dtype as qkv."

cyanguwa's avatar
cyanguwa committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    # FP8 fused attention API
    if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64):
        assert (qkv_layout == "qkv_interleaved"
                and bias_type == "no_bias"
                and attn_mask_type == "padding"
                ), """The FP8 fused attention API currently only supports qkv_interleaved layout,
                no_bias type, and padding attention mask type."""
        assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API."
        assert (q_scale_s is not None), "q_scale_s is required for the FP8 API."
        assert (q_scale_o is not None), "q_scale_o is required for the FP8 API."
        assert (amax_s is not None), "amax_s is required for the FP8 API."
        assert (amax_o is not None), "amax_o is required for the FP8 API."
        check_scalar(d_scale_qkv)
        check_scalar(q_scale_s)
        check_scalar(q_scale_o)
        check_scalar(amax_s)
        check_scalar(amax_o)

    # BF16/FP16 fused attention API from fmha_v2
    elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512):
        # add BF/FP16 support for >512 sequence length
        assert False, "The BF16/FP16 support for >512 sequence length is coming!"

    # BF16/FP16 fused attention API from fmha_v1 apex
    elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512):
        # add BF/FP16 support for <=512 sequence length
        assert False, "The BF16/FP16 support for <=512 sequence length is coming!"

    else:
        assert False, "No support for this dtype and max_seqlen combination."

    # execute kernel
    output_tensors = tex.fused_attn_fwd_qkvpacked(
            b, max_seqlen, total_seqs, h, d,
            is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
            cu_seqlens,
            qkv,
            qkv_dtype,
            d_scale_qkv,
            q_scale_s,
            q_scale_o,
            amax_s,
            amax_o,
            bias,
            rng_gen,
    )

    return output_tensors[0], output_tensors[1:]


def fused_attn_bwd_qkvpacked(
    max_seqlen: int,
    cu_seqlens: torch.Tensor,
    qkv: torch.Tensor,
    o: torch.Tensor,
    d_o: torch.Tensor,
    qkv_dtype: tex.DType,
    aux_ctx_tensors: List[torch.Tensor] = None,
    d_scale_qkv: torch.Tensor = None,
    d_scale_s: torch.Tensor = None,
    d_scale_o: torch.Tensor = None,
    d_scale_do: torch.Tensor = None,
    q_scale_s: torch.Tensor = None,
    q_scale_dp: torch.Tensor = None,
    q_scale_dqkv: torch.Tensor = None,
    amax_dp: torch.Tensor = None,
    amax_dqkv: torch.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
    qkv_layout: str = "qkv_interleaved",
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
    """Fused Attention BWD for packed QKV input.

    Parameters
    ----------
    max_seqlen: int
                max sequence length for QKV, used for padding; may be larger than max(cu_seqlens_q)
    cu_seqlens: torch.Tensor
                accumulative sequence lengths for QKV; shape [batch_size + 1]
    qkv: torch.Tensor
                input tensor QKV;
                shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
    o: torch.Tensor
                input tensor O (output of forward);
                shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
    d_o: torch.Tensor
                input tensor dO (gradient of O);
                shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
    qkv_dtype: tex.DType
                data type of QKV; in tex.DType, not torch.dtype
    aux_ctx_tensors: List[torch.Tensor]
                auxiliary output tensors of the forward pass when its is_training is True,
                e.g. aux_ctx_tensors = [M, ZInv, rng_state]
    d_scale_qkv: torch.Tensor, default = None
                input tensor for the dequantization of QKV in FP8 computations
    d_scale_s: torch.Tensor, default = None
                input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
    d_scale_o: torch.Tensor, default = None
                input tensor for the dequantization of O in FP8 computations
    d_scale_do: torch.Tensor, default = None
                input tensor for the dequantization of dO in FP8 computations
    q_scale_s: torch.Tensor, default = None
                input tensor for the quantization of S in FP8 computations
    q_scale_dp: torch.Tensor, default = None
                input tensor for the quantization of dP in FP8 computations, P = Q * K.T
    q_scale_dqkv: torch.Tensor, default = None
                input tensor for the quantization of dQKV in FP8 computations
    amax_dp: torch.Tensor, default = None
                output tensor, amax of dP, used by the next iteration in FP8 computations
    amax_dqkv: torch.Tensor, default = None
                output tensor, amax of dQKV, used by the next iteration in FP8 computations
    attn_scale: float, default = None
                if not None, use attn_scale as the attention scale for Q*K.T BMM;
                if None, use 1.0/sqrt(head_dim) as the default
    dropout: float, default = 0.0
                dropout probability, 0.0 means no dropout, 1.0 means no output;
                dropout must be 0.0 if is_training is False
    set_zero: bool, default = True
                if True, initializes the output tensor O to zero using the mha_fill method;
                if False, doesn't initialize O after its allocation
    qkv_layout: str, default = "qkv_interleaved"
                layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
    bias_type: str, default = "no_bias"
                type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
    attn_mask_type: str, default = "padding"
                type of the attention mask; {"padding", "causal", "no_mask"}

    Returns
    ----------
    d_qkv: torch.Tensor
                gradient tensor of QKV; same data type and shape as QKV
332
333
334
    d_bias: torch.Tensor, optional
                gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
                same data type and shape as Bias
cyanguwa's avatar
cyanguwa committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    """

    check_cu_seqlens(cu_seqlens)
    b = cu_seqlens.numel() - 1
    qkv_type = TORCH_DType[qkv_dtype]
    check_qkv(qkv, qkv_type)
    check_o(o, qkv_type)
    check_o(d_o, qkv_type)

    total_seqs = qkv.size(0)
    h = qkv.size(2)
    d = qkv.size(3)

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    assert (len(aux_ctx_tensors) >= 1
            ), "aux_ctx_tensors must contain rng_state as its last element."
    rng_state = aux_ctx_tensors[-1]
    check_rng_state(rng_state)

    # FP8 fused attention API
    if (qkv_type is torch.uint8) and (max_seqlen <= 512) and d == 64:
        assert (qkv_layout == "qkv_interleaved"
                and bias_type == "no_bias"
                and attn_mask_type == "padding"
                ), """The FP8 fused attention API currently only supports qkv_interleaved layout,
                no_bias type, and padding attention mask type."""
        assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API."
        assert (d_scale_s is not None), "d_scale_s is required for the FP8 API."
        assert (d_scale_o is not None), "d_scale_o is required for the FP8 API."
        assert (d_scale_do is not None), "d_scale_do is required for the FP8 API."
        assert (q_scale_s is not None), "q_scale_s is required for the FP8 API."
        assert (q_scale_dp is not None), "q_scale_dp is required for the FP8 API."
        assert (q_scale_dqkv is not None), "q_scale_dqkv is required for the FP8 API."
        assert (amax_dp is not None), "amax_dp is required for the FP8 API."
        assert (amax_dqkv is not None), "amax_dqkv is required for the FP8 API."
        assert (len(aux_ctx_tensors) == 3
                ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for the FP8 API."
        check_scalar(d_scale_qkv)
        check_scalar(d_scale_s)
        check_scalar(d_scale_o)
        check_scalar(d_scale_do)
        check_scalar(q_scale_s)
        check_scalar(q_scale_dp)
        check_scalar(q_scale_dqkv)
        check_scalar(amax_dp)
        check_scalar(amax_dqkv)
        m, z_inv = aux_ctx_tensors[:2]
        check_stats(m, b, h, max_seqlen)
        check_stats(z_inv, b, h, max_seqlen)

    # BF16/FP16 fused attention API from fmha_v2
    elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512):
        # add BF/FP16 support for >512 sequence length
        assert False, "The BF16/FP16 support for >512 sequence length is coming!"

    # BF16/FP16 fused attention API from fmha_v1 apex
    elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512):
        # add BF/FP16 support for <=512 sequence length
        assert False, "The BF16/FP16 support for <=512 sequence length is coming!"

    else:
        assert False, "No support for this dtype and max_seqlen combination."

    # execute kernel
    output_tensors = tex.fused_attn_bwd_qkvpacked(
            b, max_seqlen, total_seqs, h, d,
            attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
            cu_seqlens,
            qkv, o, d_o,
            qkv_dtype,
            aux_ctx_tensors,
            d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
            q_scale_s, q_scale_dp, q_scale_dqkv,
            amax_dp, amax_dqkv,
    )

413
414
415
416
417
    if bias_type == "no_bias":
        # return d_qkv when bias_type is no_bias
        return output_tensors[0]
    # otherwise return (d_qkv, d_bias)
    return output_tensors
cyanguwa's avatar
cyanguwa committed
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465


def fused_attn_fwd_kvpacked(
    is_training: bool,
    max_seqlen_q: int,
    max_seqlen_kv: int,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_kv: torch.Tensor,
    q: torch.Tensor,
    kv: torch.Tensor,
    qkv_dtype: tex.DType,
    bias: torch.Tensor = None,
    d_scale_qkv: torch.Tensor = None,
    q_scale_s: torch.Tensor = None,
    q_scale_o: torch.Tensor = None,
    amax_s: torch.Tensor = None,
    amax_o: torch.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
    qkv_layout: str = "qkv_interleaved",
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
    rng_gen: torch.Generator = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
    """Fused Attention FWD for packed KV input.

    Parameters
    ----------
    is_training: bool
                if True, runs training and produces auxiliary tensors aux_ctx_tensors
                for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
    max_seqlen_q: int
                max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q)
    max_seqlen_kv: int
                max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv)
    cu_seqlens_q: torch.Tensor
                accumulative sequence lengths for Q; shape [batch_size + 1]
    cu_seqlens_kv: torch.Tensor
                accumulative sequence lengths for KV; shape [batch_size + 1]
    q: torch.Tensor
                input tensor Q;
                shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
    kv: torch.Tensor
                packed input tensor KV;
                shape [total_seqs_kv, 2, num_heads, head_dim],
                where total_seqs_kv = cu_seqlens_kv[-1]
    qkv_dtype: tex.DType
466
                data type of Q and KV; in tex.DType, not torch.dtype
cyanguwa's avatar
cyanguwa committed
467
    bias: torch.Tensor, default = None
468
469
                input tensor Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
                shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
cyanguwa's avatar
cyanguwa committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    d_scale_qkv: torch.Tensor, default = None
                input tensor for the dequantization of QKV in FP8 computations
    q_scale_s: torch.Tensor, default = None
                input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
    q_scale_o: torch.Tensor, default = None
                input tensor for the quantization of O in FP8 computations
    amax_s: torch.Tensor, default = None
                output tensor, amax of S, used by the next iteration in FP8 computations
    amax_o: torch.Tensor, default = None
                output tensor, amax of O, used by the next iteration in FP8 computations
    attn_scale: float, default = None
                if not None, use attn_scale as the attention scale for Q*K.T BMM;
                if None, use 1.0/sqrt(head_dim) as the default
    dropout: float, default = 0.0
                dropout probability, 0.0 means no dropout, 1.0 means no output;
                dropout must be 0.0 if is_training is False
    set_zero: bool, default = True
                if True, initializes the output tensor O to zero using the mha_fill method;
                if False, doesn't initialize O after its allocation
    qkv_layout: str, default = "qkv_interleaved"
                layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
    bias_type: str, default = "no_bias"
                type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
    attn_mask_type: str, default = "padding"
                type of the attention mask; {"padding", "causal", "no_mask"}
    rng_gen: torch.Generator, default = None
                random number generator;
                if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen

    Returns
    ----------
    o: torch.Tensor
                output tensor O, of the attention calculation; same data type as QKV;
                shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1]
    aux_ctx_tensors: List[torch.Tensor]
                auxiliary output tensors used for the backward;
                if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state]
                if is_training is False, aux_ctx_tensors = [rng_state]
                M: torch.Tensor
                    max(Q*K.T)
                    shape [batch_size, num_heads, max_seqlen, 1], dtype float32
                ZInv: torch.Tensor
                    1/sum(e^(x - max(x))), where x=Q*K.T
                    shape [batch_size, num_heads, max_seqlen, 1], dtype float32
                rng_state: torch.Tensor
                    state of the random number generator;
                    [seed, offset], dtype uint64
    """

    check_cu_seqlens(cu_seqlens_q)
    check_cu_seqlens(cu_seqlens_kv)
    assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
            ), "cu_seqlens_q and cu_seqlens_kv must have the same length."
    b = cu_seqlens_q.numel() - 1
    qkv_type = TORCH_DType[qkv_dtype]
    check_q(q, qkv_type)
    check_kv(kv, qkv_type)

    assert (q.size(1) == kv.size(2)
            and q.size(2) == kv.size(3)
            ), "Q and KV must have the same num_heads and head_dim."
    total_seqs_q = q.size(0)
    total_seqs_kv = kv.size(0)
    h = q.size(1)
    d = q.size(2)

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

539
540
541
542
543
544
545
    if bias_type != "no_bias":
        assert bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
        assert (bias.shape == [1, h, max_seqlen_q, max_seqlen_kv]
               ), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
        assert (bias.dtype == q.dtype
               ), "bias tensor must be in the same dtype as q and kv."

cyanguwa's avatar
cyanguwa committed
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
    # FP8 fused attention API
    if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \
            and (d == 64):
        assert False, "The FP8 fused attention API currently only supports packed QKV input."

    # BF16/FP16 fused attention API from fmha_v2
    elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
            and (max_seqlen_q > 512) and (max_seqlen_kv > 512):
        # add BF/FP16 support for >512 sequence length
        assert False, "The BF16/FP16 support for >512 sequence length is coming!"

    # BF16/FP16 fused attention API from fmha_v1 apex
    elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
            and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512):
        # add BF/FP16 support for <=512 sequence length
        assert False, "The BF16/FP16 support for <=512 sequence length is coming!"

    else:
        assert False, "No support for this dtype and max_seqlen combination."

    # execute kernel
    output_tensors = tex.fused_attn_fwd_kvpacked(
            b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
            is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
            cu_seqlens_q, cu_seqlens_kv,
            q, kv,
            qkv_dtype,
            d_scale_qkv,
            q_scale_s,
            q_scale_o,
            amax_s,
            amax_o,
            bias,
            rng_gen,
    )

    return output_tensors[0], output_tensors[1:]


def fused_attn_bwd_kvpacked(
    max_seqlen_q: int,
    max_seqlen_kv: int,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_kv: torch.Tensor,
    q: torch.Tensor,
    kv: torch.Tensor,
    o: torch.Tensor,
    d_o: torch.Tensor,
    qkv_dtype: tex.DType,
    aux_ctx_tensors: List[torch.Tensor] = None,
    d_scale_qkv: torch.Tensor = None,
    d_scale_s: torch.Tensor = None,
    d_scale_o: torch.Tensor = None,
    d_scale_do: torch.Tensor = None,
    q_scale_s: torch.Tensor = None,
    q_scale_dp: torch.Tensor = None,
    q_scale_dqkv: torch.Tensor = None,
    amax_dp: torch.Tensor = None,
    amax_dqkv: torch.Tensor = None,
    attn_scale: float = None,
    dropout: float = 0.0,
    set_zero: bool = True,
    qkv_layout: str = "qkv_interleaved",
    bias_type: str = "no_bias",
    attn_mask_type: str = "padding",
) -> Tuple[Union[torch.Tensor, None], ...]:
    """Fused Attention BWD for packed KV input.

    Parameters
    ----------
    max_seqlen_q: int
                max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q)
    max_seqlen_kv: int
                max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv)
    cu_seqlens_q: torch.Tensor
                accumulative sequence lengths for Q; shape [batch_size + 1]
    cu_seqlens_kv: torch.Tensor
                accumulative sequence lengths for KV; shape [batch_size + 1]
    q: torch.Tensor
                input tensor Q;
                shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
    kv: torch.Tensor
                packed input tensor KV;
                shape [total_seqs_kv, 2, num_heads, head_dim],
                where total_seqs_kv = cu_seqlens_kv[-1]
    o: torch.Tensor
                input tensor O (output of forward);
                shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
    d_o: torch.Tensor
                input tensor dO (gradient of O);
                shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1]
    qkv_dtype: tex.DType
                data type of QKV; in tex.DType, not torch.dtype
    aux_ctx_tensors: List[torch.Tensor]
                auxiliary output tensors of the forward pass when its is_training is True,
                e.g. aux_ctx_tensors = [M, ZInv, rng_state]
    d_scale_qkv: torch.Tensor, default = None
                input tensor for the dequantization of QKV in FP8 computations
    d_scale_s: torch.Tensor, default = None
                input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
    d_scale_o: torch.Tensor, default = None
                input tensor for the dequantization of O in FP8 computations
    d_scale_do: torch.Tensor, default = None
                input tensor for the dequantization of dO in FP8 computations
    q_scale_s: torch.Tensor, default = None
                input tensor for the quantization of S in FP8 computations
    q_scale_dp: torch.Tensor, default = None
                input tensor for the quantization of dP in FP8 computations, P = Q * K.T
    q_scale_dqkv: torch.Tensor, default = None
                input tensor for the quantization of dQKV in FP8 computations
    amax_dp: torch.Tensor, default = None
                output tensor, amax of dP, used by the next iteration in FP8 computations,
                P = Q * K.T
    amax_dqkv: torch.Tensor, default = None
                output tensor, amax of dQKV, used by the next iteration in FP8 computations
    attn_scale: float, default = None
                if not None, use attn_scale as the attention scale for Q*K.T BMM;
                if None, use 1.0/sqrt(head_dim) as the default
    dropout: float, default = 0.0
                dropout probability, 0.0 means no dropout, 1.0 means no output;
                dropout must be 0.0 if is_training is False
    set_zero: bool, default = True
                if True, initializes the output tensor O to zero using the mha_fill method;
                if False, doesn't initialize O after its allocation
    qkv_layout: str, default = "qkv_interleaved"
                layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"}
    bias_type: str, default = "no_bias"
                type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"}
    attn_mask_type: str, default = "padding"
                type of the attention mask; {"padding", "causal", "no_mask"}

    Returns
    ----------
    d_q: torch.Tensor
                gradient tensor of Q; same data type and shape as Q
    d_kv: torch.Tensor
                gradient tensor of KV; same data type and shape as KV
683
684
685
    d_bias: torch.Tensor, optional
                gradient tensor of Bias when bias_type is "pre_scale_bias" or "post_scale_bias";
                same data type and shape as Bias
cyanguwa's avatar
cyanguwa committed
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
    """

    check_cu_seqlens(cu_seqlens_q)
    check_cu_seqlens(cu_seqlens_kv)
    assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel()
            ), "cu_seqlens_q and cu_seqlens_kv must have the same length."
    b = cu_seqlens_q.numel() - 1
    qkv_type = TORCH_DType[qkv_dtype]
    check_q(q, qkv_type)
    check_kv(kv, qkv_type)
    check_o(o, qkv_type)
    check_o(d_o, qkv_type)

    assert (q.size(1) == kv.size(2)
            and q.size(2) == kv.size(3)
            ), "Q and KV must have the same num_heads and head_dim."
    total_seqs_q = q.size(0)
    total_seqs_kv = q.size(0)
    h = q.size(1)
    d = q.size(2)

    if attn_scale is None:
        attn_scale = 1.0 / math.sqrt(d)

    assert (len(aux_ctx_tensors) >= 1
            ), "aux_ctx_tensors must contain rng_state as its last element."
    rng_state = aux_ctx_tensors[-1]
    check_rng_state(rng_state)

    # FP8 fused attention API
    if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \
            and d == 64:
        assert False, "The FP8 fused attention API currently only supports packed QKV input."

    ############### BF16/FP16 fused attention API from fmha_v2 ################
    elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
            and (max_seqlen_q > 512) and (max_seqlen_kv > 512):
        # add BF/FP16 support for >512 sequence length
        assert False, "The BF16/FP16 support for >512 sequence length is coming!"

    ############### BF16/FP16 fused attention API from fmha_v1 apex ################
    elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \
            and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512):
        # add BF/FP16 support for <=512 sequence length
        assert False, "The BF16/FP16 support for <=512 sequence length is coming!"

    else:
        assert False, "No support for this dtype and max_seqlen combination."

    # execute kernel
    output_tensors = tex.fused_attn_bwd_kvpacked(
            b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d,
            attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type,
            cu_seqlens_q, cu_seqlens_kv,
            q, kv, o, d_o,
            qkv_dtype,
            aux_ctx_tensors,
            d_scale_qkv, d_scale_s, d_scale_o, d_scale_do,
            q_scale_s, q_scale_dp, q_scale_dqkv,
            amax_dp, amax_dqkv,
    )

748
749
750
    # returns (d_q, d_kv) when bias_type is no_bias; otherwise returns (d_q, d_kv, d_bias)
    if bias_type == "no_bias":
        return output_tensors[:2]
cyanguwa's avatar
cyanguwa committed
751
    return output_tensors
Przemek Tredak's avatar
Przemek Tredak committed
752
753
754
755

def fp8_gemm(
    A: torch.Tensor,
    A_scale_inv: torch.Tensor,
756
    A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
Przemek Tredak's avatar
Przemek Tredak committed
757
758
759
    A_dtype: tex.DType,
    B: torch.Tensor,
    B_scale_inv: torch.Tensor,
760
    B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
Przemek Tredak's avatar
Przemek Tredak committed
761
762
763
    B_dtype: tex.DType,
    out_dtype: torch.dtype,
    workspace: torch.Tensor,
764
    gelu: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
765
766
    accumulate: bool = False,
    out: Optional[torch.Tensor] = None,
767
768
    out_index = None,
    fp8_meta_tensor: tex.FP8TensorMeta = None,
Przemek Tredak's avatar
Przemek Tredak committed
769
770
771
    bias: Optional[torch.Tensor] = None,
    use_bias: bool = False,
    use_split_accumulator: bool = False,
772
    D_dtype: Optional[tex.DType] = None,
773
774
775
    ub_algo: tex.UbufOverlapAlgo = None,
    ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None,
    extra_output_tensor: torch.Tensor = None,
Przemek Tredak's avatar
Przemek Tredak committed
776
777
778
779
) -> torch.Tensor:
    """TN layout GEMM with fp8 inputs."""

    empty_tensor = torch.Tensor()
780
781
    if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
        assert fp8_meta_tensor is not None and out_index is not None
Przemek Tredak's avatar
Przemek Tredak committed
782
783
784
785
786
787

    return_output = False
    if out is None:
        out = torch.empty(
            B.shape[0],
            A.shape[0],
788
            dtype=out_dtype,
Przemek Tredak's avatar
Przemek Tredak committed
789
790
791
            device="cuda",
        )
        return_output = True
792
793
794
795
796
797
798
    # Use bfloat16 as default bias_dtype
    bias_dtype = torch.bfloat16 if bias is None else bias.dtype
    if gelu:
        gelu_input = torch.empty_like(out, dtype=bias_dtype)
    else:
        gelu_input = empty_tensor
    bias_dtype = TE_DType[bias_dtype]
Przemek Tredak's avatar
Przemek Tredak committed
799

800
    out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
Przemek Tredak's avatar
Przemek Tredak committed
801

802
    args = (
Przemek Tredak's avatar
Przemek Tredak committed
803
804
        A,
        A_scale_inv,
805
        A_fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
806
807
808
809
        A_dtype,
        True,  # transa
        B,
        B_scale_inv,
810
        B_fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
811
812
813
        B_dtype,
        False,  # transb
        out,
814
        empty_tensor if out_index is None else fp8_meta_tensor.scale[out_index],
Przemek Tredak's avatar
Przemek Tredak committed
815
        out_dtype,
816
        empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index],
Przemek Tredak's avatar
Przemek Tredak committed
817
        bias if use_bias else empty_tensor,
818
        bias_dtype,
819
        gelu_input,  # this is pre_gelu_out
Przemek Tredak's avatar
Przemek Tredak committed
820
821
822
823
        False,  # grad
        workspace,
        workspace.shape[0],
        accumulate,
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
        use_split_accumulator)
    fn = torch.ops.tex_ts.te_gemm_ts
    if ub_algo is not None:
        assert ub is not None, 'ub object is None!'
        if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
            fn = ub.bulk_overlap
            args = tuple(args + (1,))
        elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
            fn = ub.bulk_overlap
            args = tuple(args + (0,))
        elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
            fn = ub.split_overlap_ag
            extra_output_tensor = (
                empty_tensor if extra_output_tensor is None else extra_output_tensor
            )
            args = tuple(args + (extra_output_tensor,))
        elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
            fn = ub.split_overlap_rs
            assert (
                extra_output_tensor is not None
            ), 'SPLIT_PIPELINED_RS requires extra output tensor'
            args = tuple(args + (True, extra_output_tensor,))
    _ = fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
847
848

    if return_output:
849
850
        if gelu:
            return out, gelu_input
Przemek Tredak's avatar
Przemek Tredak committed
851
        return out
852
853
    if gelu:
        return gelu_input
Przemek Tredak's avatar
Przemek Tredak committed
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
    return None


def gemm(
    A: torch.Tensor,
    B: torch.Tensor,
    dtype: torch.dtype,
    workspace: torch.Tensor,
    gelu: bool = False,
    gelu_input: Optional[torch.Tensor] = None,
    grad: bool = False,
    accumulate: bool = False,
    layout: str = "TN",
    out: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    use_bias: bool = False,
870
871
872
    ub_algo: tex.UbufOverlapAlgo = None,
    ub: tex.UbufCommOverlap = None,
    extra_output_tensor: torch.Tensor = None,
Przemek Tredak's avatar
Przemek Tredak committed
873
874
875
876
877
878
879
) -> Tuple[Union[torch.Tensor, None], ...]:
    """Non FP8 GEMM."""

    assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
    transa = layout[0] == "T"
    transb = layout[1] == "T"
    empty_tensor = torch.Tensor()
880
    fp8_index = -1 # dummy index
Przemek Tredak's avatar
Przemek Tredak committed
881
882
883
884
885
886

    return_output = False
    if out is None:
        out = torch.empty(
            B.shape[1] if transb else B.shape[0],
            A.shape[0] if transa else A.shape[1],
887
            dtype=dtype,
Przemek Tredak's avatar
Przemek Tredak committed
888
889
890
891
892
893
894
895
896
897
            device="cuda",
        )
        return_output = True

    if gelu and not grad:
        gelu_input = torch.empty_like(out, dtype=dtype)
    elif not gelu:
        gelu_input = empty_tensor

    if grad and use_bias:
898
        grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda")
Przemek Tredak's avatar
Przemek Tredak committed
899
900
901
902
903
    else:
        grad_bias = empty_tensor

    bias = bias if use_bias else empty_tensor

904
905
906
907
908
909
910
911
912
    assert A.dtype == dtype and B.dtype == dtype, \
        f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
    input_dtype = TE_DType[dtype]
    output_dtype = TE_DType[out.dtype]
    if use_bias:
        bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
    else:
        bias_dtype = output_dtype

913
    args = (
Przemek Tredak's avatar
Przemek Tredak committed
914
915
        A,
        empty_tensor,
916
        fp8_index,
Przemek Tredak's avatar
Przemek Tredak committed
917
918
919
920
        input_dtype,
        transa,
        B,
        empty_tensor,
921
        fp8_index,
Przemek Tredak's avatar
Przemek Tredak committed
922
923
924
        input_dtype,
        transb,
        out,
925
        empty_tensor, # out_scale
Przemek Tredak's avatar
Przemek Tredak committed
926
        output_dtype,
927
        empty_tensor, # out_amax
Przemek Tredak's avatar
Przemek Tredak committed
928
        grad_bias if grad else bias,
929
        bias_dtype,
Przemek Tredak's avatar
Przemek Tredak committed
930
931
932
933
934
935
936
        gelu_input,
        grad,
        workspace,
        workspace.shape[0],
        accumulate,
        False,  # use_split_accumulator
    )
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
    fn = torch.ops.tex_ts.te_gemm_ts
    if ub_algo is not None:
        assert ub is not None, 'ub object is None!'
        if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
            fn = ub.bulk_overlap
            args = tuple(args + (1,))
        elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
            fn = ub.bulk_overlap
            args = tuple(args + (0,))
        elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
            fn = ub.split_overlap_ag
            extra_output_tensor = (
                empty_tensor if extra_output_tensor is None else extra_output_tensor
            )
            args = tuple(args + (extra_output_tensor,))
        elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
            fn = ub.split_overlap_rs
            assert (
                extra_output_tensor is not None
            ), 'SPLIT_PIPELINED_RS requires extra output tensor'
            args = tuple(args + (False, extra_output_tensor,))
    _ = fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976

    if return_output:
        return out, grad_bias, gelu_input
    return None, grad_bias, gelu_input


def fp8_cast_transpose_fused(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
    cast_out: Optional[torch.Tensor] = None,
    transpose_out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], None]:
    """Cast + Transpose with FP8 output"""

    return_outputs = False
    if cast_out is None or transpose_out is None:
cyanguwa's avatar
cyanguwa committed
977
        cast_out = torch.empty_like(inp, dtype=torch.uint8)
Przemek Tredak's avatar
Przemek Tredak committed
978
        transpose_out = torch.empty(
cyanguwa's avatar
cyanguwa committed
979
            inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8
Przemek Tredak's avatar
Przemek Tredak committed
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
        )
        return_outputs = True

    tex.fused_cast_transpose(
        inp,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        cast_out,
        transpose_out,
        otype,
    )

    if return_outputs:
        return cast_out, transpose_out
    return None


def fp8_cast_transpose_bgrad_fused(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Cast + Transpose + BGRAD with FP8 output"""
    return tex.fused_cast_transpose_bgrad(
        inp,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
    )


1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
def fp8_transpose_bgrad_fused(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
    grad_bias_type: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Transpose + BGRAD with FP8 output"""
    return tex.fused_fp8_transpose_bgrad(
        inp,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
        TE_DType[grad_bias_type],
    )


Przemek Tredak's avatar
Przemek Tredak committed
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
def fp8_cast_transpose_bgrad_dgelu_fused(
    grad_output: torch.Tensor,
    gelu_input: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Cast + Transpose + BGRAD + DGELU with FP8 output"""
    return tex.fused_cast_transpose_bgrad_dgelu(
        grad_output,
        gelu_input,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
    )


def fp8_gelu(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
) -> torch.Tensor:
    """GeLU with FP8 output"""
1057
    return torch.ops.tex_ts.fp8_gelu_ts(
Przemek Tredak's avatar
Przemek Tredak committed
1058
        inp,
1059
1060
1061
1062
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
        otype,
    )


def layernorm_fwd_fp8(
    inp: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
1075
    sm_margin: int,
1076
1077
    zero_centered_gamma: bool,
    ln_out: Optional[torch.Tensor] = None,
Przemek Tredak's avatar
Przemek Tredak committed
1078
1079
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """LayerNorm with FP8 output"""
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    if ln_out is not None:
        return tex.layernorm_fwd_fp8_noalloc(
            inp,
            weight,
            bias,
            eps,
            fp8_meta_tensor.scale[fp8_tensor],
            ln_out,
            fp8_meta_tensor.amax_history[0][fp8_tensor],
            fp8_meta_tensor.scale_inv[fp8_tensor],
            otype,
            sm_margin,
            zero_centered_gamma
        )

Przemek Tredak's avatar
Przemek Tredak committed
1095
1096
1097
1098
1099
1100
1101
1102
1103
    return tex.layernorm_fwd_fp8(
        inp,
        weight,
        bias,
        eps,
        fp8_meta_tensor.scale[fp8_tensor],
        fp8_meta_tensor.amax_history[0][fp8_tensor],
        fp8_meta_tensor.scale_inv[fp8_tensor],
        otype,
1104
        sm_margin,
1105
        zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
1106
1107
1108
    )


1109
1110
1111
1112
1113
1114
1115
1116
def layernorm_fwd_fp8_inf(
    inp: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
1117
    zero_centered_gamma,
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
) -> torch.Tensor:
    """LayerNorm with FP8 output.

    This version of layernorm_fwd_fp8 is specialized for inference, and returns
    only the normalized output.
    """
    ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
        inp,
        weight,
        bias,
        eps,
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
1133
1134
        otype,
        zero_centered_gamma)
1135
1136
1137
1138
1139
1140
1141
1142
    return ret


def layernorm_fwd_inf(
    inp: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    eps: float,
1143
    zero_centered_gamma: bool,
1144
1145
1146
1147
1148
1149
1150
) -> torch.Tensor:
    """LayerNorm with FP8 output"""
    return torch.ops.tex_ts.layernorm_fwd_inf_ts(
        inp,
        weight,
        bias,
        eps,
1151
        zero_centered_gamma,
1152
1153
1154
    )


Przemek Tredak's avatar
Przemek Tredak committed
1155
1156
1157
1158
1159
def cast_to_fp8(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    otype: tex.DType,
1160
1161
    out: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
Przemek Tredak's avatar
Przemek Tredak committed
1162
    """Cast input to FP8"""
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173

    if out is not None:
        tex.cast_to_fp8_noalloc(
            inp,
            fp8_meta_tensor.scale[fp8_tensor],
            out,
            fp8_meta_tensor.amax_history[0][fp8_tensor],
            fp8_meta_tensor.scale_inv[fp8_tensor],
            otype
        )
        return None
1174
    return torch.ops.tex_ts.cast_to_fp8_ts(
Przemek Tredak's avatar
Przemek Tredak committed
1175
        inp,
1176
1177
1178
1179
        fp8_meta_tensor.scale,
        fp8_meta_tensor.amax_history,
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
        otype,
    )


def cast_from_fp8(
    inp: torch.Tensor,
    fp8_meta_tensor: tex.FP8TensorMeta,
    fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
    itype: tex.DType,
    otype: tex.DType,
) -> torch.Tensor:
    """Cast input from FP8"""
1192
    return torch.ops.tex_ts.cast_from_fp8_ts(
Przemek Tredak's avatar
Przemek Tredak committed
1193
        inp,
1194
1195
        fp8_meta_tensor.scale_inv,
        fp8_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
1196
1197
1198
        itype,
        otype,
    )