test_attention.py 106 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
import logging
Tim Moon's avatar
Tim Moon committed
5
import os
6
7
import sys
import pathlib
8
from typing import Any, Dict, Tuple, Union
Tim Moon's avatar
Tim Moon committed
9

10
import pytest
Tim Moon's avatar
Tim Moon committed
11
import torch
12

13
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype
Tim Moon's avatar
Tim Moon committed
14
from transformer_engine.common import recipe
15
16
17
18
from transformer_engine.pytorch import (
    TransformerLayer,
    autocast,
    quantized_model_init,
Tim Moon's avatar
Tim Moon committed
19
    DotProductAttention,
20
21
22
23
24
25
26
    MultiheadAttention,
    get_device_compute_capability,
    Quantizer,
    is_fp8_available,
    is_bf16_available,
)
from transformer_engine.pytorch.attention.dot_product_attention import (
27
28
    _attention_backends,
)
29
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
30
    FlashAttentionUtils,
31
    check_set_window_size,
Tim Moon's avatar
Tim Moon committed
32
)
33
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
Tim Moon's avatar
Tim Moon committed
34
35
36
37
38
39
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    FusedAttnBackend,
    fused_attn_bwd,
    fused_attn_fwd,
)
40
41
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
42
43
44
45
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
46
from transformer_engine.pytorch.utils import get_cudnn_version
47
import transformer_engine_torch as tex
48
49
from transformer_engine.pytorch.quantized_tensor import (
    Quantizer,
50
51
52
    prepare_for_saving,
    restore_from_saved,
)
53

54
55
56
57
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
    reset_rng_states,
58
    compare_and_assert,
59
60
61
62
63
    ModelConfig,
    dtype_tols,
    get_available_attention_backends,
)

64
# Check if hardware supports FP8 attention.
65
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
66
67
68
69
70
71
72
73
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability()
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
    fp8_attn_available = False
    reason_for_no_fp8_attn = (
        "FP8 attention is not supported for compute capability ="
        f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
    )
74

75
76
77
78
79
80
81
82

# Get determinism
_deterministic = (
    not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
    or torch.are_deterministic_algorithms_enabled()
)


83
# Reset RNG seed and states
84
seed = 1234
85
reset_rng_states()
86
87


88
# Reset FP8 global state manager
89
90
91
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
92
    FP8GlobalStateManager.reset()
93

94

95
96
# Define F16 data types to test
param_types = [torch.float16]
97
if is_bf16_available():
98
99
100
    param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]

101
model_configs_base = {
102
    # test: ModelConfig(b, sq, hq, dqk)
103
104
105
106
107
108
109
110
111
112
113
114
    "base_1_0": ModelConfig(8, 128, 16, 64),
    "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
    "base_2_0": ModelConfig(2, 2048, 24, 128),
    "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
    "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
    "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
    "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
    "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
    "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
    "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
    "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
115
116
}

117

118
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
119
@pytest.mark.parametrize("dtype", param_types)
120
121
122
123
124
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
125
@pytest.mark.parametrize("swa", [False])
126
@pytest.mark.parametrize("pad_between_seqs", [False])
127
def test_dot_product_attention(
128
129
130
131
132
133
134
135
    dtype,
    model_configs,
    model,
    ckpt_attn,
    workspace_opt,
    qkv_layout,
    swa,
    pad_between_seqs,
136
):
137
    """Test DotProductAttention module"""
138

Tim Moon's avatar
Tim Moon committed
139
    # Get configs
140
    tols = dict(atol=1e-3, rtol=1e-3)
Tim Moon's avatar
Tim Moon committed
141
    if dtype == torch.bfloat16:
142
        tols = dict(atol=1.5e-2, rtol=1.5e-2)
143
    config = model_configs[model]
144
    is_mla = config.head_dim_qk != config.head_dim_v
145
    is_mqa_gqa = config.num_heads != config.num_gqa_groups
146
147
    if qkv_layout is None:
        if config.attn_type == "self":
148
            qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
149
        else:
150
            qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
151
    if "3" in qkv_layout and config.attn_type == "cross":
152
        pytest.skip("No need to test this layout for cross attention")
Tim Moon's avatar
Tim Moon committed
153

154
155
    if config.window_size == (-1, -1) and swa:
        config.window_size = [2, 2]
156

157
    config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
158
159
160
161
162
    qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        config.attn_mask_type = (
            "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
        )
163

164
    # Get backends
165
    is_training = True
166
    available_backends, _, fused_attn_backends = get_available_attention_backends(
167
        config,
168
        qkv_dtype=dtype,
169
        qkv_layout=qkv_layout,
170
        pad_between_seqs=pad_between_seqs,
171
        is_training=is_training,
172
        deterministic=_deterministic,
Tim Moon's avatar
Tim Moon committed
173
    )
174
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
175

176
177
    if not fused_attn_supported:
        is_training = False
178
        available_backends, _, fused_attn_backends = get_available_attention_backends(
179
180
181
182
183
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            pad_between_seqs=pad_between_seqs,
            is_training=is_training,
184
            deterministic=_deterministic,
185
186
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
187

188
189
    # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
    # mannually pads and unpads the input and output of FlashAttention for testing purposes
190
191
    if (
        pad_between_seqs
192
        and FlashAttentionUtils.is_installed
193
194
195
196
        and not (
            config.max_seqlen_q != config.max_seqlen_kv
            and config.attn_mask_type in ["causal", "padding_causal"]
        )
197
        and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
198
    ):
199
        flash_attn_supported = True
200
201
202

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
203
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
204
205

    # UnfusedDotProductAttention backend
206
    if unfused_attn_supported:
207
        unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
208
209
210
211
212
213
214
215
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
216
        )
Tim Moon's avatar
Tim Moon committed
217
218
219

    # FusedAttention backend
    if fused_attn_supported:
220
        if len(fused_attn_backends) == 1:
221
            fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
222
223
224
225
226
227
228
229
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
230
            )
231
        if len(fused_attn_backends) == 2:
232
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
233
            fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
234
235
236
237
238
239
240
241
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
242
243
            )
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
244
            fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
245
246
247
248
249
250
251
252
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
253
            )
254

Tim Moon's avatar
Tim Moon committed
255
256
    # FlashAttention backend
    if flash_attn_supported:
257
        flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
258
259
260
261
262
263
264
265
            dtype,
            config,
            "FlashAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
Tim Moon's avatar
Tim Moon committed
266
        )
267

268
    # Compare results
269
    logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
270
    if unfused_attn_supported and flash_attn_supported:
271
        logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
272
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
273
        for i, _ in enumerate(flash_attn_bwd):
274
            torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
275
276
277
    if unfused_attn_supported and fused_attn_supported:
        logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
278
279
        if config.return_max_logit:
            torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
280
281
        for i, _ in enumerate(unfused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
282
    if fused_attn_supported and flash_attn_supported:
283
        logging.info("[test_dot_product_attention]: fused attn vs flash attn")
284
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
285
        for i, _ in enumerate(flash_attn_bwd):
286
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
287
    if fused_attn_supported and len(fused_attn_backends) == 2:
288
        logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
289
        torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
290
        for i, _ in enumerate(fused_attn_bwd):
291
292
            torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)

293

294
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
295
296
297
298
299
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
    """Test DotProductAttention module with checkpointing"""
300
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
301

302

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
model_configs_max_logit = {
    # test: ModelConfig(b, sq, hq, dqk)
    "max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
    "max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "max_logit_4": ModelConfig(
        8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
    ),
    "max_logit_5": ModelConfig(
        8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
    ),
    "max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with checkpointing"""
    config = model_configs[model]
    config.return_max_logit = True
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
model_configs_num_splits = {
    # test: ModelConfig(b, sq, hq, dqk)
    "num_splits_1_0": ModelConfig(2, 2048, 24, 128, num_splits=2),
    "num_splits_1_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_num_splits])
@pytest.mark.parametrize("model", model_configs_num_splits.keys())
def test_dpa_num_splits(dtype, model_configs, model):
    """Test DotProductAttention with FlashAttention-3 num_splits enabled"""
    test_dot_product_attention(
        dtype,
        model_configs,
        model,
        False,
        True,
        None,
        False,
        False,
    )


355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
model_configs_softmax = {
    # test: ModelConfig(b, sq, hq, dqk)
    "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
    "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"),
    "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"),
    "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
    "softmax_2_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
    ),
    "softmax_2_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
    ),
    "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"),
    "softmax_3_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one"
    ),
    "softmax_3_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable"
    ),
    "softmax_4_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal"
    ),
    "softmax_4_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="off-by-one",
    ),
    "softmax_4_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="learnable",
    ),
    "softmax_5_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal"
    ),
    "softmax_5_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="off-by-one",
    ),
    "softmax_5_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="learnable",
    ),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax(dtype, model_configs, model):
    """Test DotProductAttention module with different softmax types"""
    test_dot_product_attention(
        dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False
    )


434
435
436
437
438
439
440
441
442
@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax_thd(dtype, model_configs, model):
    """Test DotProductAttention module with different softmax types"""
    test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False)


443
model_configs_mla = {
444
445
446
447
448
    # test: ModelConfig(b, sq, hq, dqk)
    "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),
    "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
    "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),
    "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),
449
    "mla_2_1": ModelConfig(
450
        1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
451
    ),
452
    "mla_2_2": ModelConfig(
453
        1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
454
455
456
457
458
459
    ),
    "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),
    "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),
    "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),
    "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),
    "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),
460
461
462
463
464
465
466
467
468
469
470
471
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
    """Test DotProductAttention module with Multi-Latent Attention (MLA)"""
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


472
model_configs_mask = {
473
    # test: ModelConfig(b, sq, hq, dqk)
474
475
476
477
478
479
    "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
    "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "mask_2_1": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
480
    ),
481
482
483
484
485
486
487
488
489
490
    "mask_2_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
    ),
    "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
491
    "mask_5_1": ModelConfig(
492
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
493
494
    ),
    "mask_5_2": ModelConfig(
495
496
497
498
499
500
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_7_0": ModelConfig(
        2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
501
    ),
502
503
504
505
506
507
508
    "mask_7_1": ModelConfig(
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
    ),
    "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
509
    "mask_10_0": ModelConfig(
510
        2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
511
    ),
512
    "mask_10_1": ModelConfig(
513
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
514
    ),
515
}
516

517

518
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
519
520
521
522
523
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
    """Test DotProductAttention module with different mask types"""
524
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
525

526

527
model_configs_bias = {
528
    # test: ModelConfig(b, sq, hq, dqk)
529
530
531
532
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
533
534
    "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),
    "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"),
535
536
    "bias_2_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
537
    ),
538
539
540
541
542
543
544
545
    "bias_2_1": ModelConfig(
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
546
    ),
547
    "bias_2_2": ModelConfig(
548
        4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
549
    ),
550
    "bias_2_3": ModelConfig(
551
552
553
554
555
556
557
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
558
559
    ),
    "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"),
560
561
    "bias_2_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
562
    ),
563
564
565
566
567
568
569
570
571
    "bias_3_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_1": ModelConfig(
        2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_2": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
572
    "bias_3_3": ModelConfig(
573
574
575
576
577
578
579
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="causal",
        attn_bias_type="post_scale_bias",
580
    ),
581
582
583
    "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
    "bias_3_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
584
    ),
585
    "bias_4_0": ModelConfig(
586
        4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
587
    ),
588
    "bias_4_1": ModelConfig(
589
590
591
592
593
594
595
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
596
    ),
597
    "bias_4_2": ModelConfig(
598
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
599
    ),
600
    "bias_4_3": ModelConfig(
601
602
603
604
605
606
607
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
608
    ),
609
610
    "bias_4_4": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
611
    ),
612
613
614
615
616
617
618
619
    "bias_4_5": ModelConfig(
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="alibi",
620
    ),
621
}
622

623

624
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
625
626
627
628
629
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types"""
630
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
631

632

633
model_configs_bias_shapes = {
634
    # test: ModelConfig(b, sq, hq, dqk)
635
636
637
638
639
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
    "bias_1_4": ModelConfig(
640
        4,
641
642
        2048,
        24,
643
        128,
644
645
646
647
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="1hss",
        alibi_type="custom",
648
649
    ),
    "bias_1_5": ModelConfig(
650
651
652
653
654
655
656
657
        2,
        2048,
        24,
        128,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="bhss",
        alibi_type="custom",
658
    ),
659
660
}

661

662
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
663
664
665
666
667
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types and shapes"""
668
669
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)

670

671
model_configs_swa = {
672
    # test: ModelConfig(b, sq, hq, dqk)
673
674
675
676
677
678
679
680
681
682
683
684
    "swa_1_1": ModelConfig(2, 2048, 16, 64),
    "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
    "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
    "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "swa_3_2": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
    ),
    "swa_3_3": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
685
    ),
686
687
688
689
690
691
692
    "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
    "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
693
    "swa_6_2": ModelConfig(
694
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
695
696
    ),
    "swa_6_3": ModelConfig(
697
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
698
    ),
699
}
700
701


702
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
703
704
705
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
706
707
@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"])
def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout):
708
    """Test DotProductAttention module with sliding window attention"""
709
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)
710

711

712
model_configs_alibi_slopes = {
713
    # test: ModelConfig(b, sq, hq, dqk)
714
715
716
717
718
719
720
721
722
723
724
725
726
    "alibi_1_0": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
    ),
    "alibi_1_1": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="vanilla",
    ),
727
    "alibi_2_0": ModelConfig(
728
        2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
729
730
    ),
    "alibi_2_1": ModelConfig(
731
732
733
734
735
736
737
738
        1,
        1024,
        24,
        128,
        max_seqlen_kv=2048,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="custom",
739
    ),
740
}
741
742


743
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
744
745
746
747
748
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
    """Test DotProductAttention module with ALiBi slopes"""
749
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
750

751

752
qkv_layouts = [
753
754
755
756
757
758
759
760
761
762
763
    "sb3hd",
    "sbh3d",
    "sbhd_sb2hd",
    "sbhd_sbh2d",
    "sbhd_sbhd_sbhd",
    "bs3hd",
    "bsh3d",
    "bshd_bs2hd",
    "bshd_bsh2d",
    "bshd_bshd_bshd",
]
764

765

766
model_configs_layout = {
767
    # test: ModelConfig(b, sq, hq, dqk)
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
    "layout_0_0": ModelConfig(2, 128, 16, 64),
    "layout_0_1": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "layout_0_3": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_1_0": ModelConfig(2, 2048, 24, 128),
    "layout_1_1": ModelConfig(
        2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_3": ModelConfig(
        1,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
    "layout_2_1": ModelConfig(
        2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
800
801
}

802

803
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
804
@pytest.mark.parametrize("dtype", param_types_lean)
805
806
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
807
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
808
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
809
    """Test DotProductAttention module with different QKV layouts"""
810
811
812
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


813
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
814
model_configs_layout_thd = {
815
    # test: ModelConfig(b, sq, hq, dqk)
816
817
818
819
820
821
822
    "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "layout_1_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
823
    ),
824
    "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
825
    "layout_2_1": ModelConfig(
826
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
827
828
    ),
    "layout_2_2": ModelConfig(
829
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
830
    ),
831
    "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
832
    "layout_3_1": ModelConfig(
833
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
834
835
    ),
    "layout_3_2": ModelConfig(
836
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
837
    ),
838
    "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
839
    "layout_4_1": ModelConfig(
840
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
841
842
    ),
    "layout_4_2": ModelConfig(
843
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
844
845
    ),
    "layout_5_0": ModelConfig(
846
        2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
847
848
    ),
    "layout_5_1": ModelConfig(
849
850
851
852
853
854
855
        2,
        2048,
        24,
        128,
        num_gqa_groups=1,
        attn_mask_type="padding_causal_bottom_right",
        window_size=(4, 0),
856
857
858
    ),
    "layout_5_2": ModelConfig(
        2,
859
        2048,
860
861
        24,
        128,
862
863
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal_bottom_right",
864
865
        window_size=(4, 0),
    ),
866
867
868
}


869
870
871
872
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
    get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
873
874
875
876
877
878
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with different QKV layouts"""
879
880
881
    config = model_configs[model]
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")
882
    logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
883
    pad_between_seqs = True
884
885
886
    test_dot_product_attention(
        dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
    )
887
    if get_cudnn_version() >= (9, 3, 0):
888
        logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
889
890
891
892
893
        # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
        pad_between_seqs = False
        test_dot_product_attention(
            dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
        )
894

895

896
def _run_dot_product_attention(
897
898
899
900
901
902
903
904
905
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_layout: str,
    workspace_opt: bool,
    pad_between_seqs: bool,
    is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
906
907
908
    """Run DotProductAttention module with one forward pass and one backward pass"""
    # Set RNG and environment varables
    reset_rng_states()
909
910
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
911
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
912
913
914
915
916
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
917
918
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
919
    _attention_backends["backend_selection_requires_update"] = True
920

921
    # Create seqlens
922
923
924
925
926
927
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
928
            seqlens_kv = seqlens_q
929
        if config.attn_type == "cross":
930
931
932
933
934
935
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
936
937
938
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
939
    else:
940
941
942
943
944
945
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
946
947
948
949
950
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

951
952
953
954
955
956
957
    seqlens_q_after_pad = seqlens_q.clone()
    seqlens_kv_after_pad = seqlens_kv.clone()
    cu_seqlens_q_after_pad = cu_seqlens_q.clone()
    cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
    pad_len = [0] * config.batch_size
    if pad_between_seqs:
        max_pad_len = 3
958
        pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda")  # 3
959
960
961
962
963
        seqlens_q_after_pad = seqlens_q + pad_len
        seqlens_kv_after_pad = seqlens_kv + pad_len
        cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
        cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)

964
965
966
    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
967
        if config.attn_type == "self":
968
969
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
970
971
972
973
974
975
976
977
978
979
980
981
982
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
983
            attention_mask = attention_mask_q.to(device="cuda")
984
        if config.attn_type == "cross":
985
986
987
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
                attention_mask_kv = torch.cat(
                    [
                        attention_mask_kv,
                        torch.Tensor(
                            [False] * seqlens_kv[i]
                            + [True] * (config.max_seqlen_kv - seqlens_kv[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
1015
            attention_mask = (
1016
1017
1018
                attention_mask_q.to(device="cuda"),
                attention_mask_kv.to(device="cuda"),
            )
1019

1020
    alibi_slopes = None
1021
1022
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
        if config.bias_shape == "1hss":
1023
1024
1025
            alibi_slopes = (
                torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
            )
1026
        if config.bias_shape == "bhss":
1027
1028
1029
1030
1031
            alibi_slopes = (
                torch.randn(config.batch_size, config.num_heads)
                .abs()
                .to(dtype=torch.float32, device="cuda")
            )
1032

1033
1034
    # Create input tensors
    dim_to_num = {
1035
1036
1037
1038
1039
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
1040
1041
        "dqk": config.head_dim_qk,
        "dv": config.head_dim_v,
1042
1043
1044
1045
1046
1047
        "t": cu_seqlens_q_after_pad[-1],
        "tg": cu_seqlens_kv_after_pad[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
1048
    inp = []
1049
    inp_orig = []
1050
1051
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
1052
        if i == 0:
1053
            layout = layout.replace("s", "sq")
1054
        else:
1055
1056
1057
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
1058
1059
1060
1061
        if i == 2:
            layout = layout.replace("d", "dv")
        else:
            layout = layout.replace("d", "dqk")
1062
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
1063
        tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
1064
1065
        # tensor: with padding tokens
        # tensor_orig: without padding tokens
1066
        tensor_orig = tensor
1067
1068
        if qkv_format == "thd" and pad_between_seqs:
            tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1069
            if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_q_after_pad[i - 1],
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_q_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
1083
            if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_kv_after_pad[i - 1],
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_kv_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
1097
1098
        tensor_count = 1
        split_dim = 0
1099
        for dim, l in enumerate(layout.split("_")):
1100
1101
1102
1103
1104
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
1105
1106
1107
        tensors_orig = (
            torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
        )
1108
1109
1110
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
1111
                inp_orig.append(tensors_orig[j].squeeze(split_dim))
1112
1113
            else:
                inp.append(tensors[j])
1114
                inp_orig.append(tensors_orig[j])
1115
    for i in range(3):
1116
        inp[i].requires_grad = True
1117
1118
        inp_orig[i].requires_grad = True

1119
    # Create output gradient
1120
1121
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
1122
    qkv_format_kv = qkv_format_kv.replace("d", "dv")
1123
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
1124
1125
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
    out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
1126
    out_grad_orig = out_grad
1127
1128
    if qkv_format == "thd" and pad_between_seqs:
        out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1129
        if qkv_format_kv == "t_h_dv":
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
            for i in range(1, config.batch_size + 1):
                valid_range = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
                out_grad[pad_range[0] : pad_range[1]] = 0.0
                out_grad_orig = torch.cat(
                    [out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
                )
1140

1141
    # Create bias
1142
    if config.attn_bias_type in ["no_bias", "alibi"]:
1143
        bias = None
1144
1145
1146
1147
    if config.attn_bias_type == "post_scale_bias":
        shape = "_".join(config.bias_shape)
        shape = shape.replace("_s_s", "_sq_skv")
        tensor_shape = [dim_to_num[j] for j in shape.split("_")]
1148
        bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
1149
        if config.bias_shape != "1hss":
1150
            bias.requires_grad = False
1151
1152
1153
1154

    # Create RNG
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1155

1156
1157
1158
1159
1160
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    # Set up model
1161
1162
    block = DotProductAttention(
        config.num_heads,
1163
        (config.head_dim_qk, config.head_dim_v),
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
        num_gqa_groups=config.num_gqa_groups,
        attention_dropout=config.dropout_p,
        qkv_format=qkv_format,
        attn_mask_type=config.attn_mask_type,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type=config.attn_type,
1174
        softmax_type=config.softmax_type,
1175
        return_max_logit=config.return_max_logit,
1176
    ).to(dtype=dtype, device="cuda")
1177
1178
    if not is_training:
        block = block.eval()
1179
1180
    if is_training and config.softmax_type != "vanilla":
        block.softmax_offset.requires_grad = True
1181

1182
    # Run a forward and backward pass
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        q = inp_orig[0]
        k = inp_orig[1]
        v = inp_orig[2]
        d_out = out_grad_orig
    if backend == "FusedAttention":
        q = inp[0]
        k = inp[1]
        v = inp[2]
        d_out = out_grad
1193
1194
1195
1196
    out = block(
        q,
        k,
        v,
1197
        window_size=config.window_size,
1198
1199
1200
1201
1202
1203
        attention_mask=attention_mask,
        qkv_format=qkv_format,
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1204
1205
        cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
        cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
1206
1207
1208
1209
1210
1211
        attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=config.attn_bias_type,
        core_attention_bias=bias,
        alibi_slopes=alibi_slopes,
        fast_zero_fill=True,
1212
1213
        # Only pass num_splits when exercising the FlashAttention path
        num_splits=config.num_splits if backend == "FlashAttention" else 1,
1214
    )
1215
1216
1217
    max_logit = None
    if config.return_max_logit:
        out, max_logit = out
1218
1219
    if is_training:
        out.backward(d_out)
1220

1221
1222
1223
    d_softmax_offset = None
    if is_training and config.softmax_type != "vanilla":
        d_softmax_offset = block.softmax_offset.grad
1224

1225
1226
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        if is_training:
1227
            return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
1228
        else:
1229
            return out, max_logit, (None, None, None, d_softmax_offset)
1230
    if backend == "FusedAttention":
1231
1232
        if qkv_format == "thd" and pad_between_seqs:
            out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1233
1234
1235
1236
            if is_training:
                q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
            for i in range(1, config.batch_size + 1):
                valid_range_q = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                valid_range_kv = (
                    cu_seqlens_kv_after_pad[i - 1],
                    cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                )
                out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
                if is_training:
                    q_grad_orig = torch.cat(
                        [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
                    )
                    k_grad_orig = torch.cat(
                        [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
                    v_grad_orig = torch.cat(
                        [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
1257
            if is_training:
1258
1259
1260
1261
1262
                return (
                    out_orig,
                    max_logit,
                    (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
                )
1263
            else:
1264
                return out_orig, max_logit, (None, None, None, d_softmax_offset)
1265
1266
        else:
            if is_training:
1267
                return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
1268
            else:
1269
                return out, max_logit, (None, None, None, d_softmax_offset)
1270

1271

1272
model_configs_te_layer = {
1273
    # test: ModelConfig(b, sq, hq, dqk)
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
    "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "te_1_1": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "te_1_2": ModelConfig(
        2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
    ),
    "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
    "te_2_1": ModelConfig(2, 2048, 16, 64),
    "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
    "te_2_3": ModelConfig(
        1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
    "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
1290
}
1291

1292

1293
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1294
@pytest.mark.parametrize("dtype", param_types)
1295
1296
1297
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
1298
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
1299
1300
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
1301
1302
1303
def test_transformer_layer(
    dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
1304
    """Test TransformerLayer module"""
1305

Tim Moon's avatar
Tim Moon committed
1306
    # Get configs
1307
    config = model_configs[model]
1308
    tols = dict(atol=5e-2, rtol=5e-2)
1309
    workspace_opt = True
1310

1311
    # Test backend availability
1312
    is_training = True
1313
    available_backends, _, fused_attn_backends = get_available_attention_backends(
Tim Moon's avatar
Tim Moon committed
1314
        config,
1315
        qkv_dtype=dtype,
1316
1317
1318
        qkv_layout=(
            qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
        ),
1319
        is_training=is_training,
1320
        deterministic=_deterministic,
Tim Moon's avatar
Tim Moon committed
1321
    )
1322
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1323
1324
    if not fused_attn_supported:
        is_training = False
1325
        available_backends, _, fused_attn_backends = get_available_attention_backends(
1326
1327
1328
1329
1330
1331
1332
1333
            config,
            qkv_dtype=dtype,
            qkv_layout=(
                qkv_format.replace("hd", "h3d")
                if fused_qkv_params
                else qkv_format.replace("hd", "3hd")
            ),
            is_training=is_training,
1334
            deterministic=_deterministic,
1335
1336
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1337
1338
1339

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
1340
        pytest.skip("Less than two backends to compare.")
1341
1342
1343
    # Skip if qkv_format = thd and "padding" not in attn_mask_type
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        pytest.skip("THD requires padding mask.")
Tim Moon's avatar
Tim Moon committed
1344
1345

    # UnfusedDotProductAttention backend
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_format,
            workspace_opt,
            fused_qkv_params,
            RoPE,
1356
            is_training,
1357
        )
Tim Moon's avatar
Tim Moon committed
1358
1359
1360
1361
1362
1363
1364

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FusedAttention",
1365
1366
1367
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1368
1369
            fused_qkv_params,
            RoPE,
1370
            is_training,
Tim Moon's avatar
Tim Moon committed
1371
        )
1372

Tim Moon's avatar
Tim Moon committed
1373
1374
1375
1376
1377
1378
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FlashAttention",
1379
1380
1381
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1382
1383
            fused_qkv_params,
            RoPE,
1384
            is_training,
Tim Moon's avatar
Tim Moon committed
1385
        )
1386

1387
    logging.info(f"[test_transformer_layer]: is_training = {is_training}")
1388
    if unfused_attn_supported and fused_attn_supported:
1389
        logging.info("[test_transformer_layer]: unfused attn vs fused attn")
1390
1391
1392
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
    if unfused_attn_supported and flash_attn_supported:
1393
        logging.info("[test_transformer_layer]: unfused attn vs flash attn")
Tim Moon's avatar
Tim Moon committed
1394
1395
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
1396
    if fused_attn_supported and flash_attn_supported:
1397
        logging.info("[test_transformer_layer]: fused attn vs flash attn")
1398
1399
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
1400

1401

1402
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1403
@pytest.mark.parametrize("dtype", param_types_lean)
1404
1405
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
1406
1407
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format):
hugo-syn's avatar
hugo-syn committed
1408
    """Test TransformerLayer module with miscellaneous settings"""
1409
1410
1411
    ckpt_attn = True
    fused_qkv_params = True
    RoPE = True
1412
1413
1414
    test_transformer_layer(
        dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
    )
1415

1416

1417
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1418
1419
1420
1421
1422
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
    """Test TransformerLayer module with MQA/GQA"""
1423

1424
    def find_factors(x):
1425
1426
1427
1428
1429
        f = []
        for i in range(2, x + 1):
            if x % i == 0:
                f.append(i)
        return f
1430

1431
1432
1433
1434
1435
1436
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    config = model_configs[model]
    num_querys_per_gqa_group = find_factors(config.num_heads)
1437
1438

    for num_q_per_gqa_group in num_querys_per_gqa_group:
1439
1440
1441
1442
        config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
        test_transformer_layer(
            dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
        )
1443

1444

1445
def _run_transformer_layer(
1446
1447
1448
1449
1450
1451
1452
1453
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_format: str,
    workspace_opt: bool,
    fused_qkv_params: bool,
    RoPE: bool,
1454
    is_training: bool,
1455
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
1456
1457
1458
    """Run TransformerLayer module with one forward pass and one backward pass"""

    # Set RNG and environment variables
1459
    reset_rng_states()
1460
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
1461
    os.environ["NVTE_FUSED_ATTN"] = "0"
1462
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
1463
1464
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
1465
1466
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
1467
1468
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
1469
    _attention_backends["backend_selection_requires_update"] = True
1470

1471
    # Create input tensor
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
    if qkv_format == "sbhd":
        inp = torch.randn(
            config.max_seqlen_q,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.max_seqlen_kv,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1489
    if qkv_format == "bshd":
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
        inp = torch.randn(
            config.batch_size,
            config.max_seqlen_q,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.batch_size,
            config.max_seqlen_kv,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1506
1507

    # Create seqlens
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
1524
    else:
1525
1526
1527
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
    if qkv_format == "thd":
        inp = torch.randn(
            cu_seqlens_q[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            cu_seqlens_kv[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1550
1551
1552
1553
1554
1555
1556

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
1557
    drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
1558

1559
    # Create bias
1560
    bias = None
1561
1562
1563
1564
1565
1566
1567
1568
1569
    if config.attn_bias_type == "post_scale_bias":
        bias = torch.randn(
            1,
            config.num_heads,
            config.max_seqlen_q,
            config.max_seqlen_kv,
            dtype=dtype,
            device="cuda",
        )
1570
1571
1572
1573

    # Create RoPE
    rotary_pos_emb = None
    if RoPE:
1574
        PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
1575
        rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1576
1577

    # Set up model
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_heads,
        num_gqa_groups=config.num_gqa_groups,
        layernorm_epsilon=1e-5,
        hidden_dropout=0.0,
        attention_dropout=config.dropout_p,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        layer_number=layer_number,
1589
        kv_channels=config.head_dim_qk,
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
        self_attn_mask_type=config.attn_mask_type,
        tp_group=None,
        tp_size=1,
        params_dtype=dtype,
        get_rng_state_tracker=None,
        fuse_wgrad_accumulation=False,
        seq_length=config.max_seqlen_q,
        micro_batch_size=config.batch_size,
        sequence_parallel=False,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
1601
        layer_type="encoder" if config.attn_type == "self" else "decoder",
1602
1603
1604
1605
1606
1607
1608
1609
1610
        drop_path_rate=drop_path_rates[layer_number - 1],
        set_parallel_mode=True,
        fuse_qkv_params=fused_qkv_params,
        zero_centered_gamma=False,
        qkv_weight_interleaved=False,
        ub_tp_comm_overlap=False,
        bias=True,
        attn_input_format=qkv_format,
    ).to(dtype=dtype, device="cuda")
1611
1612
    if not is_training:
        block = block.eval()
1613

1614
1615
1616
    # Create ALiBi slopes
    alibi_slopes = None
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
1617
        alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
1618

1619
    # Run a forward and backward pass
1620
1621
    out = block(
        inp,
1622
        self_attn_mask_type=config.attn_mask_type,
1623
1624
        encoder_output=inp_enc if config.attn_type == "cross" else None,
        enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
1625
1626
1627
        checkpoint_core_attention=False,
        rotary_pos_emb=rotary_pos_emb,
        core_attention_bias_type=config.attn_bias_type,
1628
        core_attention_bias=bias,
1629
        alibi_slopes=alibi_slopes,
1630
1631
1632
1633
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1634
    )
1635
1636
1637
    if is_training:
        loss = out.sum()
        loss.backward()
1638
1639

    return out, inp.grad
1640
1641


1642
model_configs_fp8_extra_state = {
1643
    # test: ModelConfig(b, sq, hq, dqk)
1644
1645
1646
1647
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
}


1648
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1649
1650
1651
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
1652
1653
def test_dpa_fp8_extra_state(model, dtype):
    """Test DotProductAttention module in FP8 with checkpointing"""
1654
1655
1656
1657
1658
1659
1660
1661
    config = model_configs_fp8_extra_state[model]
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="sb3hd",
        is_training=is_training,
1662
        deterministic=_deterministic,
1663
1664
1665
1666
1667
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not fused_attn_supported and not flash_attn_supported:
        pytest.skip("No attention backend available.")

1668
1669
1670
    outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False)
    outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True)
    outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state(
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
        dtype, config, mimic_v1_6=True, checkpoint=True
    )

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )


1692
1693
def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
    """Run DotProductAttention module in FP8 with checkpointing"""
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
    steps = 10
    path = "checkpoint.pt"
    fp8_enabled = True
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_enabled,
        fp8_mha=False,
    )

    reset_rng_states()
    hidden_states = torch.randn(
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

    def get_model(dtype, config):
        sigma = 0.023
        init_method = init_method_normal(sigma)
        output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1719
        with quantized_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
            block = TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_heads,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                hidden_dropout=0.0,
                attention_dropout=0.0,
                fuse_qkv_params=True,
                params_dtype=dtype,
                device="cuda",
            )
        return block

    block = get_model(dtype, config)
    for i in range(steps // 2):
1736
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    if checkpoint:
        sd = block.state_dict()
        if mimic_v1_6:
            sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
                "self_attention.core_attention._extra_state"
            ]
            del sd["self_attention.core_attention._extra_state"]
        torch.save(sd, path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

        _cpu_rng_state_new = torch.get_rng_state()
        _cuda_rng_state_new = torch.cuda.get_rng_state()

        del block
        block = get_model(dtype, config)
        block.load_state_dict(torch.load(path, weights_only=False))
        torch.set_rng_state(_cpu_rng_state_new)
        torch.cuda.set_rng_state(_cuda_rng_state_new)

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for i in range((steps + 1) // 2):
1771
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [output, hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)

    return outputs


1789
model_configs_fp8_vs_f16 = {
1790
    # test: ModelConfig(b, sq, hq, dqk)
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
    "fp8_9": ModelConfig(2, 2048, 16, 128),
    "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
    "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
    "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
    "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
    "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
    "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
    "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
    "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
    "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
1803
}
1804

1805
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
1806
1807
1808
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]

1809

1810
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1811
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1812
1813
1814
1815
1816
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1817
@pytest.mark.parametrize("RoPE", [True, False])
1818
@pytest.mark.parametrize("is_training", [True, False])
1819
1820
1821
1822
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_mha_fp8_vs_f16(
    dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
):
1823
    """Test MultiHeadAttention module in FP8"""
1824
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1825
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1826
1827
    config = model_configs_fp8_vs_f16[model]

1828
    # Test backend availability
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
            fp8_mha=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
            fp8_mha=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
1846
1847
1848
1849
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_format.replace("hd", "h3d"),
1850
1851
        fp8=True,
        fp8_meta=fp8_meta,
1852
        is_training=is_training,
1853
        deterministic=_deterministic,
1854
    )
1855
1856
    flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
    if flash_attn_supported + fused_attn_supported_fp8 < 1:
1857
        pytest.skip("No FP8 attention backend available.")
1858
    fused_attn_supported_f16 = False
1859
1860
1861
1862
1863
1864
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_format.replace("hd", "h3d"),
            is_training=is_training,
1865
            deterministic=_deterministic,
1866
        )
1867
1868
        _, fused_attn_supported_f16, _ = available_backends
        if not fused_attn_supported_f16:
1869
1870
1871
            pytest.skip("No attention backend available.")

    if flash_attn_supported:
1872
1873
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
1874
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
1875
1876
1877
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1878
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1879
        )
1880

1881
1882
1883
    if fused_attn_supported_fp8:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "1"
1884
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
1885
1886
1887
1888
1889
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
        )
1890

1891
1892
1893
    if fused_attn_supported_f16:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "1"
1894
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
1895
1896
1897
1898
1899
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
        fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
            dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
        )
1900

1901
1902
1903
    atol = 5e-1
    rtol = 5e-1
    rmse_tol = 0.15
1904
    if flash_attn_supported and fused_attn_supported_f16:
1905
1906
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
1907
        compare_and_assert(
1908
1909
1910
1911
1912
1913
1914
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
1915
            True,
1916
        )
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
    if fused_attn_supported_fp8 and fused_attn_supported_f16:
        logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
        compare_and_assert(
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
            True,
        )
1930

1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
        if is_training:
            for i in range(len(param_names[:1])):
                logging.debug("========== {:^25s} ==========".format(param_names[i]))
                compare_and_assert(
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                    True,
                )
1944

1945

1946
1947
1948
def _run_mha_fp8_vs_f16(
    dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
):
1949
    """Run MultiHeadAttention module in FP8"""
1950
1951
1952
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1953

1954
1955
1956
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
Tim Moon's avatar
Tim Moon committed
1957

1958
    with quantized_model_init(enabled=fp8_mha, recipe=fp8_recipe):
1959
1960
1961
1962
        rotary_pos_emb = None
        if RoPE:
            PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
            rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1963
        mha = MultiheadAttention(
1964
1965
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_heads,
1966
            kv_channels=config.head_dim_qk,
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            layer_number=1,
            bias=True,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            params_dtype=dtype,
            input_layernorm=input_layernorm,
            fuse_qkv_params=True,
            attention_type="self",
            qkv_weight_interleaved=True,
            qkv_format=qkv_format,
1978
        ).to(dtype=dtype, device="cuda")
1979
1980
        if not is_training:
            mha = mha.eval()
1981

1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
2002
2003
2004
2005
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
2006

2007
    dim_to_num = {
2008
2009
2010
2011
2012
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2013
        "d": config.head_dim_qk,
2014
2015
2016
2017
2018
2019
2020
2021
2022
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
    layout = "_".join(qkv_format)
    layout = layout.replace("s", "sq")
    tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2023
2024
    tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
    hidden_states = tensor.view(*tensor.shape[:-2], -1)
2025
2026
    if is_training:
        hidden_states.requires_grad = True
2027
2028
2029
    tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
    out_grad = tensor.view(*tensor.shape[:-2], -1)

2030
    with autocast(enabled=fp8_mha, recipe=fp8_recipe):
2031
2032
        out = mha(
            hidden_states,
2033
2034
2035
2036
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
            is_first_microbatch=None,
2037
            rotary_pos_emb=rotary_pos_emb,
2038
2039
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
2040
        )
2041
2042
    if is_training:
        out.backward(out_grad)
Tim Moon's avatar
Tim Moon committed
2043

2044
    param_names = []
2045
    param_names.append("hidden_states.grad")
2046
2047
2048
2049
    params = []
    params.append(hidden_states)
    for name, param in mha.named_parameters():
        if param.requires_grad:
2050
            param_names.append(name + ".grad")
2051
            params.append(param)
2052

2053
2054
2055
    if is_training:
        return out, param_names, tuple(x.grad for x in params)
    return out, param_names, tuple(None for x in params)
2056

2057

2058
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
2059
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
2060
2061
2062
2063
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
2064
@pytest.mark.parametrize("is_training", [True, False])
2065
2066
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
2067
    """Test DotProductAttention module in FP8"""
2068
2069
    config = model_configs_fp8_vs_f16[model]

2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
    # TODO(cyang): think of another way to verify dropout results
    # test cuDNN FP8 dropout
    # 1. we modify the config here to not affect mha_fp8_vs_f16 tests
    # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an
    #    indirect verification method, we create Q/K/V as all 1s and check if O is all 1s
    # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell
    # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type:
    #    if get_device_compute_capability() >= (10, 0):
    #        config.dropout_p = 0.1

2080
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
2081
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
2082
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"
2083

2084
    # Test backend availability
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
2100
2101
2102
2103
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_layout,
2104
2105
        fp8=True,
        fp8_meta=fp8_meta,
2106
        is_training=is_training,
2107
        deterministic=_deterministic,
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if flash_attn_supported + fused_attn_supported < 1:
        pytest.skip("No FP8 attention backend available.")
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            is_training=is_training,
2118
            deterministic=_deterministic,
2119
2120
2121
2122
2123
2124
2125
2126
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")

    if flash_attn_supported:
2127
2128
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
2129
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
2130
        _attention_backends["backend_selection_requires_update"] = True
2131
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
2132
        flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2133
2134
2135
2136
2137
2138
            dtype, config, True, qkv_layout, is_training, fp8_recipe
        )

    if unfused_attn_supported:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "0"
2139
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2140
2141
2142
2143
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
        unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
            dtype, config, True, qkv_layout, is_training, fp8_recipe
2144
        )
2145

2146
2147
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
2148
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2149
    _attention_backends["backend_selection_requires_update"] = True
2150
    logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
2151
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2152
        dtype, config, True, qkv_layout, is_training, fp8_recipe
2153
    )
2154

2155
2156
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
2157
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2158
2159
    if config.dropout_p == 0.0:
        # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
2160
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
2161
        fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
2162
            dtype, config, False, qkv_layout, is_training, fp8_recipe
2163
        )
2164

2165
2166
    atol = 5e-1
    rtol = 5e-2
2167
    rmse_tol = 0.11
2168
    bwd_names = ["dq", "dk", "dv"]
2169
    if flash_attn_supported:
2170
2171
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2172
        compare_and_assert(
2173
2174
2175
2176
2177
2178
2179
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2180
            True,
2181
        )
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
    if unfused_attn_supported:
        logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
        compare_and_assert(
            unfused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "unfused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
            True,
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
                compare_and_assert(
                    unfused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"unfused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                    True,
                )
2208
2209
2210
2211
2212
2213
    if config.dropout_p != 0.0:
        # test cuDNN FP8 dropout
        assert torch.all(
            fused_attn_fwd_fp8 == 1
        ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
    else:
2214
2215
        logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2216
        compare_and_assert(
2217
2218
2219
2220
2221
2222
2223
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2224
            True,
2225
2226
2227
2228
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
2229
                compare_and_assert(
2230
2231
2232
2233
2234
2235
2236
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
2237
                    True,
2238
                )
2239
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0"
2240
2241


2242
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe):
2243
    """Run DotProductAttention module in FP8"""
2244
2245
2246
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
2247

2248
2249
2250
2251
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

2252
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
2253
    with quantized_model_init(enabled=fp8_dpa):
2254
2255
        dpa = DotProductAttention(
            config.num_heads,
2256
            config.head_dim_qk,
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            sequence_parallel=False,
            tp_size=1,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            tp_group=None,
            layer_number=1,
            attention_type="self",
            qkv_format=qkv_format,
        ).to(dtype=dtype, device="cuda")
2267
2268
        if not is_training:
            dpa = dpa.eval()
2269

2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
2290
2291
2292
2293
2294
2295
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

    dim_to_num = {
2296
2297
2298
2299
2300
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2301
        "d": config.head_dim_qk,
2302
2303
2304
2305
2306
2307
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
2308
    inp = []
2309
2310
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
2311
        if i == 0:
2312
            layout = layout.replace("s", "sq")
2313
        else:
2314
2315
2316
2317
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2318
2319
2320
2321
2322
        if config.dropout_p == 0.0:
            tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
        else:
            # test cuDNN FP8 dropout
            tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda")
2323
2324
        tensor_count = 1
        split_dim = 0
2325
        for dim, l in enumerate(layout.split("_")):
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
        inp[i].requires_grad = True

2339
2340
2341
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
2342
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
2343
    out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
2344

2345
    with autocast(enabled=fp8_dpa, recipe=fp8_recipe):
2346
2347
2348
2349
        out = dpa(
            inp[0],
            inp[1],
            inp[2],
2350
2351
2352
2353
2354
2355
2356
2357
            qkv_format=qkv_format,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
2358
            fp8_output=fp8_dpa,
2359
        )
2360
2361
    if is_training:
        out.backward(out_grad)
2362

2363
2364
2365
    if is_training:
        return out, (inp[0].grad, inp[1].grad, inp[2].grad)
    return out, (None, None, None)
2366
2367
2368


model_configs_fp8 = {
2369
    # test: ModelConfig(b, sq, hq, dqk)
2370
2371
2372
2373
2374
2375
2376
2377
    "fp8_1": ModelConfig(1, 512, 1, 64),
    "fp8_2": ModelConfig(4, 512, 16, 64),
    "fp8_3": ModelConfig(1, 2048, 1, 128),
    "fp8_4": ModelConfig(2, 2048, 24, 128),
    "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
    "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
    "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
    "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
2378
2379
}
param_types_fp8 = [torch.float16, torch.bfloat16]
2380
2381
2382
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
2383
2384


2385
2386
2387
2388
2389
2390
2391
2392
@pytest.mark.skipif(
    (
        get_cudnn_version() < (8, 9, 3)
        if cudnn_frontend_version == 0
        else get_cudnn_version() < (9, 2, 1)
    ),
    reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
2393
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
    """Test FP8 dot product attention implementations based on cuDNN frontend
    v0.9 and v1.0+. Each test compares results from a custom implementation of
    an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
    implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
    Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""

    config = model_configs_fp8[model]

2405
2406
2407
2408
2409
2410
2411
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
        is_training=is_training,
2412
        deterministic=_deterministic,
2413
2414
2415
2416
2417
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not (fused_attn_backends and unfused_attn_supported):
        pytest.skip("Not enough backends to run this test with.")

2418
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
2419
2420
2421
    unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
        dtype, config, "UnfusedDotProductAttention"
    )
2422

2423
2424
    atol = 5e-1
    rtol = 5e-1
2425
    rmse_tol = 0.13
2426
    compare_and_assert(
2427
2428
2429
2430
2431
2432
2433
        fused_attn_fwd_fp8,
        unfused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "unfused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
2434
        True,
2435
    )
2436
    compare_and_assert(
2437
2438
2439
2440
2441
2442
2443
        fused_attn_bwd_fp8,
        unfused_attn_bwd_f16,
        "fused_attn_bwd_fp8",
        "unfused_attn_bwd_f16",
        atol,
        rtol,
        rmse_tol,
2444
        True,
2445
    )
2446
2447
2448
2449
2450


def _run_custom_mha_fp8(dtype, config, backend):
    """Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
2451
    reset_rng_states()
2452
2453
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
2454
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
2455
2456
2457
2458
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2459
2460
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2461
    _attention_backends["backend_selection_requires_update"] = True
2462

2463
2464
2465
    inp = 0.0001 * torch.randint(
        -100,
        100,
2466
        (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
2467
2468
2469
2470
2471
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
2472
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2473
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2474

2475
    out_grad = 0.01 * torch.randn(
2476
        config.batch_size * config.max_seqlen_q,
2477
        config.num_heads * config.head_dim_qk,
2478
2479
2480
2481
        dtype=dtype,
        device="cuda",
    )
    torch.save(out_grad, "out_grad.pt")
2482
2483
2484
2485
2486
2487
2488
2489

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

2490
    mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
2491
    with autocast(enabled=True, recipe=fp8_recipe):
2492
        out = mha(inp, cu_seqlens, config.max_seqlen_q)
2493
    out.backward(out_grad)
2494

2495
    out = torch.load("out.pt")
2496
2497
2498
2499
    dqkv = torch.load("dqkv.pt")
    return (
        out.view(config.batch_size, config.max_seqlen_q, -1),
        dqkv.view(
2500
            config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
2501
2502
        ).contiguous(),
    )
2503

2504

2505
2506
2507
def _run_ref_mha_f16(dtype, config, backend):
    """Run reference F16 FusedAttention. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
2508
2509
2510

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
2511
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2512
2513
2514
2515
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2516
2517
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2518
    _attention_backends["backend_selection_requires_update"] = True
2519

2520
    inp = torch.load("qkv.pt").to(device="cuda")
2521
2522
2523
    inp.requires_grad = True
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2524
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2525
2526
2527
    out_grad = (
        torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
    )
2528
2529
2530
2531
2532
2533
2534

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
2535

2536
2537
    block = DotProductAttention(
        config.num_heads,
2538
        config.head_dim_qk,
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
        attention_dropout=config.dropout_p,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type="self",
        qkv_format="bshd",
    ).to(dtype=dtype, device="cuda")

    q = inp[:, :, 0, :, :]
    k = inp[:, :, 1, :, :]
    v = inp[:, :, 2, :, :]
2552
2553
2554
2555
    out = block(q, k, v, attn_mask_type=config.attn_mask_type)
    out.backward(out_grad)

    return out, inp.grad
2556
2557
2558
2559
2560
2561
2562


_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

2563
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
2564
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
2565
2566
2567
2568
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
2569
2570


2571
class _custom_mha_fp8(torch.autograd.Function):
2572
2573
2574
2575
2576
2577
2578
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
2579
        num_heads: int,
2580
2581
2582
2583
2584
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        is_training: bool,
2585
        mask_type: str,
2586
        quantizers: list[Quantizer],
2587
    ) -> torch.Tensor:
2588
        qkv_dtype = inp.dtype
2589
2590
2591

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
2592
        h = num_heads
2593
2594
2595
        d = in_features // h
        b = cu_seqlens.numel() - 1

2596
2597
2598
2599
2600
2601
2602
2603
        input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
        qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
        qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
        dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
        dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
2604

2605
        inp_fp8 = input_quantizer(inp)
2606

2607
        qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight)
2608

2609
        qkv, *_ = ext.general_gemm(
2610
            qkv_weight_fp8,
2611
            inp_fp8,
2612
            bias=qkv_bias,
2613
2614
            out_dtype=qkv_weight_fp8.dtype,
            quantization_params=qkv_quantizer,
2615
2616
            use_split_accumulator=_2X_ACC_FPROP,
        )
2617
        qkv = qkv.view(-1, 3, h, d)
2618
        qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
2619
        torch.save(qkv_fp16, "qkv.pt")
2620
        if cudnn_frontend_version == 1:
2621
            qkv = qkv.view(b, max_s, 3, h, d)  # bs3hd
2622
2623

        # FMHA
2624
2625
2626
2627
2628
2629
2630
2631
        q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
        k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
        v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
        q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
        k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
        v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)

        out, aux_ctx_tensors = fused_attn_fwd(
2632
2633
2634
2635
2636
            is_training,
            max_s,
            max_s,
            cu_seqlens,
            cu_seqlens,
2637
2638
2639
2640
            q,
            k,
            v,
            qkv_dtype,
2641
2642
2643
2644
2645
2646
2647
2648
            FusedAttnBackend["FP8"],
            attn_scale=None,
            dropout=p_dropout,
            fast_zero_fill=fast_zero_fill,
            qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
            attn_bias_type="no_bias",
            attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
            rng_gen=None,
2649
2650
            o_quantizer=o_quantizer,
            s_quantizer=s_quantizer,
2651
        )
2652

2653
        tensors_to_save, tensor_objects = prepare_for_saving(q, k, v, inp_fp8, qkv_weight_fp8, out)
2654
2655
2656

        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
2657
        ctx.aux_ctx_tensors = aux_ctx_tensors
2658
        ctx.qkv_dtype = qkv_dtype
2659
2660
2661
2662
2663
2664
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.hidden_size = in_features
2665
        ctx.num_heads = num_heads
2666
2667
        ctx.mask_type = mask_type
        ctx.dtype = inp.dtype
2668

2669
2670
2671
2672
2673
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = s_quantizer

2674
        out = out.view(-1, in_features)  # (bs)(hd)
2675
        out_fp16 = out.dequantize()
2676
        torch.save(out_fp16, "out.pt")  # (bs)(hd)
2677
        return out_fp16
2678
2679

    @staticmethod
2680
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
2681
        with torch.cuda.nvtx.range("_DPA"):
2682
            saved_tensors = ctx.saved_tensors
2683
            (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
2684
2685
                ctx.tensor_objects, saved_tensors
            )
2686

2687
            proj_dgrad = ctx.dO_quantizer(grad_output)
2688
            fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2689

2690
            dq, dk, dv, *rest = fused_attn_bwd(
2691
2692
2693
2694
                ctx.max_s,
                ctx.max_s,
                ctx.cu_seqlens,
                ctx.cu_seqlens,
2695
2696
2697
                q,
                k,
                v,
2698
2699
                out,
                proj_dgrad.view_as(out),
2700
                ctx.qkv_dtype,
2701
2702
2703
2704
2705
                fp8_dtype_backward,
                ctx.aux_ctx_tensors,
                FusedAttnBackend["FP8"],
                None,
                None,
2706
2707
2708
                ctx.S_quantizer,
                ctx.dP_quantizer,
                ctx.dQKV_quantizer,
2709
2710
2711
2712
2713
2714
2715
                attn_scale=None,
                dropout=ctx.p_dropout,
                fast_zero_fill=ctx.fast_zero_fill,
                qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
                attn_bias_type="no_bias",
                attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
            )
2716
            dim = 2 if cudnn_frontend_version == 1 else 1
2717
2718
            dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
            dqkv_shape = list(dq._data.shape)
2719
            dqkv_shape.insert(dim, 3)
2720
            dqkv_stride = list(dq._data.stride())
2721
            dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
2722
2723
2724
            dqkv.set_(
                dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride
            )  # bs3hd
2725

2726
            dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
2727
2728
            dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape)
            dqkv_c_fp16 = dqkv_c.dequantize()
2729
            torch.save(dqkv_c_fp16, "dqkv.pt")
2730

2731
2732
2733
            qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer)
            dqkv_c._transpose = None
            dqkv_c._create_transpose()
2734
2735

            # QKV DGRAD
2736
2737
            qkv_dgrad, *_ = ext.general_gemm(
                qkv_weight_fp8,
2738
                dqkv_c,
2739
                ctx.dtype,
2740
                use_split_accumulator=_2X_ACC_DGRAD,
2741
                layout="NN",
2742
            )
2743

2744
            # QKV WGRAD
2745
2746
2747
2748
            qkv_wgrad, *_ = ext.general_gemm(
                inp_fp8,
                dqkv,
                ctx.dtype,
2749
                use_split_accumulator=_2X_ACC_WGRAD,
2750
                layout="NT",
2751
2752
            )

2753
2754
        return (
            qkv_dgrad,
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2766
2767
            None,
        )
2768

2769

2770
class Custom_MHA_FP8(TransformerEngineBaseModule):
2771
    def __init__(self, config, params_dtype: torch.dtype = torch.float32):
2772
2773
        super().__init__()
        self.p_dropout = config.dropout_p
2774
        self.h = config.num_heads
2775
        self.hidden_size = config.hidden_size
2776
        self.head_dim = config.head_dim_qk
2777
        self.fast_zero_fill = True
2778
        self.mask_type = config.attn_mask_type
2779

Tim Moon's avatar
Tim Moon committed
2780
        self.qkv_weight = torch.nn.Parameter(
2781
2782
2783
2784
2785
2786
2787
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
Tim Moon's avatar
Tim Moon committed
2788
        self.qkv_bias = torch.nn.Parameter(
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)

    def forward(
2800
2801
2802
2803
        self,
        inp: torch.Tensor,
        cu_seqlens,
        max_s,
2804
    ) -> torch.Tensor:
2805
        with self.prepare_forward_ctx(inp, num_gemms=3) as inp:
2806
            out = _custom_mha_fp8.apply(
2807
2808
2809
2810
2811
2812
2813
2814
2815
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
2816
                self.training,
2817
                self.mask_type,
2818
                self.quantizers,
2819
            )
2820
        return out