test_fused_attn.py 27.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch
import pytest

from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
11
    get_device_compute_capability,
12
)
13
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
14
15
16
17
from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention
import os

18
19
from pkg_resources import packaging
from importlib.metadata import version
20
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
21
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
22
23
24
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class ModelConfig:
    def __init__(
        self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len,
        dropout_p, attn_mask_type,
    ):
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.head_dim = head_dim
        assert (hidden_size == num_attention_heads * head_dim
                ), """hidden_size must be = num_heads x head_dim."""
        self.seq_len = seq_len
        self.dropout_p = dropout_p
        self.attn_mask_type  = attn_mask_type

model_configs = {
    "test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"),
42
43
44
45
46
47
48
    "test2": ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal"),
    "test3": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
    "test4": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
    "test5": ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal"),
    "test6": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
    "test7": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
    "test8": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
49
50
51
52
53
54
}

param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
    param_types.append(torch.bfloat16)

55
batch_sizes = [1, 2, 32]
56

57
58
@pytest.mark.skipif(
    get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
59
60
61
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
62
63
64
@pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
65
66
67
68
69
    """Test DotProductAttention module with three backends,
    FlashAttention, FusedAttention and UnfusedDotProductAttention"""

    config = model_configs[model]

70
71
72
    if bias_type == "no_bias":
        flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
                dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
73
    fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
74
            dtype, bs, config, "FusedAttention", ckpt_attn, bias_type)
75
    unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
76
            dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type)
77

78
79
    atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3)
    if bias_type == "no_bias":
80
81
82
83
        assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
        assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
    assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
    assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
84

85
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
86

87
    reset_rng_states()
88
89
90
91
92
93
94
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

95
    inp = torch.randn(
96
            config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
97
            dtype=dtype).cuda()
98
    inp.requires_grad=True
99
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
100
    seqlens.fill_(config.seq_len)
101
102
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
103
104
105
106
107
    op_grad = torch.randn(
        config.seq_len, bs, config.num_attention_heads * config.head_dim,
        dtype = dtype).cuda()
    if bias_type != "no_bias":
        bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
108
                dtype=dtype).cuda()
109
110
    else:
        bias = None
111
112
113
114
115

    block = (
         DotProductAttention(
                config.num_attention_heads,
                config.head_dim,
116
117
118
119
120
121
122
123
                attention_dropout=config.dropout_p,
                sequence_parallel=False,
                tp_size=1,
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
                tp_group=None,
                layer_number=1,
                attention_type="self"
        ).to(dtype=dtype).cuda()
124
125
126
127
128
    )

    q = inp[:, :,0,:,:]
    k = inp[:, :,1,:,:]
    v = inp[:, :,2,:,:]
129
130
131
132
    op = block(q, k, v, attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=bias_type,
        core_attention_bias=bias)
133
134
135
136
    op.backward(op_grad)

    return op, inp.grad

137
138
@pytest.mark.skipif(
    get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
139
140
141
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
142
143
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
cyanguwa's avatar
cyanguwa committed
144
145
@pytest.mark.parametrize("fused_qkv_params", [True, False])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_params):
146
147
148
149
150
    """Test TransformerLayer module when its DotProductAttention is enabled with
    FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""

    config = model_configs[model]

151
152
    if bias_type == "no_bias":
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
cyanguwa's avatar
cyanguwa committed
153
                dtype, bs, config, "FlashAttention", ckpt_attn, bias_type, fused_qkv_params)
154
    fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
cyanguwa's avatar
cyanguwa committed
155
            dtype, bs, config, "FusedAttention", ckpt_attn, bias_type, fused_qkv_params)
156
    unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
cyanguwa's avatar
cyanguwa committed
157
            dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type, fused_qkv_params)
158

159
160
    atol, rtol = (5e-1, 5e-2)
    if bias_type == "no_bias":
161
162
163
164
        assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
        assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
    assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
    assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
165

cyanguwa's avatar
cyanguwa committed
166
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params):
167

168
    reset_rng_states()
169
170
171
172
173
174
175
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

176
    inp = torch.randn(
177
            config.seq_len, bs, config.num_attention_heads * config.head_dim,
178
            dtype=dtype).cuda()
179
    inp.requires_grad=True
180
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
181
    seqlens.fill_(config.seq_len)
182
183
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
184
185
186
187
188
189
190
191
192

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
    drop_path_rates = [
            rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
193
194
    if bias_type != "no_bias":
        bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
195
                dtype=dtype).cuda()
196
197
    else:
        bias = None
198
199
200
201
202
203

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
            layernorm_epsilon=1e-5,
            hidden_dropout=0.0,
            attention_dropout=config.dropout_p,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            layer_number=layer_number,
            kv_channels=config.head_dim,
            tp_group=None,
            tp_size=1,
            params_dtype=dtype,
            get_rng_state_tracker=None,
            fuse_wgrad_accumulation=False,
            seq_length=config.seq_len,
            micro_batch_size=bs,
            sequence_parallel=False,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="encoder",
            drop_path_rate=drop_path_rates[layer_number - 1],
            set_parallel_mode=True,
cyanguwa's avatar
cyanguwa committed
224
            fuse_qkv_params=fused_qkv_params,
225
226
227
228
            zero_centered_gamma=False,
            qkv_weight_interleaved=False,
            ub_tp_comm_overlap=False,
            bias=True,
229
        )
230
        .to(dtype=dtype)
231
232
233
        .cuda()
    )

234
235
    num_iters = 10
    for i in range(num_iters):
236
237
238
239
        op = block(inp, self_attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=ckpt_attn,
            core_attention_bias_type=bias_type,
            core_attention_bias=bias)
240
241
        loss = op.sum()
        loss.backward()
242
243
244

    return op, inp.grad

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available")
@pytest.mark.skipif(
    get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer_gqa(dtype, bs, model):
    """Test TransformerLayer module when its DotProductAttention is enabled with
    FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""

    config = model_configs[model]
    def find_factors(x):
       f = []
       for i in range(1, x + 1):
           if x % i == 0:
               f.append(i)
       return f

    num_querys_per_gqa_group = find_factors(config.num_attention_heads)

    for num_q_per_gqa_group in num_querys_per_gqa_group:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer_gqa(
                dtype, bs, config, "FlashAttention", num_q_per_gqa_group)
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa(
                dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)

271
        atol, rtol = 5e-1, 5e-2
272
273
        assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
        assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
274
275
276

def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):

277
    reset_rng_states()
278
279
280
281
    os.environ["NVTE_FLASH_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"

282
    inp = torch.randn(
283
            config.seq_len, bs, config.num_attention_heads * config.head_dim,
284
            dtype=dtype).cuda()
285
    inp.requires_grad=True
286
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
287
    seqlens.fill_(config.seq_len)
288
289
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
290
291
    op_grad = torch.randn(
        config.seq_len, bs, config.num_attention_heads * config.head_dim,
292
        dtype=dtype).cuda()
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
    drop_path_rates = [
            rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
            num_gqa_groups=config.num_attention_heads / num_querys_per_gqa_group,
            layernorm_epsilon=1e-5,
            hidden_dropout=0.0,
            attention_dropout=config.dropout_p,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            layer_number=layer_number,
            kv_channels=config.head_dim,
            tp_group=None,
            tp_size= 1,
            params_dtype=dtype,
            get_rng_state_tracker=None,
            fuse_wgrad_accumulation=False,
            seq_length=config.seq_len,
            micro_batch_size=bs,
            sequence_parallel=False,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="encoder",
            drop_path_rate=drop_path_rates[layer_number - 1],
            set_parallel_mode=True,
            fuse_qkv_params=True,
            zero_centered_gamma=False,
            qkv_weight_interleaved=False,
            ub_tp_comm_overlap=False,
            bias=True,
334
        )
335
        .to(dtype=dtype)
336
337
338
        .cuda()
    )

339
    op = block(inp, self_attn_mask_type=config.attn_mask_type)
340
341
342
343
    op.backward(op_grad)

    return op, inp.grad

344
345
346
347
348
349
model_configs_fp8 = {
    "test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
}
batch_sizes_fp8 = [1, 4]
param_types_fp8 = [torch.float16]

350
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
351
352
353
354
355
356
357
358
359
360
361
362
363
364
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, bs, model):
    """Test DotProductAttention module with FP8,
    using cpp_extensions import fused_attn_fwd/bwd_qkvpacked and UnfusedDotProductAttention"""

    config = model_configs_fp8[model]

    fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
            dtype, bs, config, "FusedAttention")
    unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
            dtype, bs, config, "UnfusedDotProductAttention")

365
    atol, rtol = (2.5e-2, 2.5e-2)
366
367
    assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
    assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
368
369
370

def _run_dpa_fp8(dtype, bs, config, backend):

371
    reset_rng_states()
372
373
374
375
376
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"

    inp = 0.01 * torch.randn(
            bs * config.seq_len, config.num_attention_heads * config.head_dim,
377
            dtype=dtype).cuda()
378
    inp.requires_grad=True
379
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
380
    seqlens.fill_(config.seq_len)
381
382
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
383
384
    op_grad = 0.01 * torch.randn(
        bs * config.seq_len, config.num_attention_heads * config.head_dim,
385
        dtype=dtype).cuda()
386
387
388
389
390
391
392
393
394
395
    torch.save(op_grad, 'op_grad.pt')

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        interval=1,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

396
    dpa = DPA_FP8(config).to(dtype=torch.float16).cuda()
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        op = dpa(inp, cu_seqlens, config.seq_len)
        op.backward(op_grad)

    context = torch.load("ctx.pt")
    dqkv = torch.load('dqkv.pt')
    return (context.view(bs, config.seq_len, -1).transpose(0,1),
        dqkv.view(bs, config.seq_len, 3, config.num_attention_heads, config.head_dim).transpose(0,1).contiguous())

def _run_dpa_fp8_ref(dtype, bs, config, backend):

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

    inp = torch.load('qkv.pt').cuda()
    inp.requires_grad=True
417
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
418
    seqlens.fill_(config.seq_len)
419
420
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
421
422
423
424
425
426
    op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)

    block = (
         DotProductAttention(
                config.num_attention_heads,
                config.head_dim,
427
428
429
430
431
432
433
434
                attention_dropout=config.dropout_p,
                sequence_parallel=False,
                tp_size=1,
                get_rng_state_tracker=None,
                tp_group=None,
                layer_number=1,
                attention_type="self"
        ).to(dtype=dtype).cuda()
435
436
437
438
439
    )

    q = inp[:, :,0,:,:]
    k = inp[:, :,1,:,:]
    v = inp[:, :,2,:,:]
440
    op = block(q, k, v, attn_mask_type=config.attn_mask_type)
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    op.backward(op_grad)
    torch.save(op,'ctx_ref.pt')
    torch.save(inp.grad,'dqkv_ref.pt')

    return op, inp.grad

from torch.nn.parameter import Parameter
import transformer_engine.pytorch.cpp_extensions as ext
import transformer_engine_extensions as tex
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch import fp8_autocast
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule, _prepare_backward
from transformer_engine.common import recipe
from typing import Union, Dict, Any, Tuple, List
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    FusedAttnBackend)

_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

META_QKV  = tex.FP8FwdTensors.GEMM1_OUTPUT
META_O    = tex.FP8FwdTensors.GEMM2_INPUT
META_DO   = tex.FP8BwdTensors.GRAD_INPUT2
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1

META_S    = tex.FP8FwdTensors.GEMM3_WEIGHT
META_DS   = tex.FP8BwdTensors.GRAD_INPUT3

class _dpa_fp8(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
        num_attention_heads: int,
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        workspace: torch.Tensor,
        is_training: bool,
    ) -> torch.Tensor:

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
        h = num_attention_heads
        d = in_features // h
        b = cu_seqlens.numel() - 1
        is_nl = False
        if b < 4 and b > 1:
            max_s = 512
            is_nl = True

        fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

        inputmat, inputmat_t = ext.fp8_cast_transpose_fused(
            inp,
            fp8_meta["scaling_fwd"],
            tex.FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype_forward,
        )

        qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused(
            qkv_weight,
            fp8_meta["scaling_fwd"],
            tex.FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype_forward,
        )

        M = None
        ZInv = None
        philox_unpacked = None

        qkv_out = ext.fp8_gemm(
            qkv_weight_fp8,
            fp8_meta["scaling_fwd"].scale_inv,
            tex.FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype_forward,
            inputmat,
            fp8_meta["scaling_fwd"].scale_inv,
            tex.FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype_forward,
            torch.uint8,
            workspace,
            bias=qkv_bias,
            use_bias=True,
533
534
            out_index=META_QKV,
            fp8_meta_tensor=fp8_meta["scaling_fwd"],
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
            use_split_accumulator=_2X_ACC_FPROP,
            D_dtype=fp8_dtype_forward,
        )
        qkv_out = qkv_out.view(-1, 3, h, d)
        qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"],
                META_QKV, fp8_dtype_forward,
                tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous()
        torch.save(qkv_out_fp16, 'qkv.pt')

        # FMHA
        context_, aux_ctx_tensors, *rest = fused_attn_fwd_qkvpacked(
                is_training,
                max_s,
                cu_seqlens,
                qkv_out,
                fp8_dtype_forward,
                FusedAttnBackend["FP8"],
                None,
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
558
559
560
561
562
563
564
                attn_scale=None,
                dropout=p_dropout,
                fast_zero_fill=fast_zero_fill,
                qkv_layout="qkv_interleaved",
                attn_bias_type="no_bias",
                attn_mask_type="padding",
                rng_gen=None,
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
                )
        M, ZInv, philox_unpacked = aux_ctx_tensors

        context = context_.view(-1, in_features)
        context_t = tex.fp8_transpose(context, fp8_dtype_forward)

        ctx.save_for_backward(
            inputmat_t, qkv_weight_t_fp8, workspace,
            qkv_out,
            context_, context_t,
            fp8_meta["scaling_fwd"].scale,
            fp8_meta["scaling_fwd"].scale_inv,
        )
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.is_nl = is_nl
        ctx.hidden_size = in_features
        ctx.num_attention_heads = num_attention_heads

        context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"],
                META_O, fp8_dtype_forward, tex.DType.kFloat16)
        torch.save(context_fp16, 'ctx.pt')
        return context_fp16


    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:

        with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
            (
                inputmat_t,
                qkv_weight_t_fp8,
                workspace,
                qkv_out,
                context, context_t,
                fwd_scales,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
            fp8_dtype_forward = fp8.get_fp8_te_dtype(
                ctx.fp8_meta["recipe"], fprop_tensor=True
            )
            fp8_dtype_backward = fp8.get_fp8_te_dtype(
                ctx.fp8_meta["recipe"], fprop_tensor=False
            )

            proj_dgrad = ext.cast_to_fp8(
                grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
            )

            dqkv, *rest = fused_attn_bwd_qkvpacked(
                    ctx.max_s,
                    ctx.cu_seqlens,
                    qkv_out,
                    context,
                    proj_dgrad.view_as(context),
                    fp8_dtype_forward,
                    ctx.aux_ctx_tensors,
                    FusedAttnBackend["FP8"],
                    fwd_scale_inverses[META_QKV], # d_scale_qkv,
                    fwd_scale_inverses[META_S], # d_scale_s,
                    fwd_scale_inverses[META_O], # d_scale_o,
                    ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                    fwd_scales[META_S], # q_scale_s
                    ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds
                    ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                    ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds
                    ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                    None,
                    ctx.p_dropout,
                    ctx.fast_zero_fill,
                    "qkv_interleaved",
                    "no_bias",
                    "padding",
                    )

            dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size)
            dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                fp8_dtype_backward, tex.DType.kFloat16)
            torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt')

            qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused(
                dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"],
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
            )

            # QKV DGRAD
            qkv_dgrad = ext.fp8_gemm(
                qkv_weight_t_fp8,
                fwd_scale_inverses,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
                fp8_dtype_forward,
                dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"].scale_inv,
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
                workspace,
                use_split_accumulator=_2X_ACC_DGRAD,
            )
            # QKV WGRAD
            qkv_wgrad = ext.fp8_gemm(
                inputmat_t,
                fwd_scale_inverses,
                tex.FP8FwdTensors.GEMM1_INPUT,
                fp8_dtype_forward,
                dqkv_grad_output_t,
                ctx.fp8_meta["scaling_bwd"].scale_inv,
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
                workspace,
                use_split_accumulator=_2X_ACC_WGRAD,
            )

        return (qkv_dgrad,
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None)

class DPA_FP8(TransformerEngineBaseModule):
    def __init__(
        self,
        config,
        params_dtype: torch.dtype = torch.float32):
        super().__init__()
        self.p_dropout = config.dropout_p
        self.h = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.fast_zero_fill = True

        self.qkv_weight = Parameter(
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.qkv_weight.shape)
        self.qkv_bias = Parameter(
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)
        self.workspace = torch.empty(
            _CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
        )

    def forward(
        self, inp: torch.Tensor,
        cu_seqlens, max_s,
    ) -> torch.Tensor:
        with self.prepare_forward(inp, None, num_gemms=3) as inp:
            out = _dpa_fp8.apply(
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
                self.workspace,
                self.training)
        return out

    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
    ) -> List[torch.Tensor]:
        """Needs override."""