test_attention.py 106 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
import logging
Tim Moon's avatar
Tim Moon committed
5
import os
6
7
import sys
import pathlib
8
from typing import Any, Dict, Tuple, Union
Tim Moon's avatar
Tim Moon committed
9

10
import pytest
Tim Moon's avatar
Tim Moon committed
11
import torch
12

13
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype
Tim Moon's avatar
Tim Moon committed
14
from transformer_engine.common import recipe
15
16
17
18
from transformer_engine.pytorch import (
    TransformerLayer,
    autocast,
    quantized_model_init,
Tim Moon's avatar
Tim Moon committed
19
    DotProductAttention,
20
21
22
23
24
25
26
    MultiheadAttention,
    get_device_compute_capability,
    Quantizer,
    is_fp8_available,
    is_bf16_available,
)
from transformer_engine.pytorch.attention.dot_product_attention import (
27
28
    _attention_backends,
)
29
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
30
    FlashAttentionUtils,
31
    check_set_window_size,
Tim Moon's avatar
Tim Moon committed
32
)
33
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
Tim Moon's avatar
Tim Moon committed
34
35
36
37
38
39
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    FusedAttnBackend,
    fused_attn_bwd,
    fused_attn_fwd,
)
40
41
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
42
43
44
45
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
46
from transformer_engine.pytorch.utils import get_cudnn_version
47
import transformer_engine_torch as tex
48
49
from transformer_engine.pytorch.quantized_tensor import (
    Quantizer,
50
51
52
    prepare_for_saving,
    restore_from_saved,
)
53

54
55
56
57
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
    reset_rng_states,
58
    compare_and_assert,
59
60
61
62
63
    ModelConfig,
    dtype_tols,
    get_available_attention_backends,
)

64
# Check if hardware supports FP8 attention.
65
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
66
67
68
69
70
71
72
73
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability()
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
    fp8_attn_available = False
    reason_for_no_fp8_attn = (
        "FP8 attention is not supported for compute capability ="
        f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
    )
74

75
76
77
78
79
80
81
82

# Get determinism
_deterministic = (
    not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
    or torch.are_deterministic_algorithms_enabled()
)


83
# Reset RNG seed and states
84
seed = 1234
85
reset_rng_states()
86
87


88
# Reset FP8 global state manager
89
90
91
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
92
    FP8GlobalStateManager.reset()
93

94

95
96
# Define F16 data types to test
param_types = [torch.float16]
97
if is_bf16_available():
98
99
100
    param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]

101
model_configs_base = {
102
    # test: ModelConfig(b, sq, hq, dqk)
103
104
105
106
107
108
109
110
111
112
113
114
    "base_1_0": ModelConfig(8, 128, 16, 64),
    "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
    "base_2_0": ModelConfig(2, 2048, 24, 128),
    "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
    "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
    "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
    "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
    "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
    "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
    "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
    "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
115
116
}

117

118
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
119
@pytest.mark.parametrize("dtype", param_types)
120
121
122
123
124
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
125
@pytest.mark.parametrize("swa", [False])
126
@pytest.mark.parametrize("pad_between_seqs", [False])
127
def test_dot_product_attention(
128
129
130
131
132
133
134
135
    dtype,
    model_configs,
    model,
    ckpt_attn,
    workspace_opt,
    qkv_layout,
    swa,
    pad_between_seqs,
136
):
137
    """Test DotProductAttention module"""
138

Tim Moon's avatar
Tim Moon committed
139
    # Get configs
140
    tols = dict(atol=1e-3, rtol=1e-3)
Tim Moon's avatar
Tim Moon committed
141
    if dtype == torch.bfloat16:
142
        tols = dict(atol=1.5e-2, rtol=1.5e-2)
143
    config = model_configs[model]
144
    is_mla = config.head_dim_qk != config.head_dim_v
145
    is_mqa_gqa = config.num_heads != config.num_gqa_groups
146
147
    if qkv_layout is None:
        if config.attn_type == "self":
148
            qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
149
        else:
150
            qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
151
    if "3" in qkv_layout and config.attn_type == "cross":
152
        pytest.skip("No need to test this layout for cross attention")
Tim Moon's avatar
Tim Moon committed
153

154
155
156
    if config.window_size == (-1, -1) and swa:
        config.window_size = [2, 2]
    config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
157
158
159
160
161
    qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        config.attn_mask_type = (
            "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
        )
162

163
    # Get backends
164
    is_training = True
165
    available_backends, _, fused_attn_backends = get_available_attention_backends(
166
        config,
167
        qkv_dtype=dtype,
168
        qkv_layout=qkv_layout,
169
        pad_between_seqs=pad_between_seqs,
170
        is_training=is_training,
171
        deterministic=_deterministic,
Tim Moon's avatar
Tim Moon committed
172
    )
173
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
174
175
    if not fused_attn_supported:
        is_training = False
176
        available_backends, _, fused_attn_backends = get_available_attention_backends(
177
178
179
180
181
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            pad_between_seqs=pad_between_seqs,
            is_training=is_training,
182
            deterministic=_deterministic,
183
184
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
185

186
187
    # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
    # mannually pads and unpads the input and output of FlashAttention for testing purposes
188
189
    if (
        pad_between_seqs
190
        and FlashAttentionUtils.is_installed
191
192
193
194
        and not (
            config.max_seqlen_q != config.max_seqlen_kv
            and config.attn_mask_type in ["causal", "padding_causal"]
        )
195
        and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
196
    ):
197
        flash_attn_supported = True
198
199
200

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
201
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
202
203

    # UnfusedDotProductAttention backend
204
    if unfused_attn_supported:
205
        unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
206
207
208
209
210
211
212
213
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
214
        )
Tim Moon's avatar
Tim Moon committed
215
216
217

    # FusedAttention backend
    if fused_attn_supported:
218
        if len(fused_attn_backends) == 1:
219
            fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
220
221
222
223
224
225
226
227
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
228
            )
229
        if len(fused_attn_backends) == 2:
230
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
231
            fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
232
233
234
235
236
237
238
239
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
240
241
            )
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
242
            fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
243
244
245
246
247
248
249
250
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
251
            )
252

Tim Moon's avatar
Tim Moon committed
253
254
    # FlashAttention backend
    if flash_attn_supported:
255
        flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
256
257
258
259
260
261
262
263
            dtype,
            config,
            "FlashAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
Tim Moon's avatar
Tim Moon committed
264
        )
265

266
    # Compare results
267
    logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
268
    if unfused_attn_supported and flash_attn_supported:
269
        logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
270
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
271
        for i, _ in enumerate(flash_attn_bwd):
272
            torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
273
274
275
    if unfused_attn_supported and fused_attn_supported:
        logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
276
277
        if config.return_max_logit:
            torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
278
279
        for i, _ in enumerate(unfused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
280
    if fused_attn_supported and flash_attn_supported:
281
        logging.info("[test_dot_product_attention]: fused attn vs flash attn")
282
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
283
        for i, _ in enumerate(flash_attn_bwd):
284
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
285
    if fused_attn_supported and len(fused_attn_backends) == 2:
286
        logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
287
        torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
288
        for i, _ in enumerate(fused_attn_bwd):
289
290
            torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)

291

292
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
293
294
295
296
297
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
    """Test DotProductAttention module with checkpointing"""
298
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
299

300

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
model_configs_max_logit = {
    # test: ModelConfig(b, sq, hq, dqk)
    "max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
    "max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "max_logit_4": ModelConfig(
        8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
    ),
    "max_logit_5": ModelConfig(
        8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
    ),
    "max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with checkpointing"""
    config = model_configs[model]
    config.return_max_logit = True
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
model_configs_num_splits = {
    # test: ModelConfig(b, sq, hq, dqk)
    "num_splits_1_0": ModelConfig(2, 2048, 24, 128, num_splits=2),
    "num_splits_1_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_num_splits])
@pytest.mark.parametrize("model", model_configs_num_splits.keys())
def test_dpa_num_splits(dtype, model_configs, model):
    """Test DotProductAttention with FlashAttention-3 num_splits enabled"""
    test_dot_product_attention(
        dtype,
        model_configs,
        model,
        False,
        True,
        None,
        False,
        False,
    )


353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
model_configs_softmax = {
    # test: ModelConfig(b, sq, hq, dqk)
    "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
    "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"),
    "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"),
    "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
    "softmax_2_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
    ),
    "softmax_2_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
    ),
    "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"),
    "softmax_3_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one"
    ),
    "softmax_3_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable"
    ),
    "softmax_4_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal"
    ),
    "softmax_4_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="off-by-one",
    ),
    "softmax_4_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="learnable",
    ),
    "softmax_5_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal"
    ),
    "softmax_5_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="off-by-one",
    ),
    "softmax_5_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="learnable",
    ),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax(dtype, model_configs, model):
    """Test DotProductAttention module with different softmax types"""
    test_dot_product_attention(
        dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False
    )


432
433
434
435
436
437
438
439
440
@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax_thd(dtype, model_configs, model):
    """Test DotProductAttention module with different softmax types"""
    test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False)


441
model_configs_mla = {
442
443
444
445
446
    # test: ModelConfig(b, sq, hq, dqk)
    "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),
    "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
    "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),
    "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),
447
    "mla_2_1": ModelConfig(
448
        1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
449
    ),
450
    "mla_2_2": ModelConfig(
451
        1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
452
453
454
455
456
457
    ),
    "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),
    "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),
    "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),
    "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),
    "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),
458
459
460
461
462
463
464
465
466
467
468
469
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
    """Test DotProductAttention module with Multi-Latent Attention (MLA)"""
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


470
model_configs_mask = {
471
    # test: ModelConfig(b, sq, hq, dqk)
472
473
474
475
476
477
    "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
    "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "mask_2_1": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
478
    ),
479
480
481
482
483
484
485
486
487
488
    "mask_2_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
    ),
    "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
489
    "mask_5_1": ModelConfig(
490
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
491
492
    ),
    "mask_5_2": ModelConfig(
493
494
495
496
497
498
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_7_0": ModelConfig(
        2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
499
    ),
500
501
502
503
504
505
506
    "mask_7_1": ModelConfig(
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
    ),
    "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
507
    "mask_10_0": ModelConfig(
508
        2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
509
    ),
510
    "mask_10_1": ModelConfig(
511
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
512
    ),
513
}
514

515

516
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
517
518
519
520
521
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
    """Test DotProductAttention module with different mask types"""
522
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
523

524

525
model_configs_bias = {
526
    # test: ModelConfig(b, sq, hq, dqk)
527
528
529
530
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
531
532
    "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),
    "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"),
533
534
    "bias_2_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
535
    ),
536
537
538
539
540
541
542
543
    "bias_2_1": ModelConfig(
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
544
    ),
545
    "bias_2_2": ModelConfig(
546
        4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
547
    ),
548
    "bias_2_3": ModelConfig(
549
550
551
552
553
554
555
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
556
557
    ),
    "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"),
558
559
    "bias_2_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
560
    ),
561
562
563
564
565
566
567
568
569
    "bias_3_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_1": ModelConfig(
        2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_2": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
570
    "bias_3_3": ModelConfig(
571
572
573
574
575
576
577
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="causal",
        attn_bias_type="post_scale_bias",
578
    ),
579
580
581
    "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
    "bias_3_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
582
    ),
583
    "bias_4_0": ModelConfig(
584
        4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
585
    ),
586
    "bias_4_1": ModelConfig(
587
588
589
590
591
592
593
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
594
    ),
595
    "bias_4_2": ModelConfig(
596
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
597
    ),
598
    "bias_4_3": ModelConfig(
599
600
601
602
603
604
605
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
606
    ),
607
608
    "bias_4_4": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
609
    ),
610
611
612
613
614
615
616
617
    "bias_4_5": ModelConfig(
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="alibi",
618
    ),
619
}
620

621

622
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
623
624
625
626
627
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types"""
628
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
629

630

631
model_configs_bias_shapes = {
632
    # test: ModelConfig(b, sq, hq, dqk)
633
634
635
636
637
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
    "bias_1_4": ModelConfig(
638
        4,
639
640
        2048,
        24,
641
        128,
642
643
644
645
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="1hss",
        alibi_type="custom",
646
647
    ),
    "bias_1_5": ModelConfig(
648
649
650
651
652
653
654
655
        2,
        2048,
        24,
        128,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="bhss",
        alibi_type="custom",
656
    ),
657
658
}

659

660
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
661
662
663
664
665
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types and shapes"""
666
667
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)

668

669
model_configs_swa = {
670
    # test: ModelConfig(b, sq, hq, dqk)
671
672
673
674
675
676
677
678
679
680
681
682
    "swa_1_1": ModelConfig(2, 2048, 16, 64),
    "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
    "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
    "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "swa_3_2": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
    ),
    "swa_3_3": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
683
    ),
684
685
686
687
688
689
690
    "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
    "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
691
    "swa_6_2": ModelConfig(
692
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
693
694
    ),
    "swa_6_3": ModelConfig(
695
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
696
    ),
697
}
698
699


700
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
701
702
703
704
705
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
    """Test DotProductAttention module with sliding window attention"""
706
707
    test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)

708

709
model_configs_alibi_slopes = {
710
    # test: ModelConfig(b, sq, hq, dqk)
711
712
713
714
715
716
717
718
719
720
721
722
723
    "alibi_1_0": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
    ),
    "alibi_1_1": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="vanilla",
    ),
724
    "alibi_2_0": ModelConfig(
725
        2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
726
727
    ),
    "alibi_2_1": ModelConfig(
728
729
730
731
732
733
734
735
        1,
        1024,
        24,
        128,
        max_seqlen_kv=2048,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="custom",
736
    ),
737
}
738
739


740
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
741
742
743
744
745
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
    """Test DotProductAttention module with ALiBi slopes"""
746
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
747

748

749
qkv_layouts = [
750
751
752
753
754
755
756
757
758
759
760
    "sb3hd",
    "sbh3d",
    "sbhd_sb2hd",
    "sbhd_sbh2d",
    "sbhd_sbhd_sbhd",
    "bs3hd",
    "bsh3d",
    "bshd_bs2hd",
    "bshd_bsh2d",
    "bshd_bshd_bshd",
]
761

762

763
model_configs_layout = {
764
    # test: ModelConfig(b, sq, hq, dqk)
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
    "layout_0_0": ModelConfig(2, 128, 16, 64),
    "layout_0_1": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "layout_0_3": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_1_0": ModelConfig(2, 2048, 24, 128),
    "layout_1_1": ModelConfig(
        2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_3": ModelConfig(
        1,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
    "layout_2_1": ModelConfig(
        2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
797
798
}

799

800
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
801
@pytest.mark.parametrize("dtype", param_types_lean)
802
803
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
804
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
805
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
806
    """Test DotProductAttention module with different QKV layouts"""
807
808
809
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


810
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
811
model_configs_layout_thd = {
812
    # test: ModelConfig(b, sq, hq, dqk)
813
814
815
816
817
818
819
    "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "layout_1_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
820
    ),
821
    "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
822
    "layout_2_1": ModelConfig(
823
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
824
825
    ),
    "layout_2_2": ModelConfig(
826
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
827
    ),
828
    "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
829
    "layout_3_1": ModelConfig(
830
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
831
832
    ),
    "layout_3_2": ModelConfig(
833
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
834
    ),
835
    "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
836
    "layout_4_1": ModelConfig(
837
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
838
839
    ),
    "layout_4_2": ModelConfig(
840
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
841
842
    ),
    "layout_5_0": ModelConfig(
843
        2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
844
845
    ),
    "layout_5_1": ModelConfig(
846
847
848
849
850
851
852
        2,
        2048,
        24,
        128,
        num_gqa_groups=1,
        attn_mask_type="padding_causal_bottom_right",
        window_size=(4, 0),
853
854
855
    ),
    "layout_5_2": ModelConfig(
        2,
856
        2048,
857
858
        24,
        128,
859
860
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal_bottom_right",
861
862
        window_size=(4, 0),
    ),
863
864
865
}


866
867
868
869
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
    get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
870
871
872
873
874
875
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with different QKV layouts"""
876
877
878
    config = model_configs[model]
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")
879
    logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
880
    pad_between_seqs = True
881
882
883
    test_dot_product_attention(
        dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
    )
884
    if get_cudnn_version() >= (9, 3, 0):
885
        logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
886
887
888
889
890
        # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
        pad_between_seqs = False
        test_dot_product_attention(
            dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
        )
891

892

893
def _run_dot_product_attention(
894
895
896
897
898
899
900
901
902
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_layout: str,
    workspace_opt: bool,
    pad_between_seqs: bool,
    is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
903
904
905
    """Run DotProductAttention module with one forward pass and one backward pass"""
    # Set RNG and environment varables
    reset_rng_states()
906
907
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
908
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
909
910
911
912
913
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
914
915
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
916
    _attention_backends["backend_selection_requires_update"] = True
917

918
    # Create seqlens
919
920
921
922
923
924
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
925
            seqlens_kv = seqlens_q
926
        if config.attn_type == "cross":
927
928
929
930
931
932
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
933
934
935
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
936
    else:
937
938
939
940
941
942
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
943
944
945
946
947
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

948
949
950
951
952
953
954
    seqlens_q_after_pad = seqlens_q.clone()
    seqlens_kv_after_pad = seqlens_kv.clone()
    cu_seqlens_q_after_pad = cu_seqlens_q.clone()
    cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
    pad_len = [0] * config.batch_size
    if pad_between_seqs:
        max_pad_len = 3
955
        pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda")  # 3
956
957
958
959
960
        seqlens_q_after_pad = seqlens_q + pad_len
        seqlens_kv_after_pad = seqlens_kv + pad_len
        cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
        cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)

961
962
963
    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
964
        if config.attn_type == "self":
965
966
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
967
968
969
970
971
972
973
974
975
976
977
978
979
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
980
            attention_mask = attention_mask_q.to(device="cuda")
981
        if config.attn_type == "cross":
982
983
984
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
                attention_mask_kv = torch.cat(
                    [
                        attention_mask_kv,
                        torch.Tensor(
                            [False] * seqlens_kv[i]
                            + [True] * (config.max_seqlen_kv - seqlens_kv[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
1012
            attention_mask = (
1013
1014
1015
                attention_mask_q.to(device="cuda"),
                attention_mask_kv.to(device="cuda"),
            )
1016

1017
    alibi_slopes = None
1018
1019
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
        if config.bias_shape == "1hss":
1020
1021
1022
            alibi_slopes = (
                torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
            )
1023
        if config.bias_shape == "bhss":
1024
1025
1026
1027
1028
            alibi_slopes = (
                torch.randn(config.batch_size, config.num_heads)
                .abs()
                .to(dtype=torch.float32, device="cuda")
            )
1029

1030
1031
    # Create input tensors
    dim_to_num = {
1032
1033
1034
1035
1036
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
1037
1038
        "dqk": config.head_dim_qk,
        "dv": config.head_dim_v,
1039
1040
1041
1042
1043
1044
        "t": cu_seqlens_q_after_pad[-1],
        "tg": cu_seqlens_kv_after_pad[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
1045
    inp = []
1046
    inp_orig = []
1047
1048
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
1049
        if i == 0:
1050
            layout = layout.replace("s", "sq")
1051
        else:
1052
1053
1054
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
1055
1056
1057
1058
        if i == 2:
            layout = layout.replace("d", "dv")
        else:
            layout = layout.replace("d", "dqk")
1059
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
1060
        tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
1061
1062
        # tensor: with padding tokens
        # tensor_orig: without padding tokens
1063
        tensor_orig = tensor
1064
1065
        if qkv_format == "thd" and pad_between_seqs:
            tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1066
            if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_q_after_pad[i - 1],
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_q_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
1080
            if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_kv_after_pad[i - 1],
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_kv_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
1094
1095
        tensor_count = 1
        split_dim = 0
1096
        for dim, l in enumerate(layout.split("_")):
1097
1098
1099
1100
1101
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
1102
1103
1104
        tensors_orig = (
            torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
        )
1105
1106
1107
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
1108
                inp_orig.append(tensors_orig[j].squeeze(split_dim))
1109
1110
            else:
                inp.append(tensors[j])
1111
                inp_orig.append(tensors_orig[j])
1112
    for i in range(3):
1113
        inp[i].requires_grad = True
1114
1115
        inp_orig[i].requires_grad = True

1116
    # Create output gradient
1117
1118
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
1119
    qkv_format_kv = qkv_format_kv.replace("d", "dv")
1120
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
1121
1122
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
    out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
1123
    out_grad_orig = out_grad
1124
1125
    if qkv_format == "thd" and pad_between_seqs:
        out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1126
        if qkv_format_kv == "t_h_dv":
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
            for i in range(1, config.batch_size + 1):
                valid_range = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
                out_grad[pad_range[0] : pad_range[1]] = 0.0
                out_grad_orig = torch.cat(
                    [out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
                )
1137

1138
    # Create bias
1139
    if config.attn_bias_type in ["no_bias", "alibi"]:
1140
        bias = None
1141
1142
1143
1144
    if config.attn_bias_type == "post_scale_bias":
        shape = "_".join(config.bias_shape)
        shape = shape.replace("_s_s", "_sq_skv")
        tensor_shape = [dim_to_num[j] for j in shape.split("_")]
1145
        bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
1146
        if config.bias_shape != "1hss":
1147
            bias.requires_grad = False
1148
1149
1150
1151

    # Create RNG
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1152

1153
1154
1155
1156
1157
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    # Set up model
1158
1159
    block = DotProductAttention(
        config.num_heads,
1160
        (config.head_dim_qk, config.head_dim_v),
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        num_gqa_groups=config.num_gqa_groups,
        attention_dropout=config.dropout_p,
        qkv_format=qkv_format,
        attn_mask_type=config.attn_mask_type,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type=config.attn_type,
1171
        softmax_type=config.softmax_type,
1172
        return_max_logit=config.return_max_logit,
1173
    ).to(dtype=dtype, device="cuda")
1174
1175
    if not is_training:
        block = block.eval()
1176
1177
    if is_training and config.softmax_type != "vanilla":
        block.softmax_offset.requires_grad = True
1178

1179
    # Run a forward and backward pass
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        q = inp_orig[0]
        k = inp_orig[1]
        v = inp_orig[2]
        d_out = out_grad_orig
    if backend == "FusedAttention":
        q = inp[0]
        k = inp[1]
        v = inp[2]
        d_out = out_grad
1190
1191
1192
1193
    out = block(
        q,
        k,
        v,
1194
        window_size=config.window_size,
1195
1196
1197
1198
1199
1200
        attention_mask=attention_mask,
        qkv_format=qkv_format,
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1201
1202
        cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
        cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
1203
1204
1205
1206
1207
1208
        attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=config.attn_bias_type,
        core_attention_bias=bias,
        alibi_slopes=alibi_slopes,
        fast_zero_fill=True,
1209
1210
        # Only pass num_splits when exercising the FlashAttention path
        num_splits=config.num_splits if backend == "FlashAttention" else 1,
1211
    )
1212
1213
1214
    max_logit = None
    if config.return_max_logit:
        out, max_logit = out
1215
1216
    if is_training:
        out.backward(d_out)
1217

1218
1219
1220
    d_softmax_offset = None
    if is_training and config.softmax_type != "vanilla":
        d_softmax_offset = block.softmax_offset.grad
1221

1222
1223
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        if is_training:
1224
            return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
1225
        else:
1226
            return out, max_logit, (None, None, None, d_softmax_offset)
1227
    if backend == "FusedAttention":
1228
1229
        if qkv_format == "thd" and pad_between_seqs:
            out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1230
1231
1232
1233
            if is_training:
                q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
            for i in range(1, config.batch_size + 1):
                valid_range_q = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                valid_range_kv = (
                    cu_seqlens_kv_after_pad[i - 1],
                    cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                )
                out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
                if is_training:
                    q_grad_orig = torch.cat(
                        [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
                    )
                    k_grad_orig = torch.cat(
                        [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
                    v_grad_orig = torch.cat(
                        [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
1254
            if is_training:
1255
1256
1257
1258
1259
                return (
                    out_orig,
                    max_logit,
                    (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
                )
1260
            else:
1261
                return out_orig, max_logit, (None, None, None, d_softmax_offset)
1262
1263
        else:
            if is_training:
1264
                return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
1265
            else:
1266
                return out, max_logit, (None, None, None, d_softmax_offset)
1267

1268

1269
model_configs_te_layer = {
1270
    # test: ModelConfig(b, sq, hq, dqk)
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
    "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "te_1_1": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "te_1_2": ModelConfig(
        2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
    ),
    "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
    "te_2_1": ModelConfig(2, 2048, 16, 64),
    "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
    "te_2_3": ModelConfig(
        1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
    "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
1287
}
1288

1289

1290
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1291
@pytest.mark.parametrize("dtype", param_types)
1292
1293
1294
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
1295
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
1296
1297
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
1298
1299
1300
def test_transformer_layer(
    dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
1301
    """Test TransformerLayer module"""
1302

Tim Moon's avatar
Tim Moon committed
1303
    # Get configs
1304
    config = model_configs[model]
1305
    tols = dict(atol=5e-2, rtol=5e-2)
1306
    workspace_opt = True
1307

1308
    # Test backend availability
1309
    is_training = True
1310
    available_backends, _, fused_attn_backends = get_available_attention_backends(
Tim Moon's avatar
Tim Moon committed
1311
        config,
1312
        qkv_dtype=dtype,
1313
1314
1315
        qkv_layout=(
            qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
        ),
1316
        is_training=is_training,
1317
        deterministic=_deterministic,
Tim Moon's avatar
Tim Moon committed
1318
    )
1319
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1320
1321
    if not fused_attn_supported:
        is_training = False
1322
        available_backends, _, fused_attn_backends = get_available_attention_backends(
1323
1324
1325
1326
1327
1328
1329
1330
            config,
            qkv_dtype=dtype,
            qkv_layout=(
                qkv_format.replace("hd", "h3d")
                if fused_qkv_params
                else qkv_format.replace("hd", "3hd")
            ),
            is_training=is_training,
1331
            deterministic=_deterministic,
1332
1333
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1334
1335
1336

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
1337
        pytest.skip("Less than two backends to compare.")
1338
1339
1340
    # Skip if qkv_format = thd and "padding" not in attn_mask_type
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        pytest.skip("THD requires padding mask.")
Tim Moon's avatar
Tim Moon committed
1341
1342

    # UnfusedDotProductAttention backend
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_format,
            workspace_opt,
            fused_qkv_params,
            RoPE,
1353
            is_training,
1354
        )
Tim Moon's avatar
Tim Moon committed
1355
1356
1357
1358
1359
1360
1361

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FusedAttention",
1362
1363
1364
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1365
1366
            fused_qkv_params,
            RoPE,
1367
            is_training,
Tim Moon's avatar
Tim Moon committed
1368
        )
1369

Tim Moon's avatar
Tim Moon committed
1370
1371
1372
1373
1374
1375
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FlashAttention",
1376
1377
1378
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1379
1380
            fused_qkv_params,
            RoPE,
1381
            is_training,
Tim Moon's avatar
Tim Moon committed
1382
        )
1383

1384
    logging.info(f"[test_transformer_layer]: is_training = {is_training}")
1385
    if unfused_attn_supported and fused_attn_supported:
1386
        logging.info("[test_transformer_layer]: unfused attn vs fused attn")
1387
1388
1389
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
    if unfused_attn_supported and flash_attn_supported:
1390
        logging.info("[test_transformer_layer]: unfused attn vs flash attn")
Tim Moon's avatar
Tim Moon committed
1391
1392
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
1393
    if fused_attn_supported and flash_attn_supported:
1394
        logging.info("[test_transformer_layer]: fused attn vs flash attn")
1395
1396
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
1397

1398

1399
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1400
@pytest.mark.parametrize("dtype", param_types_lean)
1401
1402
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
1403
1404
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format):
hugo-syn's avatar
hugo-syn committed
1405
    """Test TransformerLayer module with miscellaneous settings"""
1406
1407
1408
    ckpt_attn = True
    fused_qkv_params = True
    RoPE = True
1409
1410
1411
    test_transformer_layer(
        dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
    )
1412

1413

1414
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1415
1416
1417
1418
1419
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
    """Test TransformerLayer module with MQA/GQA"""
1420

1421
    def find_factors(x):
1422
1423
1424
1425
1426
        f = []
        for i in range(2, x + 1):
            if x % i == 0:
                f.append(i)
        return f
1427

1428
1429
1430
1431
1432
1433
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    config = model_configs[model]
    num_querys_per_gqa_group = find_factors(config.num_heads)
1434
1435

    for num_q_per_gqa_group in num_querys_per_gqa_group:
1436
1437
1438
1439
        config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
        test_transformer_layer(
            dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
        )
1440

1441

1442
def _run_transformer_layer(
1443
1444
1445
1446
1447
1448
1449
1450
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_format: str,
    workspace_opt: bool,
    fused_qkv_params: bool,
    RoPE: bool,
1451
    is_training: bool,
1452
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
1453
1454
1455
    """Run TransformerLayer module with one forward pass and one backward pass"""

    # Set RNG and environment variables
1456
    reset_rng_states()
1457
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
1458
    os.environ["NVTE_FUSED_ATTN"] = "0"
1459
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
1460
1461
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
1462
1463
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
1464
1465
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
1466
    _attention_backends["backend_selection_requires_update"] = True
1467

1468
    # Create input tensor
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
    if qkv_format == "sbhd":
        inp = torch.randn(
            config.max_seqlen_q,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.max_seqlen_kv,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1486
    if qkv_format == "bshd":
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
        inp = torch.randn(
            config.batch_size,
            config.max_seqlen_q,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.batch_size,
            config.max_seqlen_kv,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1503
1504

    # Create seqlens
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
1521
    else:
1522
1523
1524
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
    if qkv_format == "thd":
        inp = torch.randn(
            cu_seqlens_q[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            cu_seqlens_kv[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1547
1548
1549
1550
1551
1552
1553

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
1554
    drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
1555

1556
    # Create bias
1557
    bias = None
1558
1559
1560
1561
1562
1563
1564
1565
1566
    if config.attn_bias_type == "post_scale_bias":
        bias = torch.randn(
            1,
            config.num_heads,
            config.max_seqlen_q,
            config.max_seqlen_kv,
            dtype=dtype,
            device="cuda",
        )
1567
1568
1569
1570

    # Create RoPE
    rotary_pos_emb = None
    if RoPE:
1571
        PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
1572
        rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1573
1574

    # Set up model
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_heads,
        num_gqa_groups=config.num_gqa_groups,
        layernorm_epsilon=1e-5,
        hidden_dropout=0.0,
        attention_dropout=config.dropout_p,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        layer_number=layer_number,
1586
        kv_channels=config.head_dim_qk,
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
        self_attn_mask_type=config.attn_mask_type,
        tp_group=None,
        tp_size=1,
        params_dtype=dtype,
        get_rng_state_tracker=None,
        fuse_wgrad_accumulation=False,
        seq_length=config.max_seqlen_q,
        micro_batch_size=config.batch_size,
        sequence_parallel=False,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
1598
        layer_type="encoder" if config.attn_type == "self" else "decoder",
1599
1600
1601
1602
1603
1604
1605
1606
1607
        drop_path_rate=drop_path_rates[layer_number - 1],
        set_parallel_mode=True,
        fuse_qkv_params=fused_qkv_params,
        zero_centered_gamma=False,
        qkv_weight_interleaved=False,
        ub_tp_comm_overlap=False,
        bias=True,
        attn_input_format=qkv_format,
    ).to(dtype=dtype, device="cuda")
1608
1609
    if not is_training:
        block = block.eval()
1610

1611
1612
1613
    # Create ALiBi slopes
    alibi_slopes = None
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
1614
        alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
1615

1616
    # Run a forward and backward pass
1617
1618
    out = block(
        inp,
1619
        self_attn_mask_type=config.attn_mask_type,
1620
1621
        encoder_output=inp_enc if config.attn_type == "cross" else None,
        enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
1622
1623
1624
        checkpoint_core_attention=False,
        rotary_pos_emb=rotary_pos_emb,
        core_attention_bias_type=config.attn_bias_type,
1625
        core_attention_bias=bias,
1626
        alibi_slopes=alibi_slopes,
1627
1628
1629
1630
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1631
    )
1632
1633
1634
    if is_training:
        loss = out.sum()
        loss.backward()
1635
1636

    return out, inp.grad
1637
1638


1639
model_configs_fp8_extra_state = {
1640
    # test: ModelConfig(b, sq, hq, dqk)
1641
1642
1643
1644
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
}


1645
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1646
1647
1648
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
1649
1650
def test_dpa_fp8_extra_state(model, dtype):
    """Test DotProductAttention module in FP8 with checkpointing"""
1651
1652
1653
1654
1655
1656
1657
1658
    config = model_configs_fp8_extra_state[model]
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="sb3hd",
        is_training=is_training,
1659
        deterministic=_deterministic,
1660
1661
1662
1663
1664
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not fused_attn_supported and not flash_attn_supported:
        pytest.skip("No attention backend available.")

1665
1666
1667
    outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False)
    outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True)
    outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state(
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
        dtype, config, mimic_v1_6=True, checkpoint=True
    )

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )


1689
1690
def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
    """Run DotProductAttention module in FP8 with checkpointing"""
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
    steps = 10
    path = "checkpoint.pt"
    fp8_enabled = True
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_enabled,
        fp8_mha=False,
    )

    reset_rng_states()
    hidden_states = torch.randn(
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

    def get_model(dtype, config):
        sigma = 0.023
        init_method = init_method_normal(sigma)
        output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1716
        with quantized_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
            block = TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_heads,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                hidden_dropout=0.0,
                attention_dropout=0.0,
                fuse_qkv_params=True,
                params_dtype=dtype,
                device="cuda",
            )
        return block

    block = get_model(dtype, config)
    for i in range(steps // 2):
1733
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    if checkpoint:
        sd = block.state_dict()
        if mimic_v1_6:
            sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
                "self_attention.core_attention._extra_state"
            ]
            del sd["self_attention.core_attention._extra_state"]
        torch.save(sd, path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

        _cpu_rng_state_new = torch.get_rng_state()
        _cuda_rng_state_new = torch.cuda.get_rng_state()

        del block
        block = get_model(dtype, config)
        block.load_state_dict(torch.load(path, weights_only=False))
        torch.set_rng_state(_cpu_rng_state_new)
        torch.cuda.set_rng_state(_cuda_rng_state_new)

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for i in range((steps + 1) // 2):
1768
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [output, hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)

    return outputs


1786
model_configs_fp8_vs_f16 = {
1787
    # test: ModelConfig(b, sq, hq, dqk)
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
    "fp8_9": ModelConfig(2, 2048, 16, 128),
    "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
    "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
    "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
    "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
    "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
    "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
    "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
    "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
    "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
1800
}
1801

1802
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
1803
1804
1805
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]

1806

1807
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1808
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1809
1810
1811
1812
1813
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1814
@pytest.mark.parametrize("RoPE", [True, False])
1815
@pytest.mark.parametrize("is_training", [True, False])
1816
1817
1818
1819
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_mha_fp8_vs_f16(
    dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
):
1820
    """Test MultiHeadAttention module in FP8"""
1821
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1822
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1823
1824
    config = model_configs_fp8_vs_f16[model]

1825
    # Test backend availability
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
            fp8_mha=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
            fp8_mha=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
1843
1844
1845
1846
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_format.replace("hd", "h3d"),
1847
1848
        fp8=True,
        fp8_meta=fp8_meta,
1849
        is_training=is_training,
1850
        deterministic=_deterministic,
1851
    )
1852
1853
    flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
    if flash_attn_supported + fused_attn_supported_fp8 < 1:
1854
        pytest.skip("No FP8 attention backend available.")
1855
    fused_attn_supported_f16 = False
1856
1857
1858
1859
1860
1861
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_format.replace("hd", "h3d"),
            is_training=is_training,
1862
            deterministic=_deterministic,
1863
        )
1864
1865
        _, fused_attn_supported_f16, _ = available_backends
        if not fused_attn_supported_f16:
1866
1867
1868
            pytest.skip("No attention backend available.")

    if flash_attn_supported:
1869
1870
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
1871
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
1872
1873
1874
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1875
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1876
        )
1877

1878
1879
1880
    if fused_attn_supported_fp8:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "1"
1881
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
1882
1883
1884
1885
1886
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
        )
1887

1888
1889
1890
    if fused_attn_supported_f16:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "1"
1891
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
1892
1893
1894
1895
1896
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
        fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
            dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
        )
1897

1898
1899
1900
    atol = 5e-1
    rtol = 5e-1
    rmse_tol = 0.15
1901
    if flash_attn_supported and fused_attn_supported_f16:
1902
1903
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
1904
        compare_and_assert(
1905
1906
1907
1908
1909
1910
1911
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
1912
            True,
1913
        )
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
    if fused_attn_supported_fp8 and fused_attn_supported_f16:
        logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
        compare_and_assert(
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
            True,
        )
1927

1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
        if is_training:
            for i in range(len(param_names[:1])):
                logging.debug("========== {:^25s} ==========".format(param_names[i]))
                compare_and_assert(
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                    True,
                )
1941

1942

1943
1944
1945
def _run_mha_fp8_vs_f16(
    dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
):
1946
    """Run MultiHeadAttention module in FP8"""
1947
1948
1949
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1950

1951
1952
1953
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
Tim Moon's avatar
Tim Moon committed
1954

1955
    with quantized_model_init(enabled=fp8_mha, recipe=fp8_recipe):
1956
1957
1958
1959
        rotary_pos_emb = None
        if RoPE:
            PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
            rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1960
        mha = MultiheadAttention(
1961
1962
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_heads,
1963
            kv_channels=config.head_dim_qk,
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            layer_number=1,
            bias=True,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            params_dtype=dtype,
            input_layernorm=input_layernorm,
            fuse_qkv_params=True,
            attention_type="self",
            qkv_weight_interleaved=True,
            qkv_format=qkv_format,
1975
        ).to(dtype=dtype, device="cuda")
1976
1977
        if not is_training:
            mha = mha.eval()
1978

1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
1999
2000
2001
2002
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
2003

2004
    dim_to_num = {
2005
2006
2007
2008
2009
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2010
        "d": config.head_dim_qk,
2011
2012
2013
2014
2015
2016
2017
2018
2019
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
    layout = "_".join(qkv_format)
    layout = layout.replace("s", "sq")
    tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2020
2021
    tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
    hidden_states = tensor.view(*tensor.shape[:-2], -1)
2022
2023
    if is_training:
        hidden_states.requires_grad = True
2024
2025
2026
    tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
    out_grad = tensor.view(*tensor.shape[:-2], -1)

2027
    with autocast(enabled=fp8_mha, recipe=fp8_recipe):
2028
2029
        out = mha(
            hidden_states,
2030
2031
2032
2033
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
            is_first_microbatch=None,
2034
            rotary_pos_emb=rotary_pos_emb,
2035
2036
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
2037
        )
2038
2039
    if is_training:
        out.backward(out_grad)
Tim Moon's avatar
Tim Moon committed
2040

2041
    param_names = []
2042
    param_names.append("hidden_states.grad")
2043
2044
2045
2046
    params = []
    params.append(hidden_states)
    for name, param in mha.named_parameters():
        if param.requires_grad:
2047
            param_names.append(name + ".grad")
2048
            params.append(param)
2049

2050
2051
2052
    if is_training:
        return out, param_names, tuple(x.grad for x in params)
    return out, param_names, tuple(None for x in params)
2053

2054

2055
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
2056
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
2057
2058
2059
2060
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
2061
@pytest.mark.parametrize("is_training", [True, False])
2062
2063
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
2064
    """Test DotProductAttention module in FP8"""
2065
2066
    config = model_configs_fp8_vs_f16[model]

2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
    # TODO(cyang): think of another way to verify dropout results
    # test cuDNN FP8 dropout
    # 1. we modify the config here to not affect mha_fp8_vs_f16 tests
    # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an
    #    indirect verification method, we create Q/K/V as all 1s and check if O is all 1s
    # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell
    # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type:
    #    if get_device_compute_capability() >= (10, 0):
    #        config.dropout_p = 0.1

2077
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
2078
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
2079
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"
2080

2081
    # Test backend availability
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
2097
2098
2099
2100
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_layout,
2101
2102
        fp8=True,
        fp8_meta=fp8_meta,
2103
        is_training=is_training,
2104
        deterministic=_deterministic,
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if flash_attn_supported + fused_attn_supported < 1:
        pytest.skip("No FP8 attention backend available.")
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            is_training=is_training,
2115
            deterministic=_deterministic,
2116
2117
2118
2119
2120
2121
2122
2123
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")

    if flash_attn_supported:
2124
2125
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
2126
        os.environ["NVTE_UNFUSED_ATTN"] = "0"
2127
        _attention_backends["backend_selection_requires_update"] = True
2128
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
2129
        flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2130
2131
2132
2133
2134
2135
            dtype, config, True, qkv_layout, is_training, fp8_recipe
        )

    if unfused_attn_supported:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "0"
2136
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2137
2138
2139
2140
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
        unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
            dtype, config, True, qkv_layout, is_training, fp8_recipe
2141
        )
2142

2143
2144
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
2145
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2146
    _attention_backends["backend_selection_requires_update"] = True
2147
    logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
2148
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2149
        dtype, config, True, qkv_layout, is_training, fp8_recipe
2150
    )
2151

2152
2153
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
2154
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2155
2156
    if config.dropout_p == 0.0:
        # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
2157
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
2158
        fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
2159
            dtype, config, False, qkv_layout, is_training, fp8_recipe
2160
        )
2161

2162
2163
    atol = 5e-1
    rtol = 5e-2
2164
    rmse_tol = 0.11
2165
    bwd_names = ["dq", "dk", "dv"]
2166
    if flash_attn_supported:
2167
2168
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2169
        compare_and_assert(
2170
2171
2172
2173
2174
2175
2176
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2177
            True,
2178
        )
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
    if unfused_attn_supported:
        logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
        compare_and_assert(
            unfused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "unfused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
            True,
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
                compare_and_assert(
                    unfused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"unfused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                    True,
                )
2205
2206
2207
2208
2209
2210
    if config.dropout_p != 0.0:
        # test cuDNN FP8 dropout
        assert torch.all(
            fused_attn_fwd_fp8 == 1
        ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
    else:
2211
2212
        logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2213
        compare_and_assert(
2214
2215
2216
2217
2218
2219
2220
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2221
            True,
2222
2223
2224
2225
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
2226
                compare_and_assert(
2227
2228
2229
2230
2231
2232
2233
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
2234
                    True,
2235
                )
2236
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0"
2237
2238


2239
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe):
2240
    """Run DotProductAttention module in FP8"""
2241
2242
2243
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
2244

2245
2246
2247
2248
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

2249
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
2250
    with quantized_model_init(enabled=fp8_dpa):
2251
2252
        dpa = DotProductAttention(
            config.num_heads,
2253
            config.head_dim_qk,
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            sequence_parallel=False,
            tp_size=1,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            tp_group=None,
            layer_number=1,
            attention_type="self",
            qkv_format=qkv_format,
        ).to(dtype=dtype, device="cuda")
2264
2265
        if not is_training:
            dpa = dpa.eval()
2266

2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
2287
2288
2289
2290
2291
2292
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

    dim_to_num = {
2293
2294
2295
2296
2297
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2298
        "d": config.head_dim_qk,
2299
2300
2301
2302
2303
2304
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
2305
    inp = []
2306
2307
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
2308
        if i == 0:
2309
            layout = layout.replace("s", "sq")
2310
        else:
2311
2312
2313
2314
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2315
2316
2317
2318
2319
        if config.dropout_p == 0.0:
            tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
        else:
            # test cuDNN FP8 dropout
            tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda")
2320
2321
        tensor_count = 1
        split_dim = 0
2322
        for dim, l in enumerate(layout.split("_")):
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
        inp[i].requires_grad = True

2336
2337
2338
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
2339
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
2340
    out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
2341

2342
    with autocast(enabled=fp8_dpa, recipe=fp8_recipe):
2343
2344
2345
2346
        out = dpa(
            inp[0],
            inp[1],
            inp[2],
2347
2348
2349
2350
2351
2352
2353
2354
            qkv_format=qkv_format,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
2355
            fp8_output=fp8_dpa,
2356
        )
2357
2358
    if is_training:
        out.backward(out_grad)
2359

2360
2361
2362
    if is_training:
        return out, (inp[0].grad, inp[1].grad, inp[2].grad)
    return out, (None, None, None)
2363
2364
2365


model_configs_fp8 = {
2366
    # test: ModelConfig(b, sq, hq, dqk)
2367
2368
2369
2370
2371
2372
2373
2374
    "fp8_1": ModelConfig(1, 512, 1, 64),
    "fp8_2": ModelConfig(4, 512, 16, 64),
    "fp8_3": ModelConfig(1, 2048, 1, 128),
    "fp8_4": ModelConfig(2, 2048, 24, 128),
    "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
    "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
    "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
    "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
2375
2376
}
param_types_fp8 = [torch.float16, torch.bfloat16]
2377
2378
2379
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
2380
2381


2382
2383
2384
2385
2386
2387
2388
2389
@pytest.mark.skipif(
    (
        get_cudnn_version() < (8, 9, 3)
        if cudnn_frontend_version == 0
        else get_cudnn_version() < (9, 2, 1)
    ),
    reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
2390
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
    """Test FP8 dot product attention implementations based on cuDNN frontend
    v0.9 and v1.0+. Each test compares results from a custom implementation of
    an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
    implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
    Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""

    config = model_configs_fp8[model]

2402
2403
2404
2405
2406
2407
2408
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
        is_training=is_training,
2409
        deterministic=_deterministic,
2410
2411
2412
2413
2414
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not (fused_attn_backends and unfused_attn_supported):
        pytest.skip("Not enough backends to run this test with.")

2415
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
2416
2417
2418
    unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
        dtype, config, "UnfusedDotProductAttention"
    )
2419

2420
2421
    atol = 5e-1
    rtol = 5e-1
2422
    rmse_tol = 0.13
2423
    compare_and_assert(
2424
2425
2426
2427
2428
2429
2430
        fused_attn_fwd_fp8,
        unfused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "unfused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
2431
        True,
2432
    )
2433
    compare_and_assert(
2434
2435
2436
2437
2438
2439
2440
        fused_attn_bwd_fp8,
        unfused_attn_bwd_f16,
        "fused_attn_bwd_fp8",
        "unfused_attn_bwd_f16",
        atol,
        rtol,
        rmse_tol,
2441
        True,
2442
    )
2443
2444
2445
2446
2447


def _run_custom_mha_fp8(dtype, config, backend):
    """Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
2448
    reset_rng_states()
2449
2450
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
2451
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
2452
2453
2454
2455
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2456
2457
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2458
    _attention_backends["backend_selection_requires_update"] = True
2459

2460
2461
2462
    inp = 0.0001 * torch.randint(
        -100,
        100,
2463
        (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
2464
2465
2466
2467
2468
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
2469
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2470
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2471

2472
    out_grad = 0.01 * torch.randn(
2473
        config.batch_size * config.max_seqlen_q,
2474
        config.num_heads * config.head_dim_qk,
2475
2476
2477
2478
        dtype=dtype,
        device="cuda",
    )
    torch.save(out_grad, "out_grad.pt")
2479
2480
2481
2482
2483
2484
2485
2486

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

2487
    mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
2488
    with autocast(enabled=True, recipe=fp8_recipe):
2489
        out = mha(inp, cu_seqlens, config.max_seqlen_q)
2490
    out.backward(out_grad)
2491

2492
    out = torch.load("out.pt")
2493
2494
2495
2496
    dqkv = torch.load("dqkv.pt")
    return (
        out.view(config.batch_size, config.max_seqlen_q, -1),
        dqkv.view(
2497
            config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
2498
2499
        ).contiguous(),
    )
2500

2501

2502
2503
2504
def _run_ref_mha_f16(dtype, config, backend):
    """Run reference F16 FusedAttention. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
2505
2506
2507

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
2508
    os.environ["NVTE_UNFUSED_ATTN"] = "0"
2509
2510
2511
2512
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2513
2514
    if backend == "UnfusedDotProductAttention":
        os.environ["NVTE_UNFUSED_ATTN"] = "1"
2515
    _attention_backends["backend_selection_requires_update"] = True
2516

2517
    inp = torch.load("qkv.pt").to(device="cuda")
2518
2519
2520
    inp.requires_grad = True
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2521
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2522
2523
2524
    out_grad = (
        torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
    )
2525
2526
2527
2528
2529
2530
2531

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
2532

2533
2534
    block = DotProductAttention(
        config.num_heads,
2535
        config.head_dim_qk,
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
        attention_dropout=config.dropout_p,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type="self",
        qkv_format="bshd",
    ).to(dtype=dtype, device="cuda")

    q = inp[:, :, 0, :, :]
    k = inp[:, :, 1, :, :]
    v = inp[:, :, 2, :, :]
2549
2550
2551
2552
    out = block(q, k, v, attn_mask_type=config.attn_mask_type)
    out.backward(out_grad)

    return out, inp.grad
2553
2554
2555
2556
2557
2558
2559


_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

2560
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
2561
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
2562
2563
2564
2565
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
2566
2567


2568
class _custom_mha_fp8(torch.autograd.Function):
2569
2570
2571
2572
2573
2574
2575
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
2576
        num_heads: int,
2577
2578
2579
2580
2581
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        is_training: bool,
2582
        mask_type: str,
2583
        quantizers: list[Quantizer],
2584
    ) -> torch.Tensor:
2585
        qkv_dtype = inp.dtype
2586
2587
2588

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
2589
        h = num_heads
2590
2591
2592
        d = in_features // h
        b = cu_seqlens.numel() - 1

2593
2594
2595
2596
2597
2598
2599
2600
        input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
        qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
        qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
        dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
        dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
2601

2602
        inp_fp8 = input_quantizer(inp)
2603

2604
        qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight)
2605

2606
        qkv, *_ = ext.general_gemm(
2607
            qkv_weight_fp8,
2608
            inp_fp8,
2609
            bias=qkv_bias,
2610
2611
            out_dtype=qkv_weight_fp8.dtype,
            quantization_params=qkv_quantizer,
2612
2613
            use_split_accumulator=_2X_ACC_FPROP,
        )
2614
        qkv = qkv.view(-1, 3, h, d)
2615
        qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
2616
        torch.save(qkv_fp16, "qkv.pt")
2617
        if cudnn_frontend_version == 1:
2618
            qkv = qkv.view(b, max_s, 3, h, d)  # bs3hd
2619
2620

        # FMHA
2621
2622
2623
2624
2625
2626
2627
2628
        q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
        k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
        v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
        q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
        k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
        v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)

        out, aux_ctx_tensors = fused_attn_fwd(
2629
2630
2631
2632
2633
            is_training,
            max_s,
            max_s,
            cu_seqlens,
            cu_seqlens,
2634
2635
2636
2637
            q,
            k,
            v,
            qkv_dtype,
2638
2639
2640
2641
2642
2643
2644
2645
            FusedAttnBackend["FP8"],
            attn_scale=None,
            dropout=p_dropout,
            fast_zero_fill=fast_zero_fill,
            qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
            attn_bias_type="no_bias",
            attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
            rng_gen=None,
2646
2647
            o_quantizer=o_quantizer,
            s_quantizer=s_quantizer,
2648
        )
2649

2650
        tensors_to_save, tensor_objects = prepare_for_saving(q, k, v, inp_fp8, qkv_weight_fp8, out)
2651
2652
2653

        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
2654
        ctx.aux_ctx_tensors = aux_ctx_tensors
2655
        ctx.qkv_dtype = qkv_dtype
2656
2657
2658
2659
2660
2661
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.hidden_size = in_features
2662
        ctx.num_heads = num_heads
2663
2664
        ctx.mask_type = mask_type
        ctx.dtype = inp.dtype
2665

2666
2667
2668
2669
2670
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = s_quantizer

2671
        out = out.view(-1, in_features)  # (bs)(hd)
2672
        out_fp16 = out.dequantize()
2673
        torch.save(out_fp16, "out.pt")  # (bs)(hd)
2674
        return out_fp16
2675
2676

    @staticmethod
2677
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
2678
        with torch.cuda.nvtx.range("_DPA"):
2679
            saved_tensors = ctx.saved_tensors
2680
            (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
2681
2682
                ctx.tensor_objects, saved_tensors
            )
2683

2684
            proj_dgrad = ctx.dO_quantizer(grad_output)
2685
            fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2686

2687
            dq, dk, dv, *rest = fused_attn_bwd(
2688
2689
2690
2691
                ctx.max_s,
                ctx.max_s,
                ctx.cu_seqlens,
                ctx.cu_seqlens,
2692
2693
2694
                q,
                k,
                v,
2695
2696
                out,
                proj_dgrad.view_as(out),
2697
                ctx.qkv_dtype,
2698
2699
2700
2701
2702
                fp8_dtype_backward,
                ctx.aux_ctx_tensors,
                FusedAttnBackend["FP8"],
                None,
                None,
2703
2704
2705
                ctx.S_quantizer,
                ctx.dP_quantizer,
                ctx.dQKV_quantizer,
2706
2707
2708
2709
2710
2711
2712
                attn_scale=None,
                dropout=ctx.p_dropout,
                fast_zero_fill=ctx.fast_zero_fill,
                qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
                attn_bias_type="no_bias",
                attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
            )
2713
            dim = 2 if cudnn_frontend_version == 1 else 1
2714
2715
            dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
            dqkv_shape = list(dq._data.shape)
2716
            dqkv_shape.insert(dim, 3)
2717
            dqkv_stride = list(dq._data.stride())
2718
            dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
2719
2720
2721
            dqkv.set_(
                dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride
            )  # bs3hd
2722

2723
            dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
2724
2725
            dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape)
            dqkv_c_fp16 = dqkv_c.dequantize()
2726
            torch.save(dqkv_c_fp16, "dqkv.pt")
2727

2728
2729
2730
            qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer)
            dqkv_c._transpose = None
            dqkv_c._create_transpose()
2731
2732

            # QKV DGRAD
2733
2734
            qkv_dgrad, *_ = ext.general_gemm(
                qkv_weight_fp8,
2735
                dqkv_c,
2736
                ctx.dtype,
2737
                use_split_accumulator=_2X_ACC_DGRAD,
2738
                layout="NN",
2739
            )
2740

2741
            # QKV WGRAD
2742
2743
2744
2745
            qkv_wgrad, *_ = ext.general_gemm(
                inp_fp8,
                dqkv,
                ctx.dtype,
2746
                use_split_accumulator=_2X_ACC_WGRAD,
2747
                layout="NT",
2748
2749
            )

2750
2751
        return (
            qkv_dgrad,
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2763
2764
            None,
        )
2765

2766

2767
class Custom_MHA_FP8(TransformerEngineBaseModule):
2768
    def __init__(self, config, params_dtype: torch.dtype = torch.float32):
2769
2770
        super().__init__()
        self.p_dropout = config.dropout_p
2771
        self.h = config.num_heads
2772
        self.hidden_size = config.hidden_size
2773
        self.head_dim = config.head_dim_qk
2774
        self.fast_zero_fill = True
2775
        self.mask_type = config.attn_mask_type
2776

Tim Moon's avatar
Tim Moon committed
2777
        self.qkv_weight = torch.nn.Parameter(
2778
2779
2780
2781
2782
2783
2784
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
Tim Moon's avatar
Tim Moon committed
2785
        self.qkv_bias = torch.nn.Parameter(
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)

    def forward(
2797
2798
2799
2800
        self,
        inp: torch.Tensor,
        cu_seqlens,
        max_s,
2801
    ) -> torch.Tensor:
2802
        with self.prepare_forward_ctx(inp, num_gemms=3) as inp:
2803
            out = _custom_mha_fp8.apply(
2804
2805
2806
2807
2808
2809
2810
2811
2812
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
2813
                self.training,
2814
                self.mask_type,
2815
                self.quantizers,
2816
            )
2817
        return out