test_fused_attn.py 33.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch
import pytest

from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
11
    get_device_compute_capability,
12
)
13
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
14
15
16
17
from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention
import os

18
19
from pkg_resources import packaging
from importlib.metadata import version
20
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
21
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
22
23
24
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class ModelConfig:
    def __init__(
        self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len,
        dropout_p, attn_mask_type,
    ):
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.head_dim = head_dim
        assert (hidden_size == num_attention_heads * head_dim
                ), """hidden_size must be = num_heads x head_dim."""
        self.seq_len = seq_len
        self.dropout_p = dropout_p
        self.attn_mask_type  = attn_mask_type

model_configs = {
    "test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"),
42
43
44
45
    "test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
    "test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
    "test4": ModelConfig(1, 3072, 24, 128, 2048, 0.0, "causal"),
    "test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
46
47
}

48
49
50
51
52
53
if os.getenv('NVTE_ADDITIONAL_TESTS', '0') == '1':
    model_configs["test6"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal")
    model_configs["test7"] = ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal")
    model_configs["test8"] = ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal")
    model_configs["test9"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask")

54
55
56
57
param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
    param_types.append(torch.bfloat16)

58
batch_sizes = [1, 2] # add more if needed, e.g. 32
59

60
61
@pytest.mark.skipif(
    get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
62
63
64
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
65
66
67
@pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
68
69
70
71
72
    """Test DotProductAttention module with three backends,
    FlashAttention, FusedAttention and UnfusedDotProductAttention"""

    config = model_configs[model]

73
74
75
    if bias_type == "no_bias":
        flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
                dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
76
    fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
77
            dtype, bs, config, "FusedAttention", ckpt_attn, bias_type)
78
    unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
79
            dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type)
80

81
82
    atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3)
    if bias_type == "no_bias":
83
84
85
86
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
    torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
    torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
87

88
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
89

90
    reset_rng_states()
91
92
93
94
95
96
97
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

98
    inp = torch.randn(
99
            config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
100
            dtype=dtype).cuda()
101
    inp.requires_grad=True
102
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
103
    seqlens.fill_(config.seq_len)
104
105
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
106
107
108
109
110
    op_grad = torch.randn(
        config.seq_len, bs, config.num_attention_heads * config.head_dim,
        dtype = dtype).cuda()
    if bias_type != "no_bias":
        bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
111
                dtype=dtype).cuda()
112
113
    else:
        bias = None
114
115
116
117
118

    block = (
         DotProductAttention(
                config.num_attention_heads,
                config.head_dim,
119
120
121
122
123
124
125
126
                attention_dropout=config.dropout_p,
                sequence_parallel=False,
                tp_size=1,
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
                tp_group=None,
                layer_number=1,
                attention_type="self"
        ).to(dtype=dtype).cuda()
127
128
129
130
131
    )

    q = inp[:, :,0,:,:]
    k = inp[:, :,1,:,:]
    v = inp[:, :,2,:,:]
132
133
134
135
136
    op = block(q, k, v,
        qkv_format='sbhd',
        cu_seqlens_q = cu_seqlens,
        cu_seqlens_kv = cu_seqlens,
        attn_mask_type=config.attn_mask_type,
137
138
139
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=bias_type,
        core_attention_bias=bias)
140
141
142
143
    op.backward(op_grad)

    return op, inp.grad

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
qkv_layouts = [
    'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
    'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
    # will add tests for thd layouts later when the support is available in fused attention
    #'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
    ]

@pytest.mark.skipif(
    get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout):
    """Test DotProductAttention module with different QKV layouts"""

    config = model_configs[model]

    flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
            dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
    fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout(
            dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt)
    unfused_attn_fwd, unfused_attn_bwd = _run_dpa_qkv_layout(
            dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt)

    atol, rtol = (5e-2, 5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3)
    torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
    torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol)
    torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol)
    for i in range(len(flash_attn_bwd)):
        torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol)
        torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], atol = atol, rtol = rtol)
        torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], atol = atol, rtol = rtol)

def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):

    torch.manual_seed(1234)
    torch.cuda.manual_seed(1234)
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"


    dim_to_num = {'b': bs,
        's': config.seq_len,
        'h': config.num_attention_heads,
        'd': config.head_dim,
        't': bs * config.seq_len,
        '3': 3,
        '2': 2}

    inp = []
    for i,layout in enumerate(qkv_layout.split('_')):
        tensor_shape = [dim_to_num[j] for j in layout]
        tensor = 0.1 * torch.randn(tensor_shape, dtype = dtype).cuda()
        tensor_count = 1
        split_dim = 0
        for dim,l in enumerate(layout):
             if l.isdigit():
                 tensor_count = int(l)
                 split_dim = dim
                 break
        tensors = torch.split(tensor, 1, dim = split_dim) if split_dim != 0 else [tensor]
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
        inp[i].requires_grad=True

    seqlens = torch.empty(bs, dtype = torch.int32).cuda()
    seqlens.fill_(config.seq_len)
    cu_seqlens = torch.zeros(bs + 1, device = inp[0].device, dtype = torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
    qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
    qkv_format_no_thd = qkv_format if qkv_format != 'thd' else 'bshd'
    op_grad_shape = [dim_to_num[i] for i in qkv_format_no_thd]
    op_grad_shape_new = [*op_grad_shape[:-2], op_grad_shape[-2] * op_grad_shape[-1]]
    op_grad = 0.001 * torch.randint(0, 200, op_grad_shape_new, dtype = dtype).cuda()

    block = (
         DotProductAttention(
                config.num_attention_heads,
                config.head_dim,
                attention_dropout = config.dropout_p,
                attn_mask_type = config.attn_mask_type,
                sequence_parallel = False,
                tp_size = 1,
                get_rng_state_tracker = None,
                tp_group = None,
                layer_number = 1,
                attention_type = "self"
        ).to(dtype = dtype).cuda()
    )

    if qkv_format != 'thd':
        op = block(inp[0], inp[1], inp[2], qkv_format=qkv_format)
    else:
        cu_seqlens_q = torch.arange(
                0,
                (bs + 1) * config.seq_len,
                step=config.seq_len,
                dtype=torch.int32,
                device=inp[0].device)
        cu_seqlens_kv = torch.arange(
                0,
                (bs + 1) * config.seq_len,
                step=config.seq_len,
                dtype=torch.int32,
                device=inp[1].device)
        op = block(inp[0], inp[1], inp[2],
                qkv_format=qkv_format,
                cu_seqlens_q = cu_seqlens_q,
                cu_seqlens_kv = cu_seqlens_kv)
    op.backward(op_grad)

    return op, (inp[0].grad, inp[1].grad, inp[2].grad)

268
269
@pytest.mark.skipif(
    get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
270
271
272
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
273
274
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
cyanguwa's avatar
cyanguwa committed
275
276
@pytest.mark.parametrize("fused_qkv_params", [True, False])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_params):
277
278
279
280
281
    """Test TransformerLayer module when its DotProductAttention is enabled with
    FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""

    config = model_configs[model]

282
283
    if bias_type == "no_bias":
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
cyanguwa's avatar
cyanguwa committed
284
                dtype, bs, config, "FlashAttention", ckpt_attn, bias_type, fused_qkv_params)
285
    fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
cyanguwa's avatar
cyanguwa committed
286
            dtype, bs, config, "FusedAttention", ckpt_attn, bias_type, fused_qkv_params)
287
    unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
cyanguwa's avatar
cyanguwa committed
288
            dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type, fused_qkv_params)
289

290
291
    atol, rtol = (5e-1, 5e-2)
    if bias_type == "no_bias":
292
293
294
295
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, atol=atol, rtol=rtol)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, atol=atol, rtol=rtol)
    torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
    torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
296

cyanguwa's avatar
cyanguwa committed
297
def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params):
298

299
    reset_rng_states()
300
301
302
303
304
305
306
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

307
    inp = torch.randn(
308
            config.seq_len, bs, config.num_attention_heads * config.head_dim,
309
            dtype=dtype).cuda()
310
    inp.requires_grad=True
311
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
312
    seqlens.fill_(config.seq_len)
313
314
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
315
316
317
318
319
320
321
322
323

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
    drop_path_rates = [
            rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
324
325
    if bias_type != "no_bias":
        bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
326
                dtype=dtype).cuda()
327
328
    else:
        bias = None
329
330
331
332
333
334

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
            layernorm_epsilon=1e-5,
            hidden_dropout=0.0,
            attention_dropout=config.dropout_p,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            layer_number=layer_number,
            kv_channels=config.head_dim,
            tp_group=None,
            tp_size=1,
            params_dtype=dtype,
            get_rng_state_tracker=None,
            fuse_wgrad_accumulation=False,
            seq_length=config.seq_len,
            micro_batch_size=bs,
            sequence_parallel=False,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="encoder",
            drop_path_rate=drop_path_rates[layer_number - 1],
            set_parallel_mode=True,
cyanguwa's avatar
cyanguwa committed
355
            fuse_qkv_params=fused_qkv_params,
356
357
358
359
            zero_centered_gamma=False,
            qkv_weight_interleaved=False,
            ub_tp_comm_overlap=False,
            bias=True,
360
        )
361
        .to(dtype=dtype)
362
363
364
        .cuda()
    )

365
    num_iters = 5
366
    for i in range(num_iters):
367
368
369
370
        op = block(inp, self_attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=ckpt_attn,
            core_attention_bias_type=bias_type,
            core_attention_bias=bias)
371
372
        loss = op.sum()
        loss.backward()
373
374
375

    return op, inp.grad

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
@pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available")
@pytest.mark.skipif(
    get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_transformer_layer_gqa(dtype, bs, model):
    """Test TransformerLayer module when its DotProductAttention is enabled with
    FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""

    config = model_configs[model]
    def find_factors(x):
       f = []
       for i in range(1, x + 1):
           if x % i == 0:
               f.append(i)
       return f

    num_querys_per_gqa_group = find_factors(config.num_attention_heads)

    for num_q_per_gqa_group in num_querys_per_gqa_group:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer_gqa(
                dtype, bs, config, "FlashAttention", num_q_per_gqa_group)
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa(
                dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)

402
        atol, rtol = 5e-1, 5e-2
403
404
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
405
406
407

def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):

408
    reset_rng_states()
409
410
411
412
    os.environ["NVTE_FLASH_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"

413
    inp = torch.randn(
414
            config.seq_len, bs, config.num_attention_heads * config.head_dim,
415
            dtype=dtype).cuda()
416
    inp.requires_grad=True
417
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
418
    seqlens.fill_(config.seq_len)
419
420
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
421
422
    op_grad = torch.randn(
        config.seq_len, bs, config.num_attention_heads * config.head_dim,
423
        dtype=dtype).cuda()
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
    drop_path_rates = [
            rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
            num_gqa_groups=config.num_attention_heads / num_querys_per_gqa_group,
            layernorm_epsilon=1e-5,
            hidden_dropout=0.0,
            attention_dropout=config.dropout_p,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            layer_number=layer_number,
            kv_channels=config.head_dim,
            tp_group=None,
            tp_size= 1,
            params_dtype=dtype,
            get_rng_state_tracker=None,
            fuse_wgrad_accumulation=False,
            seq_length=config.seq_len,
            micro_batch_size=bs,
            sequence_parallel=False,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="encoder",
            drop_path_rate=drop_path_rates[layer_number - 1],
            set_parallel_mode=True,
            fuse_qkv_params=True,
            zero_centered_gamma=False,
            qkv_weight_interleaved=False,
            ub_tp_comm_overlap=False,
            bias=True,
465
        )
466
        .to(dtype=dtype)
467
468
469
        .cuda()
    )

470
    op = block(inp, self_attn_mask_type=config.attn_mask_type)
471
472
473
474
    op.backward(op_grad)

    return op, inp.grad

475
476
477
478
479
480
model_configs_fp8 = {
    "test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
}
batch_sizes_fp8 = [1, 4]
param_types_fp8 = [torch.float16]

481
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
482
483
484
485
486
487
488
489
490
491
492
493
494
495
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, bs, model):
    """Test DotProductAttention module with FP8,
    using cpp_extensions import fused_attn_fwd/bwd_qkvpacked and UnfusedDotProductAttention"""

    config = model_configs_fp8[model]

    fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
            dtype, bs, config, "FusedAttention")
    unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
            dtype, bs, config, "UnfusedDotProductAttention")

496
    atol, rtol = (2.5e-2, 2.5e-2)
497
498
    torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
    torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
499
500
501

def _run_dpa_fp8(dtype, bs, config, backend):

502
    reset_rng_states()
503
504
505
506
507
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"

    inp = 0.01 * torch.randn(
            bs * config.seq_len, config.num_attention_heads * config.head_dim,
508
            dtype=dtype).cuda()
509
    inp.requires_grad=True
510
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
511
    seqlens.fill_(config.seq_len)
512
513
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
514
515
    op_grad = 0.01 * torch.randn(
        bs * config.seq_len, config.num_attention_heads * config.head_dim,
516
        dtype=dtype).cuda()
517
518
519
520
521
522
523
524
525
526
    torch.save(op_grad, 'op_grad.pt')

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        interval=1,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

527
    dpa = DPA_FP8(config).to(dtype=torch.float16).cuda()
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        op = dpa(inp, cu_seqlens, config.seq_len)
        op.backward(op_grad)

    context = torch.load("ctx.pt")
    dqkv = torch.load('dqkv.pt')
    return (context.view(bs, config.seq_len, -1).transpose(0,1),
        dqkv.view(bs, config.seq_len, 3, config.num_attention_heads, config.head_dim).transpose(0,1).contiguous())

def _run_dpa_fp8_ref(dtype, bs, config, backend):

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

    inp = torch.load('qkv.pt').cuda()
    inp.requires_grad=True
548
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
549
    seqlens.fill_(config.seq_len)
550
551
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
552
553
554
555
556
557
    op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)

    block = (
         DotProductAttention(
                config.num_attention_heads,
                config.head_dim,
558
559
560
                attention_dropout=config.dropout_p,
                sequence_parallel=False,
                tp_size=1,
561
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
562
563
564
565
                tp_group=None,
                layer_number=1,
                attention_type="self"
        ).to(dtype=dtype).cuda()
566
567
568
569
570
    )

    q = inp[:, :,0,:,:]
    k = inp[:, :,1,:,:]
    v = inp[:, :,2,:,:]
571
    op = block(q, k, v, attn_mask_type=config.attn_mask_type)
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    op.backward(op_grad)

    return op, inp.grad

from torch.nn.parameter import Parameter
import transformer_engine.pytorch.cpp_extensions as ext
import transformer_engine_extensions as tex
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch import fp8_autocast
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule, _prepare_backward
from transformer_engine.common import recipe
from typing import Union, Dict, Any, Tuple, List
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
587
588
    fused_attn_fwd,
    fused_attn_bwd,
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
    FusedAttnBackend)

_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

META_QKV  = tex.FP8FwdTensors.GEMM1_OUTPUT
META_O    = tex.FP8FwdTensors.GEMM2_INPUT
META_DO   = tex.FP8BwdTensors.GRAD_INPUT2
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1

META_S    = tex.FP8FwdTensors.GEMM3_WEIGHT
META_DS   = tex.FP8BwdTensors.GRAD_INPUT3

class _dpa_fp8(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
        num_attention_heads: int,
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        workspace: torch.Tensor,
        is_training: bool,
    ) -> torch.Tensor:

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
        h = num_attention_heads
        d = in_features // h
        b = cu_seqlens.numel() - 1
        is_nl = False
        if b < 4 and b > 1:
            max_s = 512
            is_nl = True

        fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

        inputmat, inputmat_t = ext.fp8_cast_transpose_fused(
            inp,
            fp8_meta["scaling_fwd"],
            tex.FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype_forward,
        )

        qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused(
            qkv_weight,
            fp8_meta["scaling_fwd"],
            tex.FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype_forward,
        )

        M = None
        ZInv = None
        philox_unpacked = None

651
        qkv_out, _ = ext.fp8_gemm(
652
653
654
655
656
657
658
659
660
661
662
663
            qkv_weight_fp8,
            fp8_meta["scaling_fwd"].scale_inv,
            tex.FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype_forward,
            inputmat,
            fp8_meta["scaling_fwd"].scale_inv,
            tex.FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype_forward,
            torch.uint8,
            workspace,
            bias=qkv_bias,
            use_bias=True,
664
665
            out_index=META_QKV,
            fp8_meta_tensor=fp8_meta["scaling_fwd"],
666
667
668
669
670
671
672
673
674
675
            use_split_accumulator=_2X_ACC_FPROP,
            D_dtype=fp8_dtype_forward,
        )
        qkv_out = qkv_out.view(-1, 3, h, d)
        qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"],
                META_QKV, fp8_dtype_forward,
                tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous()
        torch.save(qkv_out_fp16, 'qkv.pt')

        # FMHA
676
        context_, aux_ctx_tensors, *rest = fused_attn_fwd(
677
678
                is_training,
                max_s,
679
                max_s,
680
                cu_seqlens,
681
682
683
684
                cu_seqlens,
                qkv_out[:,0,:,:],
                qkv_out[:,1,:,:],
                qkv_out[:,2,:,:],
685
686
687
688
689
690
691
692
                fp8_dtype_forward,
                FusedAttnBackend["FP8"],
                None,
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
693
694
695
                attn_scale=None,
                dropout=p_dropout,
                fast_zero_fill=fast_zero_fill,
696
                qkv_layout="t3hd",
697
698
699
                attn_bias_type="no_bias",
                attn_mask_type="padding",
                rng_gen=None,
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
                )
        M, ZInv, philox_unpacked = aux_ctx_tensors

        context = context_.view(-1, in_features)
        context_t = tex.fp8_transpose(context, fp8_dtype_forward)

        ctx.save_for_backward(
            inputmat_t, qkv_weight_t_fp8, workspace,
            qkv_out,
            context_, context_t,
            fp8_meta["scaling_fwd"].scale,
            fp8_meta["scaling_fwd"].scale_inv,
        )
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.is_nl = is_nl
        ctx.hidden_size = in_features
        ctx.num_attention_heads = num_attention_heads

        context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"],
                META_O, fp8_dtype_forward, tex.DType.kFloat16)
        torch.save(context_fp16, 'ctx.pt')
        return context_fp16


    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:

        with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
            (
                inputmat_t,
                qkv_weight_t_fp8,
                workspace,
                qkv_out,
                context, context_t,
                fwd_scales,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
            fp8_dtype_forward = fp8.get_fp8_te_dtype(
                ctx.fp8_meta["recipe"], fprop_tensor=True
            )
            fp8_dtype_backward = fp8.get_fp8_te_dtype(
                ctx.fp8_meta["recipe"], fprop_tensor=False
            )

            proj_dgrad = ext.cast_to_fp8(
                grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
            )

755
            dq, dk, dv, *rest = fused_attn_bwd(
756
                    ctx.max_s,
757
758
                    ctx.max_s,
                    ctx.cu_seqlens,
759
                    ctx.cu_seqlens,
760
761
762
                    qkv_out[:,0,:,:],
                    qkv_out[:,1,:,:],
                    qkv_out[:,2,:,:],
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
                    context,
                    proj_dgrad.view_as(context),
                    fp8_dtype_forward,
                    ctx.aux_ctx_tensors,
                    FusedAttnBackend["FP8"],
                    fwd_scale_inverses[META_QKV], # d_scale_qkv,
                    fwd_scale_inverses[META_S], # d_scale_s,
                    fwd_scale_inverses[META_O], # d_scale_o,
                    ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                    fwd_scales[META_S], # q_scale_s
                    ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds
                    ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                    ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds
                    ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                    None,
                    ctx.p_dropout,
                    ctx.fast_zero_fill,
780
                    "t3hd",
781
782
783
                    "no_bias",
                    "padding",
                    )
784
            dqkv = torch.cat([dq.unsqueeze(1), dk.unsqueeze(1), dv.unsqueeze(1)], dim=1)
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800

            dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size)
            dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                fp8_dtype_backward, tex.DType.kFloat16)
            torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt')

            qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused(
                dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"],
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
            )

            # QKV DGRAD
801
            qkv_dgrad, _ = ext.fp8_gemm(
802
803
804
805
806
807
808
809
810
811
812
813
814
                qkv_weight_t_fp8,
                fwd_scale_inverses,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
                fp8_dtype_forward,
                dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"].scale_inv,
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
                workspace,
                use_split_accumulator=_2X_ACC_DGRAD,
            )
            # QKV WGRAD
815
            qkv_wgrad, _ = ext.fp8_gemm(
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
                inputmat_t,
                fwd_scale_inverses,
                tex.FP8FwdTensors.GEMM1_INPUT,
                fp8_dtype_forward,
                dqkv_grad_output_t,
                ctx.fp8_meta["scaling_bwd"].scale_inv,
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
                workspace,
                use_split_accumulator=_2X_ACC_WGRAD,
            )

        return (qkv_dgrad,
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None)

class DPA_FP8(TransformerEngineBaseModule):
    def __init__(
        self,
        config,
        params_dtype: torch.dtype = torch.float32):
        super().__init__()
        self.p_dropout = config.dropout_p
        self.h = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.fast_zero_fill = True

        self.qkv_weight = Parameter(
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.qkv_weight.shape)
        self.qkv_bias = Parameter(
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)
        self.workspace = torch.empty(
            _CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
        )

    def forward(
        self, inp: torch.Tensor,
        cu_seqlens, max_s,
    ) -> torch.Tensor:
        with self.prepare_forward(inp, None, num_gemms=3) as inp:
            out = _dpa_fp8.apply(
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
                self.workspace,
                self.training)
        return out

    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
    ) -> List[torch.Tensor]:
        """Needs override."""