test_attention.py 106 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
import logging
Tim Moon's avatar
Tim Moon committed
5
import os
6
7
import sys
import pathlib
8
from typing import Any, Dict, Tuple, Union
Tim Moon's avatar
Tim Moon committed
9

10
import pytest
Tim Moon's avatar
Tim Moon committed
11
import torch
12

13
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype
Tim Moon's avatar
Tim Moon committed
14
from transformer_engine.common import recipe
15
16
17
18
from transformer_engine.pytorch import (
    TransformerLayer,
    autocast,
    quantized_model_init,
Tim Moon's avatar
Tim Moon committed
19
    DotProductAttention,
20
21
22
23
24
25
26
    MultiheadAttention,
    get_device_compute_capability,
    Quantizer,
    is_fp8_available,
    is_bf16_available,
)
from transformer_engine.pytorch.attention.dot_product_attention import (
27
28
    _attention_backends,
)
29
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
30
    FlashAttentionUtils,
31
    check_set_window_size,
Tim Moon's avatar
Tim Moon committed
32
)
33
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
Tim Moon's avatar
Tim Moon committed
34
35
36
37
38
39
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    FusedAttnBackend,
    fused_attn_bwd,
    fused_attn_fwd,
)
40
41
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
42
43
44
45
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
46
from transformer_engine.pytorch.utils import get_cudnn_version
47
import transformer_engine_torch as tex
48
49
from transformer_engine.pytorch.quantized_tensor import (
    Quantizer,
50
51
52
    prepare_for_saving,
    restore_from_saved,
)
wenjh's avatar
wenjh committed
53
from torch.utils.cpp_extension import IS_HIP_EXTENSION
54

55
56
57
58
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
    reset_rng_states,
59
    compare_and_assert,
60
61
62
63
64
    ModelConfig,
    dtype_tols,
    get_available_attention_backends,
)

65
# Check if hardware supports FP8 attention.
66
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
67
68
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability()
wenjh's avatar
wenjh committed
69
70
71
72
73
74
if not IS_HIP_EXTENSION:
    if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
        fp8_attn_available = False
        reason_for_no_fp8_attn = (
            "FP8 attention is not supported for compute capability ="
            f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
75
    )
76

77
78
79
80
81
82
83
84

# Get determinism
_deterministic = (
    not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
    or torch.are_deterministic_algorithms_enabled()
)


85
# Reset RNG seed and states
86
seed = 1234
87
reset_rng_states()
88
89


90
# Reset FP8 global state manager
91
92
93
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
94
    FP8GlobalStateManager.reset()
95
96


97
98
# Define F16 data types to test
param_types = [torch.float16]
99
if is_bf16_available():
100
101
    param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
102

103
model_configs_base = {
104
    # test: ModelConfig(b, sq, hq, dqk)
105
106
107
108
109
110
111
112
113
114
115
116
    "base_1_0": ModelConfig(8, 128, 16, 64),
    "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
    "base_2_0": ModelConfig(2, 2048, 24, 128),
    "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
    "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
    "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
    "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
    "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
    "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
    "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
    "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
117
118
}

119

120
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
121
@pytest.mark.parametrize("dtype", param_types)
122
123
124
125
126
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
127
@pytest.mark.parametrize("swa", [False])
128
@pytest.mark.parametrize("pad_between_seqs", [False])
129
def test_dot_product_attention(
130
131
132
133
134
135
136
137
    dtype,
    model_configs,
    model,
    ckpt_attn,
    workspace_opt,
    qkv_layout,
    swa,
    pad_between_seqs,
138
):
139
    """Test DotProductAttention module"""
140

Tim Moon's avatar
Tim Moon committed
141
    # Get configs
142
    tols = dict(atol=1e-3, rtol=1e-3)
Tim Moon's avatar
Tim Moon committed
143
    if dtype == torch.bfloat16:
144
        tols = dict(atol=1.5e-2, rtol=1.5e-2)
145
    config = model_configs[model]
146
    is_mla = config.head_dim_qk != config.head_dim_v
147
    is_mqa_gqa = config.num_heads != config.num_gqa_groups
148
149
    if qkv_layout is None:
        if config.attn_type == "self":
150
            qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
151
        else:
152
            qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
153
    if "3" in qkv_layout and config.attn_type == "cross":
154
        pytest.skip("No need to test this layout for cross attention")
Tim Moon's avatar
Tim Moon committed
155

156
157
    if config.window_size == (-1, -1) and swa:
        config.window_size = [2, 2]
158

159
    config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
160
161
162
163
164
    qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        config.attn_mask_type = (
            "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
        )
165

166
    # Get backends
167
    is_training = True
168
    available_backends, _, fused_attn_backends = get_available_attention_backends(
169
        config,
170
        qkv_dtype=dtype,
171
        qkv_layout=qkv_layout,
172
        pad_between_seqs=pad_between_seqs,
173
        is_training=is_training,
174
        deterministic=_deterministic,
Tim Moon's avatar
Tim Moon committed
175
    )
176
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
177

178
179
    if not fused_attn_supported:
        is_training = False
180
        available_backends, _, fused_attn_backends = get_available_attention_backends(
181
182
183
184
185
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            pad_between_seqs=pad_between_seqs,
            is_training=is_training,
186
            deterministic=_deterministic,
187
188
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
189

190
191
    # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
    # mannually pads and unpads the input and output of FlashAttention for testing purposes
192
193
    if (
        pad_between_seqs
194
        and FlashAttentionUtils.is_installed
195
196
197
198
        and not (
            config.max_seqlen_q != config.max_seqlen_kv
            and config.attn_mask_type in ["causal", "padding_causal"]
        )
199
        and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
200
    ):
201
        flash_attn_supported = True
202
203
204

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
205
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
206
207

    # UnfusedDotProductAttention backend
208
    if unfused_attn_supported:
209
        unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
210
211
212
213
214
215
216
217
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
218
        )
Tim Moon's avatar
Tim Moon committed
219
220
221

    # FusedAttention backend
    if fused_attn_supported:
222
        if len(fused_attn_backends) == 1:
223
            fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
224
225
226
227
228
229
230
231
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
232
            )
233
        if len(fused_attn_backends) == 2:
234
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
235
            fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
236
237
238
239
240
241
242
243
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
244
245
            )
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
246
            fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
247
248
249
250
251
252
253
254
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
255
            )
256

Tim Moon's avatar
Tim Moon committed
257
258
    # FlashAttention backend
    if flash_attn_supported:
259
        flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
260
261
262
263
264
265
266
267
            dtype,
            config,
            "FlashAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
Tim Moon's avatar
Tim Moon committed
268
        )
269

270
    # Compare results
271
    logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
272
    if unfused_attn_supported and flash_attn_supported:
273
        logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
274
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
275
        for i, _ in enumerate(flash_attn_bwd):
276
            torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
277
278
279
    if unfused_attn_supported and fused_attn_supported:
        logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
280
281
        if config.return_max_logit:
            torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
282
283
        for i, _ in enumerate(unfused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
284
    if fused_attn_supported and flash_attn_supported:
285
        logging.info("[test_dot_product_attention]: fused attn vs flash attn")
286
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
287
        for i, _ in enumerate(flash_attn_bwd):
288
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
289
    if fused_attn_supported and len(fused_attn_backends) == 2:
290
        logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
291
        torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
292
        for i, _ in enumerate(fused_attn_bwd):
293
294
            torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)

295

296
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
297
298
299
300
301
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
    """Test DotProductAttention module with checkpointing"""
302
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
303

304

305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
model_configs_max_logit = {
    # test: ModelConfig(b, sq, hq, dqk)
    "max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
    "max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "max_logit_4": ModelConfig(
        8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
    ),
    "max_logit_5": ModelConfig(
        8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
    ),
    "max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with checkpointing"""
    config = model_configs[model]
    config.return_max_logit = True
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
model_configs_num_splits = {
    # test: ModelConfig(b, sq, hq, dqk)
    "num_splits_1_0": ModelConfig(2, 2048, 24, 128, num_splits=2),
    "num_splits_1_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_num_splits])
@pytest.mark.parametrize("model", model_configs_num_splits.keys())
def test_dpa_num_splits(dtype, model_configs, model):
    """Test DotProductAttention with FlashAttention-3 num_splits enabled"""
    test_dot_product_attention(
        dtype,
        model_configs,
        model,
        False,
        True,
        None,
        False,
        False,
    )


357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
model_configs_softmax = {
    # test: ModelConfig(b, sq, hq, dqk)
    "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
    "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"),
    "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"),
    "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
    "softmax_2_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
    ),
    "softmax_2_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
    ),
    "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"),
    "softmax_3_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one"
    ),
    "softmax_3_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable"
    ),
    "softmax_4_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal"
    ),
    "softmax_4_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="off-by-one",
    ),
    "softmax_4_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="learnable",
    ),
    "softmax_5_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal"
    ),
    "softmax_5_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="off-by-one",
    ),
    "softmax_5_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="learnable",
    ),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax(dtype, model_configs, model):
    """Test DotProductAttention module with different softmax types"""
    test_dot_product_attention(
        dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False
    )


436
437
438
439
440
441
442
443
444
@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax_thd(dtype, model_configs, model):
    """Test DotProductAttention module with different softmax types"""
    test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False)


445
model_configs_mla = {
446
    #TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
447
    #    test:             b,  h, hg, dqk, sq, skv,   p,      mask,      bias   # attn , backend
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    # "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),  # self , 0
    # "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    # "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    # "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),  # self , 1
    # "mla_2_1": ModelConfig(
    #     1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
    # ),  # cross, 1
    # "mla_2_2": ModelConfig(
    #     1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
    # ),  # cross, 1
    # "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),  # inference
    # "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),  # inference
    # "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),  # inference
    # "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),  # inference
462
    "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),  # inference
463
464
465
466
467
468
469
470
471
472
473
474
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
    """Test DotProductAttention module with Multi-Latent Attention (MLA)"""
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


475
model_configs_mask = {
476
    # test: ModelConfig(b, sq, hq, dqk)
477
478
479
480
481
482
    "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
    "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "mask_2_1": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
483
    ),
484
485
486
487
488
489
490
491
492
493
    "mask_2_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
    ),
    "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
494
    "mask_5_1": ModelConfig(
495
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
496
497
    ),
    "mask_5_2": ModelConfig(
498
499
500
501
502
503
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_7_0": ModelConfig(
        2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
504
    ),
505
506
507
508
509
510
511
    "mask_7_1": ModelConfig(
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
    ),
    "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
512
    "mask_10_0": ModelConfig(
513
        2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
514
    ),
515
    "mask_10_1": ModelConfig(
516
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
517
    ),
518
}
519

520

521
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
522
523
524
525
526
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
    """Test DotProductAttention module with different mask types"""
527
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
528

529

530
model_configs_bias = {
531
    # test: ModelConfig(b, sq, hq, dqk)
532
533
534
535
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
536
537
    "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),
    "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"),
538
539
    "bias_2_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
540
    ),
541
542
543
544
545
546
547
548
    "bias_2_1": ModelConfig(
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
549
    ),
550
    "bias_2_2": ModelConfig(
551
        4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
552
    ),
553
    "bias_2_3": ModelConfig(
554
555
556
557
558
559
560
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
561
562
    ),
    "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"),
563
564
    "bias_2_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
565
    ),
566
567
568
569
570
571
572
573
574
    "bias_3_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_1": ModelConfig(
        2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_2": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
575
    "bias_3_3": ModelConfig(
576
577
578
579
580
581
582
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="causal",
        attn_bias_type="post_scale_bias",
583
    ),
584
585
586
    "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
    "bias_3_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
587
    ),
588
    "bias_4_0": ModelConfig(
589
        4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
590
    ),
591
    "bias_4_1": ModelConfig(
592
593
594
595
596
597
598
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
599
    ),
600
    "bias_4_2": ModelConfig(
601
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
602
    ),
603
    "bias_4_3": ModelConfig(
604
605
606
607
608
609
610
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
611
    ),
612
613
    "bias_4_4": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
614
    ),
615
616
617
618
619
620
621
622
    "bias_4_5": ModelConfig(
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="alibi",
623
    ),
624
}
625

626

627
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
628
629
630
631
632
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types"""
633
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
634

635

636
model_configs_bias_shapes = {
637
    # test: ModelConfig(b, sq, hq, dqk)
638
639
640
641
642
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
    "bias_1_4": ModelConfig(
643
        4,
644
645
        2048,
        24,
646
        128,
647
648
649
650
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="1hss",
        alibi_type="custom",
651
652
    ),
    "bias_1_5": ModelConfig(
653
654
655
656
657
658
659
660
        2,
        2048,
        24,
        128,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="bhss",
        alibi_type="custom",
661
    ),
662
663
}

664

665
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
666
667
668
669
670
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types and shapes"""
671
672
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)

673

674
model_configs_swa = {
675
    # test: ModelConfig(b, sq, hq, dqk)
676
677
678
679
680
681
682
683
684
685
686
687
    "swa_1_1": ModelConfig(2, 2048, 16, 64),
    "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
    "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
    "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "swa_3_2": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
    ),
    "swa_3_3": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
688
    ),
689
690
691
692
693
694
695
    "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
    "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
696
    "swa_6_2": ModelConfig(
697
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
698
699
    ),
    "swa_6_3": ModelConfig(
700
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
701
    ),
702
}
703
704


705
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
706
707
708
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
709
710
@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"])
def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout):
711
    """Test DotProductAttention module with sliding window attention"""
712
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)
713

714

715
model_configs_alibi_slopes = {
716
    # test: ModelConfig(b, sq, hq, dqk)
717
718
719
720
721
722
723
724
725
726
727
728
729
    "alibi_1_0": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
    ),
    "alibi_1_1": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="vanilla",
    ),
730
    "alibi_2_0": ModelConfig(
731
        2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
732
733
    ),
    "alibi_2_1": ModelConfig(
734
735
736
737
738
739
740
741
        1,
        1024,
        24,
        128,
        max_seqlen_kv=2048,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="custom",
742
    ),
743
}
744
745


746
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
747
748
749
750
751
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
    """Test DotProductAttention module with ALiBi slopes"""
752
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
753

754

755
qkv_layouts = [
756
757
758
759
760
761
762
763
764
765
766
    "sb3hd",
    "sbh3d",
    "sbhd_sb2hd",
    "sbhd_sbh2d",
    "sbhd_sbhd_sbhd",
    "bs3hd",
    "bsh3d",
    "bshd_bs2hd",
    "bshd_bsh2d",
    "bshd_bshd_bshd",
]
767

768

769
model_configs_layout = {
770
    # test: ModelConfig(b, sq, hq, dqk)
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
    "layout_0_0": ModelConfig(2, 128, 16, 64),
    "layout_0_1": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "layout_0_3": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_1_0": ModelConfig(2, 2048, 24, 128),
    "layout_1_1": ModelConfig(
        2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_3": ModelConfig(
        1,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
    "layout_2_1": ModelConfig(
        2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
803
804
}

805

806
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
807
@pytest.mark.parametrize("dtype", param_types_lean)
808
809
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
810
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
811
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
812
    """Test DotProductAttention module with different QKV layouts"""
813
814
815
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


816
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
817
model_configs_layout_thd = {
818
    # test: ModelConfig(b, sq, hq, dqk)
819
820
821
822
823
824
825
    "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "layout_1_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
826
    ),
827
    "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
828
    "layout_2_1": ModelConfig(
829
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
830
831
    ),
    "layout_2_2": ModelConfig(
832
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
833
    ),
834
    "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
835
    "layout_3_1": ModelConfig(
836
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
837
838
    ),
    "layout_3_2": ModelConfig(
839
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
840
    ),
841
    "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
842
    "layout_4_1": ModelConfig(
843
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
844
845
    ),
    "layout_4_2": ModelConfig(
846
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
847
848
    ),
    "layout_5_0": ModelConfig(
849
        2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
850
851
    ),
    "layout_5_1": ModelConfig(
852
853
854
855
856
857
858
        2,
        2048,
        24,
        128,
        num_gqa_groups=1,
        attn_mask_type="padding_causal_bottom_right",
        window_size=(4, 0),
859
860
861
    ),
    "layout_5_2": ModelConfig(
        2,
862
        2048,
863
864
        24,
        128,
865
866
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal_bottom_right",
867
868
        window_size=(4, 0),
    ),
869
870
871
}


872
873
874
875
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
    get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
876
877
878
879
880
881
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with different QKV layouts"""
882
883
884
    config = model_configs[model]
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")
885
    logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
886
    pad_between_seqs = True
887
888
889
    test_dot_product_attention(
        dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
    )
890
    if get_cudnn_version() >= (9, 3, 0):
891
        logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
892
893
894
895
896
        # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
        pad_between_seqs = False
        test_dot_product_attention(
            dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
        )
897

898

899
def _run_dot_product_attention(
900
901
902
903
904
905
906
907
908
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_layout: str,
    workspace_opt: bool,
    pad_between_seqs: bool,
    is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
909
910
911
    """Run DotProductAttention module with one forward pass and one backward pass"""
    # Set RNG and environment varables
    reset_rng_states()
912
913
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
914
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
915
916
917
918
919
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
920
921
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
922
    _attention_backends["backend_selection_requires_update"] = True
923

924
    # Create seqlens
925
926
927
928
929
930
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
931
            seqlens_kv = seqlens_q
932
        if config.attn_type == "cross":
933
934
935
936
937
938
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
939
940
941
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
942
    else:
943
944
945
946
947
948
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
949
950
951
952
953
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

954
955
956
957
958
959
960
    seqlens_q_after_pad = seqlens_q.clone()
    seqlens_kv_after_pad = seqlens_kv.clone()
    cu_seqlens_q_after_pad = cu_seqlens_q.clone()
    cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
    pad_len = [0] * config.batch_size
    if pad_between_seqs:
        max_pad_len = 3
961
        pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda")  # 3
962
963
964
965
966
        seqlens_q_after_pad = seqlens_q + pad_len
        seqlens_kv_after_pad = seqlens_kv + pad_len
        cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
        cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)

967
968
969
    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
970
        if config.attn_type == "self":
971
972
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
973
974
975
976
977
978
979
980
981
982
983
984
985
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
986
            attention_mask = attention_mask_q.to(device="cuda")
987
        if config.attn_type == "cross":
988
989
990
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
                attention_mask_kv = torch.cat(
                    [
                        attention_mask_kv,
                        torch.Tensor(
                            [False] * seqlens_kv[i]
                            + [True] * (config.max_seqlen_kv - seqlens_kv[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
1018
            attention_mask = (
1019
1020
1021
                attention_mask_q.to(device="cuda"),
                attention_mask_kv.to(device="cuda"),
            )
1022

1023
    alibi_slopes = None
1024
1025
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
        if config.bias_shape == "1hss":
1026
1027
1028
            alibi_slopes = (
                torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
            )
1029
        if config.bias_shape == "bhss":
1030
1031
1032
1033
1034
            alibi_slopes = (
                torch.randn(config.batch_size, config.num_heads)
                .abs()
                .to(dtype=torch.float32, device="cuda")
            )
1035

1036
1037
    # Create input tensors
    dim_to_num = {
1038
1039
1040
1041
1042
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
1043
1044
        "dqk": config.head_dim_qk,
        "dv": config.head_dim_v,
1045
1046
1047
1048
1049
1050
        "t": cu_seqlens_q_after_pad[-1],
        "tg": cu_seqlens_kv_after_pad[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
1051
    inp = []
1052
    inp_orig = []
1053
1054
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
1055
        if i == 0:
1056
            layout = layout.replace("s", "sq")
1057
        else:
1058
1059
1060
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
1061
1062
1063
1064
        if i == 2:
            layout = layout.replace("d", "dv")
        else:
            layout = layout.replace("d", "dqk")
1065
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
1066
        tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
1067
1068
        # tensor: with padding tokens
        # tensor_orig: without padding tokens
1069
        tensor_orig = tensor
1070
1071
        if qkv_format == "thd" and pad_between_seqs:
            tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1072
            if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_q_after_pad[i - 1],
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_q_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
1086
            if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_kv_after_pad[i - 1],
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_kv_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
1100
1101
        tensor_count = 1
        split_dim = 0
1102
        for dim, l in enumerate(layout.split("_")):
1103
1104
1105
1106
1107
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
1108
1109
1110
        tensors_orig = (
            torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
        )
1111
1112
1113
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
1114
                inp_orig.append(tensors_orig[j].squeeze(split_dim))
1115
1116
            else:
                inp.append(tensors[j])
1117
                inp_orig.append(tensors_orig[j])
1118
    for i in range(3):
1119
        inp[i].requires_grad = True
1120
1121
        inp_orig[i].requires_grad = True

1122
    # Create output gradient
1123
1124
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
1125
    qkv_format_kv = qkv_format_kv.replace("d", "dv")
1126
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
1127
1128
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
    out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
1129
    out_grad_orig = out_grad
1130
1131
    if qkv_format == "thd" and pad_between_seqs:
        out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1132
        if qkv_format_kv == "t_h_dv":
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
            for i in range(1, config.batch_size + 1):
                valid_range = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
                out_grad[pad_range[0] : pad_range[1]] = 0.0
                out_grad_orig = torch.cat(
                    [out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
                )
1143

1144
    # Create bias
1145
    if config.attn_bias_type in ["no_bias", "alibi"]:
1146
        bias = None
1147
1148
1149
1150
    if config.attn_bias_type == "post_scale_bias":
        shape = "_".join(config.bias_shape)
        shape = shape.replace("_s_s", "_sq_skv")
        tensor_shape = [dim_to_num[j] for j in shape.split("_")]
1151
        bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
1152
        if config.bias_shape != "1hss":
1153
            bias.requires_grad = False
1154
1155
1156
1157

    # Create RNG
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1158

1159
1160
1161
1162
1163
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    # Set up model
1164
1165
    block = DotProductAttention(
        config.num_heads,
1166
        (config.head_dim_qk, config.head_dim_v),
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
        num_gqa_groups=config.num_gqa_groups,
        attention_dropout=config.dropout_p,
        qkv_format=qkv_format,
        attn_mask_type=config.attn_mask_type,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type=config.attn_type,
1177
        softmax_type=config.softmax_type,
1178
        return_max_logit=config.return_max_logit,
1179
    ).to(dtype=dtype, device="cuda")
1180
1181
    if not is_training:
        block = block.eval()
1182
1183
    if is_training and config.softmax_type != "vanilla":
        block.softmax_offset.requires_grad = True
1184

1185
    # Run a forward and backward pass
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        q = inp_orig[0]
        k = inp_orig[1]
        v = inp_orig[2]
        d_out = out_grad_orig
    if backend == "FusedAttention":
        q = inp[0]
        k = inp[1]
        v = inp[2]
        d_out = out_grad
1196
1197
1198
1199
    out = block(
        q,
        k,
        v,
1200
        window_size=config.window_size,
1201
1202
1203
1204
1205
1206
        attention_mask=attention_mask,
        qkv_format=qkv_format,
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1207
1208
        cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
        cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
1209
1210
1211
1212
1213
1214
        attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=config.attn_bias_type,
        core_attention_bias=bias,
        alibi_slopes=alibi_slopes,
        fast_zero_fill=True,
1215
1216
        # Only pass num_splits when exercising the FlashAttention path
        num_splits=config.num_splits if backend == "FlashAttention" else 1,
1217
    )
1218
1219
1220
    max_logit = None
    if config.return_max_logit:
        out, max_logit = out
1221
1222
    if is_training:
        out.backward(d_out)
1223

1224
1225
1226
    d_softmax_offset = None
    if is_training and config.softmax_type != "vanilla":
        d_softmax_offset = block.softmax_offset.grad
1227

1228
1229
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        if is_training:
1230
            return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
1231
        else:
1232
            return out, max_logit, (None, None, None, d_softmax_offset)
1233
    if backend == "FusedAttention":
1234
1235
        if qkv_format == "thd" and pad_between_seqs:
            out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1236
1237
1238
1239
            if is_training:
                q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
            for i in range(1, config.batch_size + 1):
                valid_range_q = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                valid_range_kv = (
                    cu_seqlens_kv_after_pad[i - 1],
                    cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                )
                out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
                if is_training:
                    q_grad_orig = torch.cat(
                        [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
                    )
                    k_grad_orig = torch.cat(
                        [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
                    v_grad_orig = torch.cat(
                        [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
1260
            if is_training:
1261
1262
1263
1264
1265
                return (
                    out_orig,
                    max_logit,
                    (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
                )
1266
            else:
1267
                return out_orig, max_logit, (None, None, None, d_softmax_offset)
1268
1269
        else:
            if is_training:
1270
                return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
1271
            else:
1272
                return out, max_logit, (None, None, None, d_softmax_offset)
1273

1274

1275
model_configs_te_layer = {
1276
    # test: ModelConfig(b, sq, hq, dqk)
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
    "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "te_1_1": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "te_1_2": ModelConfig(
        2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
    ),
    "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
    "te_2_1": ModelConfig(2, 2048, 16, 64),
    "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
    "te_2_3": ModelConfig(
        1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
    "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
1293
}
1294

1295

1296
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1297
@pytest.mark.parametrize("dtype", param_types)
1298
1299
1300
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
1301
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
1302
1303
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
1304
1305
1306
def test_transformer_layer(
    dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
1307
    """Test TransformerLayer module"""
1308

Tim Moon's avatar
Tim Moon committed
1309
    # Get configs
1310
    config = model_configs[model]
1311
    tols = dict(atol=5e-2, rtol=5e-2)
1312
    workspace_opt = True
1313

1314
    # Test backend availability
1315
    is_training = True
1316
    available_backends, _, fused_attn_backends = get_available_attention_backends(
Tim Moon's avatar
Tim Moon committed
1317
        config,
1318
        qkv_dtype=dtype,
1319
1320
1321
        qkv_layout=(
            qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
        ),
1322
        is_training=is_training,
1323
        deterministic=_deterministic,
Tim Moon's avatar
Tim Moon committed
1324
    )
1325
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1326
1327
    if not fused_attn_supported:
        is_training = False
1328
        available_backends, _, fused_attn_backends = get_available_attention_backends(
1329
1330
1331
1332
1333
1334
1335
1336
            config,
            qkv_dtype=dtype,
            qkv_layout=(
                qkv_format.replace("hd", "h3d")
                if fused_qkv_params
                else qkv_format.replace("hd", "3hd")
            ),
            is_training=is_training,
1337
            deterministic=_deterministic,
1338
1339
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1340
1341
1342

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
1343
        pytest.skip("Less than two backends to compare.")
1344
1345
1346
    # Skip if qkv_format = thd and "padding" not in attn_mask_type
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        pytest.skip("THD requires padding mask.")
Tim Moon's avatar
Tim Moon committed
1347
1348

    # UnfusedDotProductAttention backend
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_format,
            workspace_opt,
            fused_qkv_params,
            RoPE,
1359
            is_training,
1360
        )
Tim Moon's avatar
Tim Moon committed
1361
1362
1363
1364
1365
1366
1367

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FusedAttention",
1368
1369
1370
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1371
1372
            fused_qkv_params,
            RoPE,
1373
            is_training,
Tim Moon's avatar
Tim Moon committed
1374
        )
1375

Tim Moon's avatar
Tim Moon committed
1376
1377
1378
1379
1380
1381
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FlashAttention",
1382
1383
1384
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1385
1386
            fused_qkv_params,
            RoPE,
1387
            is_training,
Tim Moon's avatar
Tim Moon committed
1388
        )
1389

1390
    logging.info(f"[test_transformer_layer]: is_training = {is_training}")
1391
    if unfused_attn_supported and fused_attn_supported:
1392
        logging.info("[test_transformer_layer]: unfused attn vs fused attn")
1393
1394
1395
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
    if unfused_attn_supported and flash_attn_supported:
1396
        logging.info("[test_transformer_layer]: unfused attn vs flash attn")
Tim Moon's avatar
Tim Moon committed
1397
1398
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
1399
    if fused_attn_supported and flash_attn_supported:
1400
        logging.info("[test_transformer_layer]: fused attn vs flash attn")
1401
1402
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
1403

1404

1405
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1406
@pytest.mark.parametrize("dtype", param_types_lean)
1407
1408
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
1409
1410
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format):
hugo-syn's avatar
hugo-syn committed
1411
    """Test TransformerLayer module with miscellaneous settings"""
1412
1413
1414
    ckpt_attn = True
    fused_qkv_params = True
    RoPE = True
1415
1416
1417
    test_transformer_layer(
        dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
    )
1418

1419

1420
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1421
1422
1423
1424
1425
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
    """Test TransformerLayer module with MQA/GQA"""
1426

1427
    def find_factors(x):
1428
1429
1430
1431
1432
        f = []
        for i in range(2, x + 1):
            if x % i == 0:
                f.append(i)
        return f
1433

1434
1435
1436
1437
1438
1439
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    config = model_configs[model]
    num_querys_per_gqa_group = find_factors(config.num_heads)
1440
1441

    for num_q_per_gqa_group in num_querys_per_gqa_group:
1442
1443
1444
1445
        config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
        test_transformer_layer(
            dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
        )
1446

1447

1448
def _run_transformer_layer(
1449
1450
1451
1452
1453
1454
1455
1456
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_format: str,
    workspace_opt: bool,
    fused_qkv_params: bool,
    RoPE: bool,
1457
    is_training: bool,
1458
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
1459
1460
1461
    """Run TransformerLayer module with one forward pass and one backward pass"""

    # Set RNG and environment variables
1462
    reset_rng_states()
1463
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
1464
    os.environ["NVTE_FUSED_ATTN"] = "0"
1465
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
1466
1467
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
1468
1469
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
1470
1471
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
1472
    _attention_backends["backend_selection_requires_update"] = True
1473

1474
    # Create input tensor
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
    if qkv_format == "sbhd":
        inp = torch.randn(
            config.max_seqlen_q,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.max_seqlen_kv,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1492
    if qkv_format == "bshd":
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
        inp = torch.randn(
            config.batch_size,
            config.max_seqlen_q,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.batch_size,
            config.max_seqlen_kv,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1509
1510

    # Create seqlens
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
1527
    else:
1528
1529
1530
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
    if qkv_format == "thd":
        inp = torch.randn(
            cu_seqlens_q[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            cu_seqlens_kv[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1553
1554
1555
1556
1557
1558
1559

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
1560
    drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
1561

1562
    # Create bias
1563
    bias = None
1564
1565
1566
1567
1568
1569
1570
1571
1572
    if config.attn_bias_type == "post_scale_bias":
        bias = torch.randn(
            1,
            config.num_heads,
            config.max_seqlen_q,
            config.max_seqlen_kv,
            dtype=dtype,
            device="cuda",
        )
1573
1574
1575
1576

    # Create RoPE
    rotary_pos_emb = None
    if RoPE:
1577
        PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
1578
        rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1579
1580

    # Set up model
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_heads,
        num_gqa_groups=config.num_gqa_groups,
        layernorm_epsilon=1e-5,
        hidden_dropout=0.0,
        attention_dropout=config.dropout_p,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        layer_number=layer_number,
1592
        kv_channels=config.head_dim_qk,
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
        self_attn_mask_type=config.attn_mask_type,
        tp_group=None,
        tp_size=1,
        params_dtype=dtype,
        get_rng_state_tracker=None,
        fuse_wgrad_accumulation=False,
        seq_length=config.max_seqlen_q,
        micro_batch_size=config.batch_size,
        sequence_parallel=False,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
1604
        layer_type="encoder" if config.attn_type == "self" else "decoder",
1605
1606
1607
1608
1609
1610
1611
1612
1613
        drop_path_rate=drop_path_rates[layer_number - 1],
        set_parallel_mode=True,
        fuse_qkv_params=fused_qkv_params,
        zero_centered_gamma=False,
        qkv_weight_interleaved=False,
        ub_tp_comm_overlap=False,
        bias=True,
        attn_input_format=qkv_format,
    ).to(dtype=dtype, device="cuda")
1614
1615
    if not is_training:
        block = block.eval()
1616

1617
1618
1619
    # Create ALiBi slopes
    alibi_slopes = None
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
1620
        alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
1621

1622
    # Run a forward and backward pass
1623
1624
    out = block(
        inp,
1625
        self_attn_mask_type=config.attn_mask_type,
1626
1627
        encoder_output=inp_enc if config.attn_type == "cross" else None,
        enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
1628
1629
1630
        checkpoint_core_attention=False,
        rotary_pos_emb=rotary_pos_emb,
        core_attention_bias_type=config.attn_bias_type,
1631
        core_attention_bias=bias,
1632
        alibi_slopes=alibi_slopes,
1633
1634
1635
1636
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1637
    )
1638
1639
1640
    if is_training:
        loss = out.sum()
        loss.backward()
1641
1642

    return out, inp.grad
1643
1644


1645
model_configs_fp8_extra_state = {
1646
    # test: ModelConfig(b, sq, hq, dqk)
1647
1648
1649
1650
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
}


1651
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1652
1653
1654
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
1655
1656
def test_dpa_fp8_extra_state(model, dtype):
    """Test DotProductAttention module in FP8 with checkpointing"""
1657
1658
1659
1660
1661
1662
1663
1664
    config = model_configs_fp8_extra_state[model]
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="sb3hd",
        is_training=is_training,
1665
        deterministic=_deterministic,
1666
1667
1668
1669
1670
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not fused_attn_supported and not flash_attn_supported:
        pytest.skip("No attention backend available.")

1671
1672
1673
    outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False)
    outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True)
    outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state(
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
        dtype, config, mimic_v1_6=True, checkpoint=True
    )

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )


1695
1696
def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
    """Run DotProductAttention module in FP8 with checkpointing"""
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
    steps = 10
    path = "checkpoint.pt"
    fp8_enabled = True
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_enabled,
        fp8_mha=False,
    )

    reset_rng_states()
    hidden_states = torch.randn(
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

    def get_model(dtype, config):
        sigma = 0.023
        init_method = init_method_normal(sigma)
        output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1722
        with quantized_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
            block = TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_heads,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                hidden_dropout=0.0,
                attention_dropout=0.0,
                fuse_qkv_params=True,
                params_dtype=dtype,
                device="cuda",
            )
        return block

    block = get_model(dtype, config)
    for i in range(steps // 2):
1739
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    if checkpoint:
        sd = block.state_dict()
        if mimic_v1_6:
            sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
                "self_attention.core_attention._extra_state"
            ]
            del sd["self_attention.core_attention._extra_state"]
        torch.save(sd, path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

        _cpu_rng_state_new = torch.get_rng_state()
        _cuda_rng_state_new = torch.cuda.get_rng_state()

        del block
        block = get_model(dtype, config)
        block.load_state_dict(torch.load(path, weights_only=False))
        torch.set_rng_state(_cpu_rng_state_new)
        torch.cuda.set_rng_state(_cuda_rng_state_new)

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for i in range((steps + 1) // 2):
1774
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [output, hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)

    return outputs


1792
model_configs_fp8_vs_f16 = {
1793
    # test: ModelConfig(b, sq, hq, dqk)
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
    "fp8_9": ModelConfig(2, 2048, 16, 128),
    "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
    "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
    "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
    "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
    "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
    "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
    "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
    "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
    "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
1806
}
1807

1808
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
1809
1810
1811
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]

1812

1813
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1814
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1815
1816
1817
1818
1819
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1820
@pytest.mark.parametrize("RoPE", [True, False])
1821
@pytest.mark.parametrize("is_training", [True, False])
1822
1823
1824
1825
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_mha_fp8_vs_f16(
    dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
):
1826
    """Test MultiHeadAttention module in FP8"""
1827
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1828
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1829
1830
    config = model_configs_fp8_vs_f16[model]

1831
    # Test backend availability
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
            fp8_mha=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
            fp8_mha=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
1849
1850
1851
1852
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_format.replace("hd", "h3d"),
1853
1854
        fp8=True,
        fp8_meta=fp8_meta,
1855
        is_training=is_training,
1856
        deterministic=_deterministic,
1857
    )
1858
1859
    flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
    if flash_attn_supported + fused_attn_supported_fp8 < 1:
1860
        pytest.skip("No FP8 attention backend available.")
1861
    fused_attn_supported_f16 = False
1862
1863
1864
1865
1866
1867
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_format.replace("hd", "h3d"),
            is_training=is_training,
1868
            deterministic=_deterministic,
1869
        )
1870
1871
        _, fused_attn_supported_f16, _ = available_backends
        if not fused_attn_supported_f16:
1872
1873
1874
            pytest.skip("No attention backend available.")

    if flash_attn_supported:
1875
1876
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
1877
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
1878
1879
1880
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1881
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1882
        )
1883

1884
1885
1886
    if fused_attn_supported_fp8:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "1"
1887
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
1888
1889
1890
1891
1892
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
        )
1893

1894
1895
1896
    if fused_attn_supported_f16:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "1"
1897
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
1898
1899
1900
1901
1902
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
        fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
            dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
        )
1903

1904
1905
1906
    atol = 5e-1
    rtol = 5e-1
    rmse_tol = 0.15
1907
    if flash_attn_supported and fused_attn_supported_f16:
1908
1909
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
1910
        compare_and_assert(
1911
1912
1913
1914
1915
1916
1917
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
1918
            True,
1919
        )
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
    if fused_attn_supported_fp8 and fused_attn_supported_f16:
        logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
        compare_and_assert(
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
            True,
        )
1933

1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
        if is_training:
            for i in range(len(param_names[:1])):
                logging.debug("========== {:^25s} ==========".format(param_names[i]))
                compare_and_assert(
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                    True,
                )
1947

1948

1949
1950
1951
def _run_mha_fp8_vs_f16(
    dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
):
1952
    """Run MultiHeadAttention module in FP8"""
1953
1954
1955
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1956

1957
1958
1959
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
Tim Moon's avatar
Tim Moon committed
1960

1961
    with quantized_model_init(enabled=fp8_mha, recipe=fp8_recipe):
1962
1963
1964
1965
        rotary_pos_emb = None
        if RoPE:
            PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
            rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1966
        mha = MultiheadAttention(
1967
1968
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_heads,
1969
            kv_channels=config.head_dim_qk,
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            layer_number=1,
            bias=True,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            params_dtype=dtype,
            input_layernorm=input_layernorm,
            fuse_qkv_params=True,
            attention_type="self",
            qkv_weight_interleaved=True,
            qkv_format=qkv_format,
1981
        ).to(dtype=dtype, device="cuda")
1982
1983
        if not is_training:
            mha = mha.eval()
1984

1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
2005
2006
2007
2008
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
2009

2010
    dim_to_num = {
2011
2012
2013
2014
2015
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2016
        "d": config.head_dim_qk,
2017
2018
2019
2020
2021
2022
2023
2024
2025
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
    layout = "_".join(qkv_format)
    layout = layout.replace("s", "sq")
    tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2026
2027
    tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
    hidden_states = tensor.view(*tensor.shape[:-2], -1)
2028
2029
    if is_training:
        hidden_states.requires_grad = True
2030
2031
2032
    tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
    out_grad = tensor.view(*tensor.shape[:-2], -1)

2033
    with autocast(enabled=fp8_mha, recipe=fp8_recipe):
2034
2035
        out = mha(
            hidden_states,
2036
2037
2038
2039
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
            is_first_microbatch=None,
2040
            rotary_pos_emb=rotary_pos_emb,
2041
2042
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
2043
        )
2044
2045
    if is_training:
        out.backward(out_grad)
Tim Moon's avatar
Tim Moon committed
2046

2047
    param_names = []
2048
    param_names.append("hidden_states.grad")
2049
2050
2051
2052
    params = []
    params.append(hidden_states)
    for name, param in mha.named_parameters():
        if param.requires_grad:
2053
            param_names.append(name + ".grad")
2054
            params.append(param)
2055

2056
2057
2058
    if is_training:
        return out, param_names, tuple(x.grad for x in params)
    return out, param_names, tuple(None for x in params)
2059

2060

2061
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
2062
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
2063
2064
2065
2066
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
2067
@pytest.mark.parametrize("is_training", [True, False])
2068
2069
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
2070
    """Test DotProductAttention module in FP8"""
2071
2072
    config = model_configs_fp8_vs_f16[model]

2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
    # TODO(cyang): think of another way to verify dropout results
    # test cuDNN FP8 dropout
    # 1. we modify the config here to not affect mha_fp8_vs_f16 tests
    # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an
    #    indirect verification method, we create Q/K/V as all 1s and check if O is all 1s
    # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell
    # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type:
    #    if get_device_compute_capability() >= (10, 0):
    #        config.dropout_p = 0.1

2083
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
2084
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
2085
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"
2086

2087
    # Test backend availability
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
2103
2104
2105
2106
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_layout,
2107
2108
        fp8=True,
        fp8_meta=fp8_meta,
2109
        is_training=is_training,
2110
        deterministic=_deterministic,
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if flash_attn_supported + fused_attn_supported < 1:
        pytest.skip("No FP8 attention backend available.")
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            is_training=is_training,
2121
            deterministic=_deterministic,
2122
2123
2124
2125
2126
2127
2128
2129
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")

    if flash_attn_supported:
2130
2131
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
2132
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
2133
        _attention_backends["backend_selection_requires_update"] = True
2134
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
2135
        flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2136
2137
2138
2139
2140
2141
            dtype, config, True, qkv_layout, is_training, fp8_recipe
        )

    if unfused_attn_supported:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "0"
2142
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2143
2144
2145
2146
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
        unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
            dtype, config, True, qkv_layout, is_training, fp8_recipe
2147
        )
2148

2149
2150
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
2151
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2152
    _attention_backends["backend_selection_requires_update"] = True
2153
    logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
2154
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2155
        dtype, config, True, qkv_layout, is_training, fp8_recipe
2156
    )
2157

2158
2159
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
2160
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2161
2162
    if config.dropout_p == 0.0:
        # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
2163
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
2164
        fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
2165
            dtype, config, False, qkv_layout, is_training, fp8_recipe
2166
        )
2167

2168
2169
    atol = 5e-1
    rtol = 5e-2
2170
    rmse_tol = 0.11
2171
    bwd_names = ["dq", "dk", "dv"]
2172
    if flash_attn_supported:
2173
2174
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2175
        compare_and_assert(
2176
2177
2178
2179
2180
2181
2182
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2183
            True,
2184
        )
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
    if unfused_attn_supported:
        logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
        compare_and_assert(
            unfused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "unfused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
            True,
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
                compare_and_assert(
                    unfused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"unfused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                    True,
                )
2211
2212
2213
2214
2215
2216
    if config.dropout_p != 0.0:
        # test cuDNN FP8 dropout
        assert torch.all(
            fused_attn_fwd_fp8 == 1
        ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
    else:
2217
2218
        logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2219
        compare_and_assert(
2220
2221
2222
2223
2224
2225
2226
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2227
            True,
2228
2229
2230
2231
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
2232
                compare_and_assert(
2233
2234
2235
2236
2237
2238
2239
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
2240
                    True,
2241
                )
2242
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0"
2243
2244


2245
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe):
2246
    """Run DotProductAttention module in FP8"""
2247
2248
2249
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
2250

2251
2252
2253
2254
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

2255
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
2256
    with quantized_model_init(enabled=fp8_dpa):
2257
2258
        dpa = DotProductAttention(
            config.num_heads,
2259
            config.head_dim_qk,
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            sequence_parallel=False,
            tp_size=1,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            tp_group=None,
            layer_number=1,
            attention_type="self",
            qkv_format=qkv_format,
        ).to(dtype=dtype, device="cuda")
2270
2271
        if not is_training:
            dpa = dpa.eval()
2272

2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
2293
2294
2295
2296
2297
2298
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

    dim_to_num = {
2299
2300
2301
2302
2303
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2304
        "d": config.head_dim_qk,
2305
2306
2307
2308
2309
2310
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
2311
    inp = []
2312
2313
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
2314
        if i == 0:
2315
            layout = layout.replace("s", "sq")
2316
        else:
2317
2318
2319
2320
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2321
2322
2323
2324
2325
        if config.dropout_p == 0.0:
            tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
        else:
            # test cuDNN FP8 dropout
            tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda")
2326
2327
        tensor_count = 1
        split_dim = 0
2328
        for dim, l in enumerate(layout.split("_")):
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
        inp[i].requires_grad = True

2342
2343
2344
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
2345
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
2346
    out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
2347

2348
    with autocast(enabled=fp8_dpa, recipe=fp8_recipe):
2349
2350
2351
2352
        out = dpa(
            inp[0],
            inp[1],
            inp[2],
2353
2354
2355
2356
2357
2358
2359
2360
            qkv_format=qkv_format,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
2361
            fp8_output=fp8_dpa,
2362
        )
2363
2364
    if is_training:
        out.backward(out_grad)
2365

2366
2367
2368
    if is_training:
        return out, (inp[0].grad, inp[1].grad, inp[2].grad)
    return out, (None, None, None)
2369
2370
2371


model_configs_fp8 = {
2372
    # test: ModelConfig(b, sq, hq, dqk)
2373
2374
2375
2376
2377
2378
2379
2380
    "fp8_1": ModelConfig(1, 512, 1, 64),
    "fp8_2": ModelConfig(4, 512, 16, 64),
    "fp8_3": ModelConfig(1, 2048, 1, 128),
    "fp8_4": ModelConfig(2, 2048, 24, 128),
    "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
    "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
    "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
    "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
2381
2382
}
param_types_fp8 = [torch.float16, torch.bfloat16]
2383
2384
2385
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
2386
2387


2388
2389
2390
2391
2392
2393
2394
2395
@pytest.mark.skipif(
    (
        get_cudnn_version() < (8, 9, 3)
        if cudnn_frontend_version == 0
        else get_cudnn_version() < (9, 2, 1)
    ),
    reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
2396
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
    """Test FP8 dot product attention implementations based on cuDNN frontend
    v0.9 and v1.0+. Each test compares results from a custom implementation of
    an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
    implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
    Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""

    config = model_configs_fp8[model]

2408
2409
2410
2411
2412
2413
2414
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
        is_training=is_training,
2415
        deterministic=_deterministic,
2416
2417
2418
2419
2420
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not (fused_attn_backends and unfused_attn_supported):
        pytest.skip("Not enough backends to run this test with.")

2421
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
2422
2423
2424
    unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
        dtype, config, "UnfusedDotProductAttention"
    )
2425

2426
2427
    atol = 5e-1
    rtol = 5e-1
2428
    rmse_tol = 0.13
2429
    compare_and_assert(
2430
2431
2432
2433
2434
2435
2436
        fused_attn_fwd_fp8,
        unfused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "unfused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
2437
        True,
2438
    )
2439
    compare_and_assert(
2440
2441
2442
2443
2444
2445
2446
        fused_attn_bwd_fp8,
        unfused_attn_bwd_f16,
        "fused_attn_bwd_fp8",
        "unfused_attn_bwd_f16",
        atol,
        rtol,
        rmse_tol,
2447
        True,
2448
    )
2449
2450
2451
2452
2453


def _run_custom_mha_fp8(dtype, config, backend):
    """Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
2454
    reset_rng_states()
2455
2456
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
2457
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
2458
2459
2460
2461
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2462
2463
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2464
    _attention_backends["backend_selection_requires_update"] = True
2465

2466
2467
2468
    inp = 0.0001 * torch.randint(
        -100,
        100,
2469
        (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
2470
2471
2472
2473
2474
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
2475
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2476
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2477

2478
    out_grad = 0.01 * torch.randn(
2479
        config.batch_size * config.max_seqlen_q,
2480
        config.num_heads * config.head_dim_qk,
2481
2482
2483
2484
        dtype=dtype,
        device="cuda",
    )
    torch.save(out_grad, "out_grad.pt")
2485
2486
2487
2488
2489
2490
2491
2492

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

2493
    mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
2494
    with autocast(enabled=True, recipe=fp8_recipe):
2495
        out = mha(inp, cu_seqlens, config.max_seqlen_q)
2496
    out.backward(out_grad)
2497

2498
    out = torch.load("out.pt")
2499
2500
2501
2502
    dqkv = torch.load("dqkv.pt")
    return (
        out.view(config.batch_size, config.max_seqlen_q, -1),
        dqkv.view(
2503
            config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
2504
2505
        ).contiguous(),
    )
2506

2507

2508
2509
2510
def _run_ref_mha_f16(dtype, config, backend):
    """Run reference F16 FusedAttention. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
2511
2512
2513

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
2514
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2515
2516
2517
2518
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2519
2520
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2521
    _attention_backends["backend_selection_requires_update"] = True
2522

2523
    inp = torch.load("qkv.pt").to(device="cuda")
2524
2525
2526
    inp.requires_grad = True
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2527
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2528
2529
2530
    out_grad = (
        torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
    )
2531
2532
2533
2534
2535
2536
2537

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
2538

2539
2540
    block = DotProductAttention(
        config.num_heads,
2541
        config.head_dim_qk,
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
        attention_dropout=config.dropout_p,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type="self",
        qkv_format="bshd",
    ).to(dtype=dtype, device="cuda")

    q = inp[:, :, 0, :, :]
    k = inp[:, :, 1, :, :]
    v = inp[:, :, 2, :, :]
2555
2556
2557
2558
    out = block(q, k, v, attn_mask_type=config.attn_mask_type)
    out.backward(out_grad)

    return out, inp.grad
2559
2560
2561
2562
2563
2564
2565


_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

2566
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
2567
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
2568
2569
2570
2571
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
2572
2573


2574
class _custom_mha_fp8(torch.autograd.Function):
2575
2576
2577
2578
2579
2580
2581
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
2582
        num_heads: int,
2583
2584
2585
2586
2587
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        is_training: bool,
2588
        mask_type: str,
2589
        quantizers: list[Quantizer],
2590
    ) -> torch.Tensor:
2591
        qkv_dtype = inp.dtype
2592
2593
2594

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
2595
        h = num_heads
2596
2597
2598
        d = in_features // h
        b = cu_seqlens.numel() - 1

2599
2600
2601
2602
2603
2604
2605
2606
        input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
        qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
        qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
        dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
        dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
2607

2608
        inp_fp8 = input_quantizer(inp)
2609

2610
        qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight)
2611

2612
        qkv, *_ = ext.general_gemm(
2613
            qkv_weight_fp8,
2614
            inp_fp8,
2615
            bias=qkv_bias,
2616
2617
            out_dtype=qkv_weight_fp8.dtype,
            quantization_params=qkv_quantizer,
2618
2619
            use_split_accumulator=_2X_ACC_FPROP,
        )
2620
        qkv = qkv.view(-1, 3, h, d)
2621
        qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
2622
        torch.save(qkv_fp16, "qkv.pt")
2623
        if cudnn_frontend_version == 1:
2624
            qkv = qkv.view(b, max_s, 3, h, d)  # bs3hd
2625
2626

        # FMHA
2627
2628
2629
2630
2631
2632
2633
2634
        q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
        k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
        v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
        q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
        k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
        v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)

        out, aux_ctx_tensors = fused_attn_fwd(
2635
2636
2637
2638
2639
            is_training,
            max_s,
            max_s,
            cu_seqlens,
            cu_seqlens,
2640
2641
2642
2643
            q,
            k,
            v,
            qkv_dtype,
2644
2645
2646
2647
2648
2649
2650
2651
            FusedAttnBackend["FP8"],
            attn_scale=None,
            dropout=p_dropout,
            fast_zero_fill=fast_zero_fill,
            qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
            attn_bias_type="no_bias",
            attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
            rng_gen=None,
2652
2653
            o_quantizer=o_quantizer,
            s_quantizer=s_quantizer,
2654
        )
2655

2656
        tensors_to_save, tensor_objects = prepare_for_saving(q, k, v, inp_fp8, qkv_weight_fp8, out)
2657
2658
2659

        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
2660
        ctx.aux_ctx_tensors = aux_ctx_tensors
2661
        ctx.qkv_dtype = qkv_dtype
2662
2663
2664
2665
2666
2667
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.hidden_size = in_features
2668
        ctx.num_heads = num_heads
2669
2670
        ctx.mask_type = mask_type
        ctx.dtype = inp.dtype
2671

2672
2673
2674
2675
2676
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = s_quantizer

2677
        out = out.view(-1, in_features)  # (bs)(hd)
2678
        out_fp16 = out.dequantize()
2679
        torch.save(out_fp16, "out.pt")  # (bs)(hd)
2680
        return out_fp16
2681
2682

    @staticmethod
2683
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
2684
        with torch.cuda.nvtx.range("_DPA"):
2685
            saved_tensors = ctx.saved_tensors
2686
            (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
2687
2688
                ctx.tensor_objects, saved_tensors
            )
2689

2690
            proj_dgrad = ctx.dO_quantizer(grad_output)
2691
            fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2692

2693
            dq, dk, dv, *rest = fused_attn_bwd(
2694
2695
2696
2697
                ctx.max_s,
                ctx.max_s,
                ctx.cu_seqlens,
                ctx.cu_seqlens,
2698
2699
2700
                q,
                k,
                v,
2701
2702
                out,
                proj_dgrad.view_as(out),
2703
                ctx.qkv_dtype,
2704
2705
2706
2707
2708
                fp8_dtype_backward,
                ctx.aux_ctx_tensors,
                FusedAttnBackend["FP8"],
                None,
                None,
2709
2710
2711
                ctx.S_quantizer,
                ctx.dP_quantizer,
                ctx.dQKV_quantizer,
2712
2713
2714
2715
2716
2717
2718
                attn_scale=None,
                dropout=ctx.p_dropout,
                fast_zero_fill=ctx.fast_zero_fill,
                qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
                attn_bias_type="no_bias",
                attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
            )
2719
            dim = 2 if cudnn_frontend_version == 1 else 1
2720
2721
            dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
            dqkv_shape = list(dq._data.shape)
2722
            dqkv_shape.insert(dim, 3)
2723
            dqkv_stride = list(dq._data.stride())
2724
            dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
2725
2726
2727
            dqkv.set_(
                dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride
            )  # bs3hd
2728

2729
            dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
2730
2731
            dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape)
            dqkv_c_fp16 = dqkv_c.dequantize()
2732
            torch.save(dqkv_c_fp16, "dqkv.pt")
2733

2734
2735
2736
            qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer)
            dqkv_c._transpose = None
            dqkv_c._create_transpose()
2737
2738

            # QKV DGRAD
2739
2740
            qkv_dgrad, *_ = ext.general_gemm(
                qkv_weight_fp8,
2741
                dqkv_c,
2742
                ctx.dtype,
2743
                use_split_accumulator=_2X_ACC_DGRAD,
2744
                layout="NN",
2745
            )
2746

2747
            # QKV WGRAD
2748
2749
2750
2751
            qkv_wgrad, *_ = ext.general_gemm(
                inp_fp8,
                dqkv,
                ctx.dtype,
2752
                use_split_accumulator=_2X_ACC_WGRAD,
2753
                layout="NT",
2754
2755
            )

2756
2757
        return (
            qkv_dgrad,
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2769
2770
            None,
        )
2771

2772

2773
class Custom_MHA_FP8(TransformerEngineBaseModule):
2774
    def __init__(self, config, params_dtype: torch.dtype = torch.float32):
2775
2776
        super().__init__()
        self.p_dropout = config.dropout_p
2777
        self.h = config.num_heads
2778
        self.hidden_size = config.hidden_size
2779
        self.head_dim = config.head_dim_qk
2780
        self.fast_zero_fill = True
2781
        self.mask_type = config.attn_mask_type
2782

Tim Moon's avatar
Tim Moon committed
2783
        self.qkv_weight = torch.nn.Parameter(
2784
2785
2786
2787
2788
2789
2790
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
Tim Moon's avatar
Tim Moon committed
2791
        self.qkv_bias = torch.nn.Parameter(
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)

    def forward(
2803
2804
2805
2806
        self,
        inp: torch.Tensor,
        cu_seqlens,
        max_s,
2807
    ) -> torch.Tensor:
2808
        with self.prepare_forward_ctx(inp, num_gemms=3) as inp:
2809
            out = _custom_mha_fp8.apply(
2810
2811
2812
2813
2814
2815
2816
2817
2818
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
2819
                self.training,
2820
                self.mask_type,
2821
                self.quantizers,
2822
            )
2823
        return out