test_attention.py 96 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
import logging
import math
Tim Moon's avatar
Tim Moon committed
6
import os
7
8
import sys
import pathlib
9
from typing import Any, Dict, Tuple, Union
Tim Moon's avatar
Tim Moon committed
10

11
import pytest
Tim Moon's avatar
Tim Moon committed
12
import torch
13

Tim Moon's avatar
Tim Moon committed
14
from transformer_engine.common import recipe
15
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
16
from transformer_engine.pytorch.attention.dot_product_attention import (
Tim Moon's avatar
Tim Moon committed
17
    DotProductAttention,
18
19
    _attention_backends,
)
20
21
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
22
    FlashAttentionUtils,
23
    check_set_window_size,
Tim Moon's avatar
Tim Moon committed
24
)
25
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
Tim Moon's avatar
Tim Moon committed
26
27
28
29
30
31
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    FusedAttnBackend,
    fused_attn_bwd,
    fused_attn_fwd,
)
32
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
Tim Moon's avatar
Tim Moon committed
33
import transformer_engine.pytorch.fp8 as fp8
34
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
35
from transformer_engine.pytorch.utils import (
Tim Moon's avatar
Tim Moon committed
36
    get_device_compute_capability,
37
38
    init_method_normal,
    scaled_init_method_normal,
39
    is_bf16_compatible,
40
)
41
from transformer_engine.pytorch.utils import get_cudnn_version
42
import transformer_engine_torch as tex
43
44
45
46
47
from transformer_engine.pytorch.tensor.quantized_tensor import (
    Quantizer,
    prepare_for_saving,
    restore_from_saved,
)
48

49
50
51
52
53
54
55
56
57
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
    reset_rng_states,
    ModelConfig,
    dtype_tols,
    get_available_attention_backends,
)

58
# Only run FP8 tests on H100
Tim Moon's avatar
Tim Moon committed
59
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
60
61

seed = 1234
62
63
# Reset RNG states
reset_rng_states()
64
65
66
67
68
69
70


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    fp8.FP8GlobalStateManager.reset()

71

72
model_configs_base = {
73
    #     test:             b,  h, hg,  d,  sq, skv,   p,      mask,      bias
74
75
76
77
78
79
80
81
82
83
84
85
    "base_1_0": ModelConfig(8, 128, 16, 64),
    "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
    "base_2_0": ModelConfig(2, 2048, 24, 128),
    "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
    "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
    "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
    "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
    "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
    "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
    "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
    "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
86
87
}

88

89
param_types = [torch.float16]
90
if is_bf16_compatible():  # bf16 requires sm_80 or higher
91
92
93
    param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]

94

95
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
96
@pytest.mark.parametrize("dtype", param_types)
97
98
99
100
101
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
102
@pytest.mark.parametrize("swa", [False])
103
@pytest.mark.parametrize("pad_between_seqs", [False])
104
105
106
def test_dot_product_attention(
    dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
):
107
    """Test DotProductAttention module"""
108

Tim Moon's avatar
Tim Moon committed
109
    # Get configs
110
    tols = dict(atol=1e-3, rtol=1e-3)
Tim Moon's avatar
Tim Moon committed
111
    if dtype == torch.bfloat16:
112
        tols = dict(atol=1.5e-2, rtol=1.5e-2)
113
    config = model_configs[model]
114
    is_mla = config.head_dim_qk != config.head_dim_v
115
    is_mqa_gqa = config.num_heads != config.num_gqa_groups
116
117
    if qkv_layout is None:
        if config.attn_type == "self":
118
            qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
119
        else:
120
            qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
121
    if "3" in qkv_layout and config.attn_type == "cross":
122
        pytest.skip("No need to test this layout for cross attention")
Tim Moon's avatar
Tim Moon committed
123

124
125
126
    if config.window_size == (-1, -1) and swa:
        config.window_size = [2, 2]
    config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
127
128

    is_training = True
129
    available_backends, _, fused_attn_backends = get_available_attention_backends(
130
        config,
131
        qkv_dtype=dtype,
132
        qkv_layout=qkv_layout,
133
        window_size=config.window_size,
134
        pad_between_seqs=pad_between_seqs,
135
        is_training=is_training,
Tim Moon's avatar
Tim Moon committed
136
    )
137
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
138
139
    if not fused_attn_supported:
        is_training = False
140
        available_backends, _, fused_attn_backends = get_available_attention_backends(
141
142
143
144
145
146
147
148
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            window_size=config.window_size,
            pad_between_seqs=pad_between_seqs,
            is_training=is_training,
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
149

150
151
    # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
    # mannually pads and unpads the input and output of FlashAttention for testing purposes
152
153
    if (
        pad_between_seqs
154
        and FlashAttentionUtils.is_installed
155
156
157
158
        and not (
            config.max_seqlen_q != config.max_seqlen_kv
            and config.attn_mask_type in ["causal", "padding_causal"]
        )
159
        and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
160
    ):
161
        flash_attn_supported = True
162
163
164

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
165
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
166
167

    # UnfusedDotProductAttention backend
168
169
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
170
171
172
173
174
175
176
177
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
178
        )
Tim Moon's avatar
Tim Moon committed
179
180
181

    # FusedAttention backend
    if fused_attn_supported:
182
        if len(fused_attn_backends) == 1:
183
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
184
185
186
187
188
189
190
191
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
192
            )
193
        if len(fused_attn_backends) == 2:
194
195
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
196
197
198
199
200
201
202
203
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
204
205
206
            )
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
            fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
207
208
209
210
211
212
213
214
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
215
            )
216

Tim Moon's avatar
Tim Moon committed
217
218
219
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
220
221
222
223
224
225
226
227
            dtype,
            config,
            "FlashAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
Tim Moon's avatar
Tim Moon committed
228
        )
229

230
    logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
231
    if unfused_attn_supported and flash_attn_supported:
232
        logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
233
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
234
        for i, _ in enumerate(flash_attn_bwd):
235
            torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
236
237
238
239
240
    if unfused_attn_supported and fused_attn_supported:
        logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        for i, _ in enumerate(unfused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
241
    if fused_attn_supported and flash_attn_supported:
242
        logging.info("[test_dot_product_attention]: fused attn vs flash attn")
243
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
244
        for i, _ in enumerate(flash_attn_bwd):
245
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
246
    if fused_attn_supported and len(fused_attn_backends) == 2:
247
        logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
248
        torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
249
        for i, _ in enumerate(fused_attn_bwd):
250
251
            torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)

252

253
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
254
255
256
257
258
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
    """Test DotProductAttention module with checkpointing"""
259
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
260

261

262
263
model_configs_mla = {
    #    test:             b,  h, hg, dqk, sq, skv,   p,      mask,      bias   # attn , backend
264
265
266
267
    "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),  # self , 0
    "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),  # self , 1
268
    "mla_2_1": ModelConfig(
269
        1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
270
    ),  # cross, 1
271
    "mla_2_2": ModelConfig(
272
        1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
273
    ),  # cross, 1
274
275
276
    "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),  # inference
    "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),  # inference
    "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),  # inference
277
278
    "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),  # inference
    "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),  # inference
279
280
281
282
283
284
285
286
287
288
289
290
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
    """Test DotProductAttention module with Multi-Latent Attention (MLA)"""
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


291
292
model_configs_mask = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,             mask,      bias
293
294
295
296
297
298
    "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
    "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "mask_2_1": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
299
    ),
300
301
302
303
304
305
306
307
308
309
    "mask_2_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
    ),
    "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
310
    "mask_5_1": ModelConfig(
311
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
312
313
    ),
    "mask_5_2": ModelConfig(
314
315
316
317
318
319
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_7_0": ModelConfig(
        2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
320
    ),
321
322
323
324
325
326
327
    "mask_7_1": ModelConfig(
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
    ),
    "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
328
    "mask_10_0": ModelConfig(
329
        2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
330
    ),
331
    "mask_10_1": ModelConfig(
332
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
333
    ),
334
}
335

336

337
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
338
339
340
341
342
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
    """Test DotProductAttention module with different mask types"""
343
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
344

345

346
347
model_configs_bias = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
    "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),  # skipped
    "bias_1_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"
    ),  # skipped
    "bias_2_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
    ),  # skipped
    "bias_2_1": ModelConfig(
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
    ),  # skipped
368
    "bias_2_2": ModelConfig(
369
        4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
370
371
    ),  # skipped
    "bias_2_3": ModelConfig(
372
373
374
375
376
377
378
379
380
381
382
383
384
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
    ),  # skipped
    "bias_2_4": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"
    ),  # skipped
    "bias_2_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
385
    ),  # skipped
386
387
388
389
390
391
392
393
394
    "bias_3_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_1": ModelConfig(
        2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_2": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
395
    "bias_3_3": ModelConfig(
396
397
398
399
400
401
402
403
404
405
406
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="causal",
        attn_bias_type="post_scale_bias",
    ),  # skipped
    "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
    "bias_3_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
407
408
    ),  # skipped
    "bias_4_0": ModelConfig(
409
        4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
410
411
    ),  # skipped
    "bias_4_1": ModelConfig(
412
413
414
415
416
417
418
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
419
420
    ),  # skipped
    "bias_4_2": ModelConfig(
421
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
422
423
    ),  # skipped
    "bias_4_3": ModelConfig(
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),  # skipped
    "bias_4_4": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
    ),  # skipped
    "bias_4_5": ModelConfig(
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="alibi",
443
    ),  # skipped
444
}
445

446

447
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
448
449
450
451
452
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types"""
453
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
454

455

456
457
model_configs_bias_shapes = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,
458
459
460
461
462
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
    "bias_1_4": ModelConfig(
463
        4,
464
465
        2048,
        24,
466
        128,
467
468
469
470
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="1hss",
        alibi_type="custom",
471
472
    ),
    "bias_1_5": ModelConfig(
473
474
475
476
477
478
479
480
        2,
        2048,
        24,
        128,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="bhss",
        alibi_type="custom",
481
    ),
482
483
}

484

485
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
486
487
488
489
490
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types and shapes"""
491
492
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)

493

494
model_configs_swa = {
495
    #    test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
496
497
498
499
500
501
502
503
504
505
506
507
    "swa_1_1": ModelConfig(2, 2048, 16, 64),
    "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
    "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
    "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "swa_3_2": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
    ),
    "swa_3_3": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
508
    ),
509
510
511
512
513
514
515
    "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
    "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
516
    "swa_6_2": ModelConfig(
517
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
518
519
    ),
    "swa_6_3": ModelConfig(
520
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
521
    ),
522
}
523
524


525
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
526
527
528
529
530
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
    """Test DotProductAttention module with sliding window attention"""
531
532
    test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)

533

534
535
model_configs_alibi_slopes = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,      mask,    bias, alibi_type
536
537
538
539
540
541
542
543
544
545
546
547
548
    "alibi_1_0": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
    ),
    "alibi_1_1": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="vanilla",
    ),
549
    "alibi_2_0": ModelConfig(
550
        2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
551
552
    ),
    "alibi_2_1": ModelConfig(
553
554
555
556
557
558
559
560
        1,
        1024,
        24,
        128,
        max_seqlen_kv=2048,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="custom",
561
    ),
562
}
563
564


565
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
566
567
568
569
570
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
    """Test DotProductAttention module with ALiBi slopes"""
571
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
572

573

574
qkv_layouts = [
575
576
577
578
579
580
581
582
583
584
585
    "sb3hd",
    "sbh3d",
    "sbhd_sb2hd",
    "sbhd_sbh2d",
    "sbhd_sbhd_sbhd",
    "bs3hd",
    "bsh3d",
    "bshd_bs2hd",
    "bshd_bsh2d",
    "bshd_bshd_bshd",
]
586

587

588
589
model_configs_layout = {
    #       test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
    "layout_0_0": ModelConfig(2, 128, 16, 64),
    "layout_0_1": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "layout_0_3": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_1_0": ModelConfig(2, 2048, 24, 128),
    "layout_1_1": ModelConfig(
        2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_3": ModelConfig(
        1,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
    "layout_2_1": ModelConfig(
        2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
622
623
}

624

625
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
626
@pytest.mark.parametrize("dtype", param_types_lean)
627
628
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
629
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
630
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
631
    """Test DotProductAttention module with different QKV layouts"""
632
633
634
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


635
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
636
637
model_configs_layout_thd = {
    #       test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
638
639
640
641
642
643
644
    "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "layout_1_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
645
    ),
646
    "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
647
    "layout_2_1": ModelConfig(
648
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
649
650
    ),
    "layout_2_2": ModelConfig(
651
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
652
    ),
653
    "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
654
    "layout_3_1": ModelConfig(
655
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
656
657
    ),
    "layout_3_2": ModelConfig(
658
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
659
    ),
660
    "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
661
    "layout_4_1": ModelConfig(
662
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
663
664
    ),
    "layout_4_2": ModelConfig(
665
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
666
667
    ),
    "layout_5_0": ModelConfig(
668
        2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
669
670
    ),
    "layout_5_1": ModelConfig(
671
672
673
674
675
676
677
        2,
        2048,
        24,
        128,
        num_gqa_groups=1,
        attn_mask_type="padding_causal_bottom_right",
        window_size=(4, 0),
678
679
680
    ),
    "layout_5_2": ModelConfig(
        2,
681
        2048,
682
683
        24,
        128,
684
685
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal_bottom_right",
686
687
        window_size=(4, 0),
    ),
688
689
690
}


691
692
693
694
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
    get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
695
696
697
698
699
700
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with different QKV layouts"""
701
702
703
    config = model_configs[model]
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")
704
    logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
705
    pad_between_seqs = True
706
707
708
    test_dot_product_attention(
        dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
    )
709
    if get_cudnn_version() >= (9, 3, 0):
710
        logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
711
712
713
714
715
        # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
        pad_between_seqs = False
        test_dot_product_attention(
            dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
        )
716

717

718
def _run_dot_product_attention(
719
720
721
722
723
724
725
726
727
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_layout: str,
    workspace_opt: bool,
    pad_between_seqs: bool,
    is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
728
729
730
731
    """Run DotProductAttention module with one forward pass and one backward pass"""

    # Set RNG and environment varables
    reset_rng_states()
732
733
734
735
736
737
738
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
739
    _attention_backends["backend_selection_requires_update"] = True
740

741
    # Create seqlens
742
743
744
745
746
747
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
748
            seqlens_kv = seqlens_q
749
        if config.attn_type == "cross":
750
751
752
753
754
755
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
756
757
758
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
759
    else:
760
761
762
763
764
765
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
766
767
768
769
770
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

771
772
773
774
775
776
777
    seqlens_q_after_pad = seqlens_q.clone()
    seqlens_kv_after_pad = seqlens_kv.clone()
    cu_seqlens_q_after_pad = cu_seqlens_q.clone()
    cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
    pad_len = [0] * config.batch_size
    if pad_between_seqs:
        max_pad_len = 3
778
        pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda")  # 3
779
780
781
782
783
        seqlens_q_after_pad = seqlens_q + pad_len
        seqlens_kv_after_pad = seqlens_kv + pad_len
        cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
        cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)

784
785
786
    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
787
        if config.attn_type == "self":
788
789
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
790
791
792
793
794
795
796
797
798
799
800
801
802
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
803
            attention_mask = attention_mask_q.to(device="cuda")
804
        if config.attn_type == "cross":
805
806
807
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
                attention_mask_kv = torch.cat(
                    [
                        attention_mask_kv,
                        torch.Tensor(
                            [False] * seqlens_kv[i]
                            + [True] * (config.max_seqlen_kv - seqlens_kv[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
835
            attention_mask = (
836
837
838
                attention_mask_q.to(device="cuda"),
                attention_mask_kv.to(device="cuda"),
            )
839

840
    alibi_slopes = None
841
842
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
        if config.bias_shape == "1hss":
843
844
845
            alibi_slopes = (
                torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
            )
846
        if config.bias_shape == "bhss":
847
848
849
850
851
            alibi_slopes = (
                torch.randn(config.batch_size, config.num_heads)
                .abs()
                .to(dtype=torch.float32, device="cuda")
            )
852

853
854
    # Create input tensors
    dim_to_num = {
855
856
857
858
859
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
860
861
        "dqk": config.head_dim_qk,
        "dv": config.head_dim_v,
862
863
864
865
866
867
        "t": cu_seqlens_q_after_pad[-1],
        "tg": cu_seqlens_kv_after_pad[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
868
    inp = []
869
    inp_orig = []
870
871
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
872
        if i == 0:
873
            layout = layout.replace("s", "sq")
874
        else:
875
876
877
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
878
879
880
881
        if i == 2:
            layout = layout.replace("d", "dv")
        else:
            layout = layout.replace("d", "dqk")
882
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
883
        tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
884
        tensor_orig = tensor
885
886
        if qkv_format == "thd" and pad_between_seqs:
            tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
887
            if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
888
889
890
891
892
893
894
895
896
897
898
899
900
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_q_after_pad[i - 1],
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_q_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
901
            if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
902
903
904
905
906
907
908
909
910
911
912
913
914
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_kv_after_pad[i - 1],
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_kv_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
915
916
        tensor_count = 1
        split_dim = 0
917
        for dim, l in enumerate(layout.split("_")):
918
919
920
921
922
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
923
924
925
        tensors_orig = (
            torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
        )
926
927
928
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
929
                inp_orig.append(tensors_orig[j].squeeze(split_dim))
930
931
            else:
                inp.append(tensors[j])
932
                inp_orig.append(tensors_orig[j])
933
    for i in range(3):
934
        inp[i].requires_grad = True
935
936
        inp_orig[i].requires_grad = True

937
    # Create output gradient
938
939
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
940
    qkv_format_kv = qkv_format_kv.replace("d", "dv")
941
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
942
943
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
    out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
944
    out_grad_orig = out_grad
945
946
    if qkv_format == "thd" and pad_between_seqs:
        out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
947
        if qkv_format_kv == "t_h_dv":
948
949
950
951
952
953
954
955
956
957
            for i in range(1, config.batch_size + 1):
                valid_range = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
                out_grad[pad_range[0] : pad_range[1]] = 0.0
                out_grad_orig = torch.cat(
                    [out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
                )
958

959
    # Create bias
960
    if config.attn_bias_type in ["no_bias", "alibi"]:
961
        bias = None
962
963
964
965
    if config.attn_bias_type == "post_scale_bias":
        shape = "_".join(config.bias_shape)
        shape = shape.replace("_s_s", "_sq_skv")
        tensor_shape = [dim_to_num[j] for j in shape.split("_")]
966
        bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
967
        if config.bias_shape != "1hss":
968
            bias.requires_grad = False
969
970
971
972

    # Create RNG
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
973

974
975
976
977
978
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    # Set up model
979
980
    block = DotProductAttention(
        config.num_heads,
981
        (config.head_dim_qk, config.head_dim_v),
982
983
984
985
986
987
988
989
990
991
992
        num_gqa_groups=config.num_gqa_groups,
        attention_dropout=config.dropout_p,
        qkv_format=qkv_format,
        attn_mask_type=config.attn_mask_type,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type=config.attn_type,
    ).to(dtype=dtype, device="cuda")
993
994
    if not is_training:
        block = block.eval()
995

996
    # Run a forward and backward pass
997
998
999
1000
1001
1002
1003
1004
1005
1006
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        q = inp_orig[0]
        k = inp_orig[1]
        v = inp_orig[2]
        d_out = out_grad_orig
    if backend == "FusedAttention":
        q = inp[0]
        k = inp[1]
        v = inp[2]
        d_out = out_grad
1007
1008
1009
1010
    out = block(
        q,
        k,
        v,
1011
        window_size=config.window_size,
1012
1013
1014
1015
1016
1017
        attention_mask=attention_mask,
        qkv_format=qkv_format,
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1018
1019
        cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
        cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
1020
1021
1022
1023
1024
1025
1026
        attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=config.attn_bias_type,
        core_attention_bias=bias,
        alibi_slopes=alibi_slopes,
        fast_zero_fill=True,
    )
1027
1028
    if is_training:
        out.backward(d_out)
1029

1030
1031
1032
1033
1034
1035
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        if is_training:
            return out, (q.grad, k.grad, v.grad)
        else:
            return out, (None, None, None)
    if backend == "FusedAttention":
1036
1037
        if qkv_format == "thd" and pad_between_seqs:
            out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1038
1039
1040
1041
            if is_training:
                q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
            for i in range(1, config.batch_size + 1):
                valid_range_q = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                valid_range_kv = (
                    cu_seqlens_kv_after_pad[i - 1],
                    cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                )
                out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
                if is_training:
                    q_grad_orig = torch.cat(
                        [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
                    )
                    k_grad_orig = torch.cat(
                        [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
                    v_grad_orig = torch.cat(
                        [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
1062
1063
1064
1065
1066
1067
1068
1069
1070
            if is_training:
                return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
            else:
                return out_orig, (None, None, None)
        else:
            if is_training:
                return out, (q.grad, k.grad, v.grad)
            else:
                return out, (None, None, None)
1071

1072

1073
1074
model_configs_te_layer = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,      mask,             bias
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
    "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "te_1_1": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "te_1_2": ModelConfig(
        2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
    ),
    "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
    "te_2_1": ModelConfig(2, 2048, 16, 64),
    "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
    "te_2_3": ModelConfig(
        1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
    "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
1091
}
1092

1093

1094
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1095
@pytest.mark.parametrize("dtype", param_types)
1096
1097
1098
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
1099
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
1100
1101
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
1102
1103
1104
def test_transformer_layer(
    dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
1105
    """Test TransformerLayer module"""
1106

Tim Moon's avatar
Tim Moon committed
1107
    # Get configs
1108
    config = model_configs[model]
1109
    tols = dict(atol=5e-2, rtol=5e-2)
1110
    workspace_opt = True
1111

1112
    # Test backend availability
1113
    is_training = True
1114
    available_backends, _, fused_attn_backends = get_available_attention_backends(
Tim Moon's avatar
Tim Moon committed
1115
        config,
1116
        qkv_dtype=dtype,
1117
1118
1119
        qkv_layout=(
            qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
        ),
1120
        is_training=is_training,
Tim Moon's avatar
Tim Moon committed
1121
    )
1122
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1123
1124
    if not fused_attn_supported:
        is_training = False
1125
        available_backends, _, fused_attn_backends = get_available_attention_backends(
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
            config,
            qkv_dtype=dtype,
            qkv_layout=(
                qkv_format.replace("hd", "h3d")
                if fused_qkv_params
                else qkv_format.replace("hd", "3hd")
            ),
            is_training=is_training,
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1136
1137
1138

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
1139
        pytest.skip("Less than two backends to compare.")
1140
1141
1142
    # Skip if qkv_format = thd and "padding" not in attn_mask_type
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        pytest.skip("THD requires padding mask.")
Tim Moon's avatar
Tim Moon committed
1143
1144

    # UnfusedDotProductAttention backend
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_format,
            workspace_opt,
            fused_qkv_params,
            RoPE,
1155
            is_training,
1156
        )
Tim Moon's avatar
Tim Moon committed
1157
1158
1159
1160
1161
1162
1163

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FusedAttention",
1164
1165
1166
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1167
1168
            fused_qkv_params,
            RoPE,
1169
            is_training,
Tim Moon's avatar
Tim Moon committed
1170
        )
1171

Tim Moon's avatar
Tim Moon committed
1172
1173
1174
1175
1176
1177
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FlashAttention",
1178
1179
1180
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1181
1182
            fused_qkv_params,
            RoPE,
1183
            is_training,
Tim Moon's avatar
Tim Moon committed
1184
        )
1185

1186
    logging.info(f"[test_transformer_layer]: is_training = {is_training}")
1187
    if unfused_attn_supported and fused_attn_supported:
1188
        logging.info("[test_transformer_layer]: unfused attn vs fused attn")
1189
1190
1191
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
    if unfused_attn_supported and flash_attn_supported:
1192
        logging.info("[test_transformer_layer]: unfused attn vs flash attn")
Tim Moon's avatar
Tim Moon committed
1193
1194
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
1195
    if fused_attn_supported and flash_attn_supported:
1196
        logging.info("[test_transformer_layer]: fused attn vs flash attn")
1197
1198
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
1199

1200

1201
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1202
@pytest.mark.parametrize("dtype", param_types_lean)
1203
1204
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
1205
1206
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format):
hugo-syn's avatar
hugo-syn committed
1207
    """Test TransformerLayer module with miscellaneous settings"""
1208
1209
1210
    ckpt_attn = True
    fused_qkv_params = True
    RoPE = True
1211
1212
1213
    test_transformer_layer(
        dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
    )
1214

1215

1216
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1217
1218
1219
1220
1221
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
    """Test TransformerLayer module with MQA/GQA"""
1222

1223
    def find_factors(x):
1224
1225
1226
1227
1228
        f = []
        for i in range(2, x + 1):
            if x % i == 0:
                f.append(i)
        return f
1229

1230
1231
1232
1233
1234
1235
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    config = model_configs[model]
    num_querys_per_gqa_group = find_factors(config.num_heads)
1236
1237

    for num_q_per_gqa_group in num_querys_per_gqa_group:
1238
1239
1240
1241
        config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
        test_transformer_layer(
            dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
        )
1242

1243

1244
def _run_transformer_layer(
1245
1246
1247
1248
1249
1250
1251
1252
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_format: str,
    workspace_opt: bool,
    fused_qkv_params: bool,
    RoPE: bool,
1253
    is_training: bool,
1254
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
1255
1256
1257
    """Run TransformerLayer module with one forward pass and one backward pass"""

    # Set RNG and environment variables
1258
    reset_rng_states()
1259
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
1260
    os.environ["NVTE_FUSED_ATTN"] = "0"
1261
1262
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
1263
1264
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
1265
    _attention_backends["backend_selection_requires_update"] = True
1266

1267
    # Create input tensor
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
    if qkv_format == "sbhd":
        inp = torch.randn(
            config.max_seqlen_q,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.max_seqlen_kv,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1285
    if qkv_format == "bshd":
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
        inp = torch.randn(
            config.batch_size,
            config.max_seqlen_q,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.batch_size,
            config.max_seqlen_kv,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1302
1303

    # Create seqlens
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
1320
    else:
1321
1322
1323
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
    if qkv_format == "thd":
        inp = torch.randn(
            cu_seqlens_q[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            cu_seqlens_kv[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1346
1347
1348
1349
1350
1351
1352

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
1353
    drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
1354

1355
    # Create bias
1356
    bias = None
1357
1358
1359
1360
1361
1362
1363
1364
1365
    if config.attn_bias_type == "post_scale_bias":
        bias = torch.randn(
            1,
            config.num_heads,
            config.max_seqlen_q,
            config.max_seqlen_kv,
            dtype=dtype,
            device="cuda",
        )
1366
1367
1368
1369

    # Create RoPE
    rotary_pos_emb = None
    if RoPE:
1370
        PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
1371
        rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1372
1373

    # Set up model
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_heads,
        num_gqa_groups=config.num_gqa_groups,
        layernorm_epsilon=1e-5,
        hidden_dropout=0.0,
        attention_dropout=config.dropout_p,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        layer_number=layer_number,
1385
        kv_channels=config.head_dim_qk,
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
        self_attn_mask_type=config.attn_mask_type,
        tp_group=None,
        tp_size=1,
        params_dtype=dtype,
        get_rng_state_tracker=None,
        fuse_wgrad_accumulation=False,
        seq_length=config.max_seqlen_q,
        micro_batch_size=config.batch_size,
        sequence_parallel=False,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
1397
        layer_type="encoder" if config.attn_type == "self" else "decoder",
1398
1399
1400
1401
1402
1403
1404
1405
1406
        drop_path_rate=drop_path_rates[layer_number - 1],
        set_parallel_mode=True,
        fuse_qkv_params=fused_qkv_params,
        zero_centered_gamma=False,
        qkv_weight_interleaved=False,
        ub_tp_comm_overlap=False,
        bias=True,
        attn_input_format=qkv_format,
    ).to(dtype=dtype, device="cuda")
1407
1408
    if not is_training:
        block = block.eval()
1409

1410
1411
1412
    # Create ALiBi slopes
    alibi_slopes = None
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
1413
        alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
1414

1415
    # Run a forward and backward pass
1416
1417
    out = block(
        inp,
1418
        self_attn_mask_type=config.attn_mask_type,
1419
1420
        encoder_output=inp_enc if config.attn_type == "cross" else None,
        enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
1421
1422
1423
        checkpoint_core_attention=False,
        rotary_pos_emb=rotary_pos_emb,
        core_attention_bias_type=config.attn_bias_type,
1424
        core_attention_bias=bias,
1425
        alibi_slopes=alibi_slopes,
1426
1427
1428
1429
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1430
    )
1431
1432
1433
    if is_training:
        loss = out.sum()
        loss.backward()
1434
1435

    return out, inp.grad
1436
1437


1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
model_configs_fp8_extra_state = {
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
}


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
    config = model_configs_fp8_extra_state[model]
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="sb3hd",
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not fused_attn_supported and not flash_attn_supported:
        pytest.skip("No attention backend available.")

    outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
    outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
    outputs_checkpoint_v1_6 = _run_attention_extra_state(
        dtype, config, mimic_v1_6=True, checkpoint=True
    )

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )


def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
    steps = 10
    path = "checkpoint.pt"
    fp8_enabled = True
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_enabled,
        fp8_mha=False,
    )

    reset_rng_states()
    hidden_states = torch.randn(
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

    def get_model(dtype, config):
        sigma = 0.023
        init_method = init_method_normal(sigma)
        output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

        with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
            block = TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_heads,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                hidden_dropout=0.0,
                attention_dropout=0.0,
                fuse_qkv_params=True,
                params_dtype=dtype,
                device="cuda",
            )
        return block

    block = get_model(dtype, config)
    for i in range(steps // 2):
        with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    if checkpoint:
        sd = block.state_dict()
        if mimic_v1_6:
            sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
                "self_attention.core_attention._extra_state"
            ]
            del sd["self_attention.core_attention._extra_state"]
        torch.save(sd, path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

        _cpu_rng_state_new = torch.get_rng_state()
        _cuda_rng_state_new = torch.cuda.get_rng_state()

        del block
        block = get_model(dtype, config)
        block.load_state_dict(torch.load(path, weights_only=False))
        torch.set_rng_state(_cpu_rng_state_new)
        torch.cuda.set_rng_state(_cuda_rng_state_new)

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for i in range((steps + 1) // 2):
        with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [output, hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)

    return outputs


1582
model_configs_fp8_vs_f16 = {
1583
    #  test:             b,  h, hg,   d,   sq,  skv,   p,      mask,      bias
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
    "fp8_9": ModelConfig(2, 2048, 16, 128),
    "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
    "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
    "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
    "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
    "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
    "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
    "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
    "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
    "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
1596
}
1597

1598
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
1599
1600
1601
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]

1602
1603

def _rmse(a, b):
1604
    return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
1605

1606

1607
1608
1609
1610
def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
    logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
    logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
    try:
1611
1612
        if a.dtype != b.dtype:
            a = a.to(b.dtype)
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
        torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
    except Exception as e:
        logging.debug(e)

    rmse = _rmse(a, b)
    logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
    rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
    assert rmse < rmse_tol * rmse_range, (
        name_a
        + " vs "
        + name_b
        + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
            rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
        )
    )


1630
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1631
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1632
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
1633
1634
1635
1636
1637
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1638
@pytest.mark.parametrize("RoPE", [True, False])
1639
@pytest.mark.parametrize("is_training", [True, False])
1640
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
1641
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1642
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1643
1644
    config = model_configs_fp8_vs_f16[model]

1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
    # Test backend availability
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_format.replace("hd", "h3d"),
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
        pytest.skip("Less than two backends to compare.")
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_format.replace("hd", "h3d"),
            is_training=is_training,
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")

    if flash_attn_supported:
1668
1669
1670
1671
1672
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1673
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
1674
        )
1675

1676
1677
1678
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
    _attention_backends["backend_selection_requires_update"] = True
1679
    logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
1680
    fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1681
        dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
1682
    )
1683
1684

    logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
1685
    fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
1686
        dtype, config, False, qkv_format, input_layernorm, RoPE, is_training
1687
1688
    )

1689
1690
1691
    atol = 5e-1
    rtol = 5e-1
    rmse_tol = 0.15
1692
    logging.debug("========== {:^25s} ==========".format("forward output"))
1693
    if flash_attn_supported:
1694
1695
1696
1697
1698
1699
1700
1701
        _error(
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
1702
        )
1703
1704
1705
1706
1707
1708
1709
1710
    _error(
        fused_attn_fwd_fp8,
        fused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "fused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
1711
    )
1712

1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
    if is_training:
        for i in range(len(param_names[:1])):
            logging.debug("========== {:^25s} ==========".format(param_names[i]))
            _error(
                fused_attn_bwd_fp8[i],
                fused_attn_bwd_f16[i],
                f"fused_attn_bwd_fp8[{i}]",
                f"fused_attn_bwd_f16[{i}]",
                atol,
                rtol,
                rmse_tol,
1724
1725
            )

1726

1727
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
1728
1729
1730
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1731

1732
1733
1734
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
Tim Moon's avatar
Tim Moon committed
1735

1736
1737
1738
1739
1740
1741
1742
1743
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_mha,
        fp8_mha=fp8_mha,
    )
Tim Moon's avatar
Tim Moon committed
1744

1745
    with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe):
1746
1747
1748
1749
        rotary_pos_emb = None
        if RoPE:
            PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
            rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1750
        mha = MultiheadAttention(
1751
1752
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_heads,
1753
            kv_channels=config.head_dim_qk,
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            layer_number=1,
            bias=True,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            params_dtype=dtype,
            input_layernorm=input_layernorm,
            fuse_qkv_params=True,
            attention_type="self",
            qkv_weight_interleaved=True,
            qkv_format=qkv_format,
1765
        ).to(dtype=dtype, device="cuda")
1766
1767
        if not is_training:
            mha = mha.eval()
1768

1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
1789
1790
1791
1792
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
1793

1794
    dim_to_num = {
1795
1796
1797
1798
1799
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
1800
        "d": config.head_dim_qk,
1801
1802
1803
1804
1805
1806
1807
1808
1809
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
    layout = "_".join(qkv_format)
    layout = layout.replace("s", "sq")
    tensor_shape = [dim_to_num[j] for j in layout.split("_")]
1810
1811
    tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
    hidden_states = tensor.view(*tensor.shape[:-2], -1)
1812
1813
    if is_training:
        hidden_states.requires_grad = True
1814
1815
1816
1817
    tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
    out_grad = tensor.view(*tensor.shape[:-2], -1)

    with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe):
1818
1819
        out = mha(
            hidden_states,
1820
1821
1822
1823
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
            is_first_microbatch=None,
1824
            rotary_pos_emb=rotary_pos_emb,
1825
1826
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
1827
        )
1828
1829
    if is_training:
        out.backward(out_grad)
Tim Moon's avatar
Tim Moon committed
1830

1831
    param_names = []
1832
    param_names.append("hidden_states.grad")
1833
1834
1835
1836
    params = []
    params.append(hidden_states)
    for name, param in mha.named_parameters():
        if param.requires_grad:
1837
            param_names.append(name + ".grad")
1838
            params.append(param)
1839

1840
1841
1842
    if is_training:
        return out, param_names, tuple(x.grad for x in params)
    return out, param_names, tuple(None for x in params)
1843

1844

1845
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1846
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1847
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
1848
1849
1850
1851
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1852
1853
@pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
1854
1855
    config = model_configs_fp8_vs_f16[model]

1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
    # TODO(cyang): think of another way to verify dropout results
    # test cuDNN FP8 dropout
    # 1. we modify the config here to not affect mha_fp8_vs_f16 tests
    # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an
    #    indirect verification method, we create Q/K/V as all 1s and check if O is all 1s
    # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell
    # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type:
    #    if get_device_compute_capability() >= (10, 0):
    #        config.dropout_p = 0.1

1866
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1867
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1868

1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
    # Test backend availability
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_layout,
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    # Skip if only unfused backend is supported
    if flash_attn_supported + fused_attn_supported < 1:
        pytest.skip("No FP8 attention backend available.")
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            is_training=is_training,
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")

    if flash_attn_supported:
1894
1895
1896
1897
1898
1899
1900
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
        flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
            dtype, config, True, qkv_layout, is_training
        )
1901

1902
1903
1904
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
    _attention_backends["backend_selection_requires_update"] = True
1905
    logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
1906
1907
1908
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
        dtype, config, True, qkv_layout, is_training
    )
1909

1910
1911
1912
1913
1914
1915
    if config.dropout_p == 0.0:
        # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
        fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
            dtype, config, False, qkv_layout, is_training
        )
1916

1917
1918
    atol = 5e-1
    rtol = 5e-2
1919
    rmse_tol = 0.11
1920
1921
    bwd_names = ["dq", "dk", "dv"]
    logging.debug("========== {:^25s} ==========".format("forward output"))
1922
    if flash_attn_supported:
1923
1924
1925
1926
1927
1928
1929
1930
        _error(
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
1931
        )
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
    if config.dropout_p != 0.0:
        # test cuDNN FP8 dropout
        assert torch.all(
            fused_attn_fwd_fp8 == 1
        ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
    else:
        _error(
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
                _error(
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                )
1959
1960


1961
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
1962

1963
1964
1965
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1966

1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_dpa,
    )

1979
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
1980
    with fp8_model_init(enabled=fp8_dpa):
1981
1982
        dpa = DotProductAttention(
            config.num_heads,
1983
            config.head_dim_qk,
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            sequence_parallel=False,
            tp_size=1,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            tp_group=None,
            layer_number=1,
            attention_type="self",
            qkv_format=qkv_format,
        ).to(dtype=dtype, device="cuda")
1994
1995
        if not is_training:
            dpa = dpa.eval()
1996

1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
2017
2018
2019
2020
2021
2022
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

    dim_to_num = {
2023
2024
2025
2026
2027
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2028
        "d": config.head_dim_qk,
2029
2030
2031
2032
2033
2034
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
2035
    inp = []
2036
2037
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
2038
        if i == 0:
2039
            layout = layout.replace("s", "sq")
2040
        else:
2041
2042
2043
2044
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2045
2046
2047
2048
2049
        if config.dropout_p == 0.0:
            tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
        else:
            # test cuDNN FP8 dropout
            tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda")
2050
2051
        tensor_count = 1
        split_dim = 0
2052
        for dim, l in enumerate(layout.split("_")):
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
        inp[i].requires_grad = True

2066
2067
2068
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
2069
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
2070
    out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
2071
2072

    with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
2073
2074
2075
2076
        out = dpa(
            inp[0],
            inp[1],
            inp[2],
2077
2078
2079
2080
2081
2082
2083
2084
            qkv_format=qkv_format,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
2085
        )
2086
2087
    if is_training:
        out.backward(out_grad)
2088

2089
2090
2091
    if is_training:
        return out, (inp[0].grad, inp[1].grad, inp[2].grad)
    return out, (None, None, None)
2092
2093
2094
2095


model_configs_fp8 = {
    #  test:             b,  h, hg,   d,   sq,  skv,   p,      mask,      bias
2096
2097
2098
2099
2100
2101
2102
2103
    "fp8_1": ModelConfig(1, 512, 1, 64),
    "fp8_2": ModelConfig(4, 512, 16, 64),
    "fp8_3": ModelConfig(1, 2048, 1, 128),
    "fp8_4": ModelConfig(2, 2048, 24, 128),
    "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
    "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
    "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
    "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
2104
2105
}
param_types_fp8 = [torch.float16, torch.bfloat16]
2106
2107
2108
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
2109
2110


2111
2112
2113
2114
2115
2116
2117
2118
@pytest.mark.skipif(
    (
        get_cudnn_version() < (8, 9, 3)
        if cudnn_frontend_version == 0
        else get_cudnn_version() < (9, 2, 1)
    ),
    reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
2119
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
2120
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
    """Test FP8 dot product attention implementations based on cuDNN frontend
    v0.9 and v1.0+. Each test compares results from a custom implementation of
    an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
    implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
    Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""

    config = model_configs_fp8[model]

2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not (fused_attn_backends and unfused_attn_supported):
        pytest.skip("Not enough backends to run this test with.")

2144
2145
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
    unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
2146

2147
2148
    atol = 5e-1
    rtol = 5e-1
2149
    rmse_tol = 0.13
2150
2151
2152
2153
2154
2155
2156
2157
    _error(
        fused_attn_fwd_fp8,
        unfused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "unfused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
2158
    )
2159
2160
2161
2162
2163
2164
2165
2166
    _error(
        fused_attn_bwd_fp8,
        unfused_attn_bwd_f16,
        "fused_attn_bwd_fp8",
        "unfused_attn_bwd_f16",
        atol,
        rtol,
        rmse_tol,
2167
    )
2168
2169
2170
2171
2172


def _run_custom_mha_fp8(dtype, config, backend):
    """Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
2173
    reset_rng_states()
2174
2175
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
2176
2177
2178
2179
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2180
    _attention_backends["backend_selection_requires_update"] = True
2181

2182
2183
2184
    inp = 0.0001 * torch.randint(
        -100,
        100,
2185
        (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
2186
2187
2188
2189
2190
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
2191
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2192
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2193

2194
    out_grad = 0.01 * torch.randn(
2195
        config.batch_size * config.max_seqlen_q,
2196
        config.num_heads * config.head_dim_qk,
2197
2198
2199
2200
        dtype=dtype,
        device="cuda",
    )
    torch.save(out_grad, "out_grad.pt")
2201
2202
2203
2204
2205
2206
2207
2208

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

2209
    mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
2210
    with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
2211
        out = mha(inp, cu_seqlens, config.max_seqlen_q)
2212
    out.backward(out_grad)
2213

2214
    out = torch.load("out.pt")
2215
2216
2217
2218
    dqkv = torch.load("dqkv.pt")
    return (
        out.view(config.batch_size, config.max_seqlen_q, -1),
        dqkv.view(
2219
            config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
2220
2221
        ).contiguous(),
    )
2222

2223

2224
2225
2226
def _run_ref_mha_f16(dtype, config, backend):
    """Run reference F16 FusedAttention. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
2227
2228
2229
2230
2231
2232
2233

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2234
    _attention_backends["backend_selection_requires_update"] = True
2235

2236
    inp = torch.load("qkv.pt").to(device="cuda")
2237
2238
2239
    inp.requires_grad = True
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2240
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2241
2242
2243
    out_grad = (
        torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
    )
2244
2245
2246
2247
2248
2249
2250

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
2251

2252
2253
    block = DotProductAttention(
        config.num_heads,
2254
        config.head_dim_qk,
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
        attention_dropout=config.dropout_p,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type="self",
        qkv_format="bshd",
    ).to(dtype=dtype, device="cuda")

    q = inp[:, :, 0, :, :]
    k = inp[:, :, 1, :, :]
    v = inp[:, :, 2, :, :]
2268
2269
2270
2271
    out = block(q, k, v, attn_mask_type=config.attn_mask_type)
    out.backward(out_grad)

    return out, inp.grad
2272
2273
2274
2275
2276
2277
2278


_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

2279
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
2280
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
2281
2282
2283
2284
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
2285
2286


2287
class _custom_mha_fp8(torch.autograd.Function):
2288
2289
2290
2291
2292
2293
2294
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
2295
        num_heads: int,
2296
2297
2298
2299
2300
2301
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        workspace: torch.Tensor,
        is_training: bool,
2302
        mask_type: str,
2303
        quantizers: list[Quantizer],
2304
    ) -> torch.Tensor:
2305
        qkv_dtype = inp.dtype
2306
2307
2308

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
2309
        h = num_heads
2310
2311
2312
        d = in_features // h
        b = cu_seqlens.numel() - 1

2313
2314
2315
2316
2317
2318
2319
2320
        input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
        qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
        qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
        dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
        dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
2321

2322
        inp_fp8 = input_quantizer(inp)
2323

2324
        qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight)
2325

2326
        qkv, *_ = ext.general_gemm(
2327
            qkv_weight_fp8,
2328
            inp_fp8,
2329
2330
            workspace,
            bias=qkv_bias,
2331
2332
            out_dtype=qkv_weight_fp8.dtype,
            quantization_params=qkv_quantizer,
2333
2334
            use_split_accumulator=_2X_ACC_FPROP,
        )
2335
        qkv = qkv.view(-1, 3, h, d)
2336
        qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
2337
        torch.save(qkv_fp16, "qkv.pt")
2338
        if cudnn_frontend_version == 1:
2339
            qkv = qkv.view(b, max_s, 3, h, d)  # bs3hd
2340
2341

        # FMHA
2342
2343
2344
2345
2346
2347
2348
2349
        q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
        k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
        v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
        q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
        k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
        v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)

        out, aux_ctx_tensors = fused_attn_fwd(
2350
2351
2352
2353
2354
            is_training,
            max_s,
            max_s,
            cu_seqlens,
            cu_seqlens,
2355
2356
2357
2358
            q,
            k,
            v,
            qkv_dtype,
2359
2360
2361
2362
2363
2364
2365
2366
            FusedAttnBackend["FP8"],
            attn_scale=None,
            dropout=p_dropout,
            fast_zero_fill=fast_zero_fill,
            qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
            attn_bias_type="no_bias",
            attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
            rng_gen=None,
2367
2368
            o_quantizer=o_quantizer,
            s_quantizer=s_quantizer,
2369
        )
2370

2371
2372
        tensors_to_save, tensor_objects = prepare_for_saving(
            q, k, v, inp_fp8, qkv_weight_fp8, workspace, out
2373
        )
2374
2375
2376

        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
2377
        ctx.aux_ctx_tensors = aux_ctx_tensors
2378
        ctx.qkv_dtype = qkv_dtype
2379
2380
2381
2382
2383
2384
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.hidden_size = in_features
2385
        ctx.num_heads = num_heads
2386
2387
        ctx.mask_type = mask_type
        ctx.dtype = inp.dtype
2388

2389
2390
2391
2392
2393
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = s_quantizer

2394
        out = out.view(-1, in_features)  # (bs)(hd)
2395
        out_fp16 = out.dequantize()
2396
        torch.save(out_fp16, "out.pt")  # (bs)(hd)
2397
        return out_fp16
2398
2399

    @staticmethod
2400
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
2401
        with torch.cuda.nvtx.range("_DPA"):
2402
2403
2404
2405
            saved_tensors = ctx.saved_tensors
            (q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved(
                ctx.tensor_objects, saved_tensors
            )
2406

2407
2408
            proj_dgrad = ctx.dO_quantizer(grad_output)
            fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2409

2410
            dq, dk, dv, *rest = fused_attn_bwd(
2411
2412
2413
2414
                ctx.max_s,
                ctx.max_s,
                ctx.cu_seqlens,
                ctx.cu_seqlens,
2415
2416
2417
                q,
                k,
                v,
2418
2419
                out,
                proj_dgrad.view_as(out),
2420
                ctx.qkv_dtype,
2421
2422
2423
2424
2425
                fp8_dtype_backward,
                ctx.aux_ctx_tensors,
                FusedAttnBackend["FP8"],
                None,
                None,
2426
2427
2428
                ctx.S_quantizer,
                ctx.dP_quantizer,
                ctx.dQKV_quantizer,
2429
2430
2431
2432
2433
2434
2435
                attn_scale=None,
                dropout=ctx.p_dropout,
                fast_zero_fill=ctx.fast_zero_fill,
                qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
                attn_bias_type="no_bias",
                attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
            )
2436
            dim = 2 if cudnn_frontend_version == 1 else 1
2437
2438
            dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
            dqkv_shape = list(dq._data.shape)
2439
            dqkv_shape.insert(dim, 3)
2440
            dqkv_stride = list(dq._data.stride())
2441
            dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
2442
2443
2444
            dqkv.set_(
                dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride
            )  # bs3hd
2445

2446
            dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
2447
2448
            dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape)
            dqkv_c_fp16 = dqkv_c.dequantize()
2449
            torch.save(dqkv_c_fp16, "dqkv.pt")
2450

2451
2452
2453
            qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer)
            dqkv_c._transpose = None
            dqkv_c._create_transpose()
2454
2455

            # QKV DGRAD
2456
2457
            qkv_dgrad, *_ = ext.general_gemm(
                qkv_weight_fp8,
2458
                dqkv_c,
2459
                workspace,
2460
                ctx.dtype,
2461
                use_split_accumulator=_2X_ACC_DGRAD,
2462
                layout="NN",
2463
            )
2464

2465
            # QKV WGRAD
2466
2467
2468
            qkv_wgrad, *_ = ext.general_gemm(
                inp_fp8,
                dqkv,
2469
                workspace,
2470
                ctx.dtype,
2471
                use_split_accumulator=_2X_ACC_WGRAD,
2472
                layout="NT",
2473
2474
            )

2475
2476
        return (
            qkv_dgrad,
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2488
2489
            None,
        )
2490

2491

2492
class Custom_MHA_FP8(TransformerEngineBaseModule):
2493
    def __init__(self, config, params_dtype: torch.dtype = torch.float32):
2494
2495
        super().__init__()
        self.p_dropout = config.dropout_p
2496
        self.h = config.num_heads
2497
        self.hidden_size = config.hidden_size
2498
        self.head_dim = config.head_dim_qk
2499
        self.fast_zero_fill = True
2500
        self.mask_type = config.attn_mask_type
2501

Tim Moon's avatar
Tim Moon committed
2502
        self.qkv_weight = torch.nn.Parameter(
2503
2504
2505
2506
2507
2508
2509
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
Tim Moon's avatar
Tim Moon committed
2510
        self.qkv_bias = torch.nn.Parameter(
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)
        self.workspace = torch.empty(
            _CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
        )

    def forward(
2525
2526
2527
2528
        self,
        inp: torch.Tensor,
        cu_seqlens,
        max_s,
2529
    ) -> torch.Tensor:
2530
        with self.prepare_forward(inp, num_gemms=3) as inp:
2531
            out = _custom_mha_fp8.apply(
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
                self.workspace,
2542
                self.training,
2543
                self.mask_type,
2544
                self.quantizers,
2545
            )
2546
        return out