test_attention.py 104 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
import logging
Tim Moon's avatar
Tim Moon committed
5
import os
6
7
import sys
import pathlib
8
from typing import Any, Dict, Tuple, Union
Tim Moon's avatar
Tim Moon committed
9

10
import pytest
Tim Moon's avatar
Tim Moon committed
11
import torch
12

13
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype
Tim Moon's avatar
Tim Moon committed
14
from transformer_engine.common import recipe
15
16
17
18
from transformer_engine.pytorch import (
    TransformerLayer,
    autocast,
    quantized_model_init,
Tim Moon's avatar
Tim Moon committed
19
    DotProductAttention,
20
21
22
23
24
25
26
    MultiheadAttention,
    get_device_compute_capability,
    Quantizer,
    is_fp8_available,
    is_bf16_available,
)
from transformer_engine.pytorch.attention.dot_product_attention import (
27
28
    _attention_backends,
)
29
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
30
    FlashAttentionUtils,
31
    check_set_window_size,
Tim Moon's avatar
Tim Moon committed
32
)
33
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
Tim Moon's avatar
Tim Moon committed
34
35
36
37
38
39
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    FusedAttnBackend,
    fused_attn_bwd,
    fused_attn_fwd,
)
40
41
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
42
43
44
45
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
46
from transformer_engine.pytorch.utils import get_cudnn_version
47
import transformer_engine_torch as tex
48
49
from transformer_engine.pytorch.quantized_tensor import (
    Quantizer,
50
51
52
    prepare_for_saving,
    restore_from_saved,
)
wenjh's avatar
wenjh committed
53
from torch.utils.cpp_extension import IS_HIP_EXTENSION
54

55
56
57
58
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
    reset_rng_states,
59
    compare_and_assert,
60
61
62
63
64
    ModelConfig,
    dtype_tols,
    get_available_attention_backends,
)

65
# Check if hardware supports FP8 attention.
66
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
67
68
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability()
wenjh's avatar
wenjh committed
69
70
71
72
73
74
if not IS_HIP_EXTENSION:
    if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
        fp8_attn_available = False
        reason_for_no_fp8_attn = (
            "FP8 attention is not supported for compute capability ="
            f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
75
    )
76

77
# Reset RNG seed and states
78
seed = 1234
79
reset_rng_states()
80
81


82
# Reset FP8 global state manager
83
84
85
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
86
    FP8GlobalStateManager.reset()
87
88


89
90
# Define F16 data types to test
param_types = [torch.float16]
91
if is_bf16_available():
92
93
    param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
94

95
model_configs_base = {
96
    # test: ModelConfig(b, sq, hq, dqk)
97
98
99
100
101
102
103
104
105
106
107
108
    "base_1_0": ModelConfig(8, 128, 16, 64),
    "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
    "base_2_0": ModelConfig(2, 2048, 24, 128),
    "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
    "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
    "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
    "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
    "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
    "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
    "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
    "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
109
110
}

111

112
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
113
@pytest.mark.parametrize("dtype", param_types)
114
115
116
117
118
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
119
@pytest.mark.parametrize("swa", [False])
120
@pytest.mark.parametrize("pad_between_seqs", [False])
121
def test_dot_product_attention(
122
123
124
125
126
127
128
129
    dtype,
    model_configs,
    model,
    ckpt_attn,
    workspace_opt,
    qkv_layout,
    swa,
    pad_between_seqs,
130
):
131
    """Test DotProductAttention module"""
132

Tim Moon's avatar
Tim Moon committed
133
    # Get configs
134
    tols = dict(atol=1e-3, rtol=1e-3)
Tim Moon's avatar
Tim Moon committed
135
    if dtype == torch.bfloat16:
136
        tols = dict(atol=1.5e-2, rtol=1.5e-2)
137
    config = model_configs[model]
138
    is_mla = config.head_dim_qk != config.head_dim_v
139
    is_mqa_gqa = config.num_heads != config.num_gqa_groups
140
141
    if qkv_layout is None:
        if config.attn_type == "self":
142
            qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
143
        else:
144
            qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
145
    if "3" in qkv_layout and config.attn_type == "cross":
146
        pytest.skip("No need to test this layout for cross attention")
Tim Moon's avatar
Tim Moon committed
147

148
149
150
    if config.window_size == (-1, -1) and swa:
        config.window_size = [2, 2]
    config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
151
152
153
154
155
    qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        config.attn_mask_type = (
            "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
        )
156

157
    # Get backends
158
    is_training = True
159
    available_backends, _, fused_attn_backends = get_available_attention_backends(
160
        config,
161
        qkv_dtype=dtype,
162
        qkv_layout=qkv_layout,
163
        pad_between_seqs=pad_between_seqs,
164
        is_training=is_training,
Tim Moon's avatar
Tim Moon committed
165
    )
166
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
167
168
    if not fused_attn_supported:
        is_training = False
169
        available_backends, _, fused_attn_backends = get_available_attention_backends(
170
171
172
173
174
175
176
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            pad_between_seqs=pad_between_seqs,
            is_training=is_training,
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
177

178
179
    # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
    # mannually pads and unpads the input and output of FlashAttention for testing purposes
180
181
    if (
        pad_between_seqs
182
        and FlashAttentionUtils.is_installed
183
184
185
186
        and not (
            config.max_seqlen_q != config.max_seqlen_kv
            and config.attn_mask_type in ["causal", "padding_causal"]
        )
187
        and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
188
    ):
189
        flash_attn_supported = True
190
191
192

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
193
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
194
195

    # UnfusedDotProductAttention backend
196
    if unfused_attn_supported:
197
        unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
198
199
200
201
202
203
204
205
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
206
        )
Tim Moon's avatar
Tim Moon committed
207
208
209

    # FusedAttention backend
    if fused_attn_supported:
210
        if len(fused_attn_backends) == 1:
211
            fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
212
213
214
215
216
217
218
219
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
220
            )
221
        if len(fused_attn_backends) == 2:
222
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
223
            fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
224
225
226
227
228
229
230
231
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
232
233
            )
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
234
            fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
235
236
237
238
239
240
241
242
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
243
            )
244

Tim Moon's avatar
Tim Moon committed
245
246
    # FlashAttention backend
    if flash_attn_supported:
247
        flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
248
249
250
251
252
253
254
255
            dtype,
            config,
            "FlashAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
Tim Moon's avatar
Tim Moon committed
256
        )
257

258
    # Compare results
259
    logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
260
    if unfused_attn_supported and flash_attn_supported:
261
        logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
262
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
263
        for i, _ in enumerate(flash_attn_bwd):
264
            torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
265
266
267
    if unfused_attn_supported and fused_attn_supported:
        logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
268
269
        if config.return_max_logit:
            torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
270
271
        for i, _ in enumerate(unfused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
272
    if fused_attn_supported and flash_attn_supported:
273
        logging.info("[test_dot_product_attention]: fused attn vs flash attn")
274
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
275
        for i, _ in enumerate(flash_attn_bwd):
276
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
277
    if fused_attn_supported and len(fused_attn_backends) == 2:
278
        logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
279
        torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
280
        for i, _ in enumerate(fused_attn_bwd):
281
282
            torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)

283

284
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
285
286
287
288
289
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
    """Test DotProductAttention module with checkpointing"""
290
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
291

292

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
model_configs_max_logit = {
    # test: ModelConfig(b, sq, hq, dqk)
    "max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
    "max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "max_logit_4": ModelConfig(
        8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
    ),
    "max_logit_5": ModelConfig(
        8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
    ),
    "max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with checkpointing"""
    config = model_configs[model]
    config.return_max_logit = True
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
model_configs_num_splits = {
    # test: ModelConfig(b, sq, hq, dqk)
    "num_splits_1_0": ModelConfig(2, 2048, 24, 128, num_splits=2),
    "num_splits_1_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_num_splits])
@pytest.mark.parametrize("model", model_configs_num_splits.keys())
def test_dpa_num_splits(dtype, model_configs, model):
    """Test DotProductAttention with FlashAttention-3 num_splits enabled"""
    test_dot_product_attention(
        dtype,
        model_configs,
        model,
        False,
        True,
        None,
        False,
        False,
    )


345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
model_configs_softmax = {
    # test: ModelConfig(b, sq, hq, dqk)
    "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
    "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"),
    "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"),
    "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
    "softmax_2_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
    ),
    "softmax_2_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
    ),
    "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"),
    "softmax_3_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one"
    ),
    "softmax_3_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable"
    ),
    "softmax_4_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal"
    ),
    "softmax_4_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="off-by-one",
    ),
    "softmax_4_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="learnable",
    ),
    "softmax_5_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal"
    ),
    "softmax_5_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="off-by-one",
    ),
    "softmax_5_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="learnable",
    ),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax(dtype, model_configs, model):
    """Test DotProductAttention module with different softmax types"""
    test_dot_product_attention(
        dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False
    )


424
model_configs_mla = {
425
    #TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
426
    #    test:             b,  h, hg, dqk, sq, skv,   p,      mask,      bias   # attn , backend
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    # "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),  # self , 0
    # "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    # "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    # "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),  # self , 1
    # "mla_2_1": ModelConfig(
    #     1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
    # ),  # cross, 1
    # "mla_2_2": ModelConfig(
    #     1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
    # ),  # cross, 1
    # "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),  # inference
    # "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),  # inference
    # "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),  # inference
    # "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),  # inference
441
    "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),  # inference
442
443
444
445
446
447
448
449
450
451
452
453
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
    """Test DotProductAttention module with Multi-Latent Attention (MLA)"""
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


454
model_configs_mask = {
455
    # test: ModelConfig(b, sq, hq, dqk)
456
457
458
459
460
461
    "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
    "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "mask_2_1": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
462
    ),
463
464
465
466
467
468
469
470
471
472
    "mask_2_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
    ),
    "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
473
    "mask_5_1": ModelConfig(
474
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
475
476
    ),
    "mask_5_2": ModelConfig(
477
478
479
480
481
482
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_7_0": ModelConfig(
        2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
483
    ),
484
485
486
487
488
489
490
    "mask_7_1": ModelConfig(
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
    ),
    "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
491
    "mask_10_0": ModelConfig(
492
        2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
493
    ),
494
    "mask_10_1": ModelConfig(
495
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
496
    ),
497
}
498

499

500
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
501
502
503
504
505
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
    """Test DotProductAttention module with different mask types"""
506
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
507

508

509
model_configs_bias = {
510
    # test: ModelConfig(b, sq, hq, dqk)
511
512
513
514
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
515
516
    "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),
    "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"),
517
518
    "bias_2_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
519
    ),
520
521
522
523
524
525
526
527
    "bias_2_1": ModelConfig(
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
528
    ),
529
    "bias_2_2": ModelConfig(
530
        4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
531
    ),
532
    "bias_2_3": ModelConfig(
533
534
535
536
537
538
539
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
540
541
    ),
    "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"),
542
543
    "bias_2_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
544
    ),
545
546
547
548
549
550
551
552
553
    "bias_3_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_1": ModelConfig(
        2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_2": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
554
    "bias_3_3": ModelConfig(
555
556
557
558
559
560
561
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="causal",
        attn_bias_type="post_scale_bias",
562
    ),
563
564
565
    "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
    "bias_3_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
566
    ),
567
    "bias_4_0": ModelConfig(
568
        4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
569
    ),
570
    "bias_4_1": ModelConfig(
571
572
573
574
575
576
577
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
578
    ),
579
    "bias_4_2": ModelConfig(
580
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
581
    ),
582
    "bias_4_3": ModelConfig(
583
584
585
586
587
588
589
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
590
    ),
591
592
    "bias_4_4": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
593
    ),
594
595
596
597
598
599
600
601
    "bias_4_5": ModelConfig(
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="alibi",
602
    ),
603
}
604

605

606
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
607
608
609
610
611
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types"""
612
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
613

614

615
model_configs_bias_shapes = {
616
    # test: ModelConfig(b, sq, hq, dqk)
617
618
619
620
621
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
    "bias_1_4": ModelConfig(
622
        4,
623
624
        2048,
        24,
625
        128,
626
627
628
629
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="1hss",
        alibi_type="custom",
630
631
    ),
    "bias_1_5": ModelConfig(
632
633
634
635
636
637
638
639
        2,
        2048,
        24,
        128,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="bhss",
        alibi_type="custom",
640
    ),
641
642
}

643

644
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
645
646
647
648
649
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types and shapes"""
650
651
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)

652

653
model_configs_swa = {
654
    # test: ModelConfig(b, sq, hq, dqk)
655
656
657
658
659
660
661
662
663
664
665
666
    "swa_1_1": ModelConfig(2, 2048, 16, 64),
    "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
    "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
    "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "swa_3_2": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
    ),
    "swa_3_3": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
667
    ),
668
669
670
671
672
673
674
    "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
    "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
675
    "swa_6_2": ModelConfig(
676
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
677
678
    ),
    "swa_6_3": ModelConfig(
679
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
680
    ),
681
}
682
683


684
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
685
686
687
688
689
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
    """Test DotProductAttention module with sliding window attention"""
690
691
    test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)

692

693
model_configs_alibi_slopes = {
694
    # test: ModelConfig(b, sq, hq, dqk)
695
696
697
698
699
700
701
702
703
704
705
706
707
    "alibi_1_0": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
    ),
    "alibi_1_1": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="vanilla",
    ),
708
    "alibi_2_0": ModelConfig(
709
        2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
710
711
    ),
    "alibi_2_1": ModelConfig(
712
713
714
715
716
717
718
719
        1,
        1024,
        24,
        128,
        max_seqlen_kv=2048,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="custom",
720
    ),
721
}
722
723


724
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
725
726
727
728
729
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
    """Test DotProductAttention module with ALiBi slopes"""
730
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
731

732

733
qkv_layouts = [
734
735
736
737
738
739
740
741
742
743
744
    "sb3hd",
    "sbh3d",
    "sbhd_sb2hd",
    "sbhd_sbh2d",
    "sbhd_sbhd_sbhd",
    "bs3hd",
    "bsh3d",
    "bshd_bs2hd",
    "bshd_bsh2d",
    "bshd_bshd_bshd",
]
745

746

747
model_configs_layout = {
748
    # test: ModelConfig(b, sq, hq, dqk)
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
    "layout_0_0": ModelConfig(2, 128, 16, 64),
    "layout_0_1": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "layout_0_3": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_1_0": ModelConfig(2, 2048, 24, 128),
    "layout_1_1": ModelConfig(
        2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_3": ModelConfig(
        1,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
    "layout_2_1": ModelConfig(
        2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
781
782
}

783

784
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
785
@pytest.mark.parametrize("dtype", param_types_lean)
786
787
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
788
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
789
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
790
    """Test DotProductAttention module with different QKV layouts"""
791
792
793
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


794
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
795
model_configs_layout_thd = {
796
    # test: ModelConfig(b, sq, hq, dqk)
797
798
799
800
801
802
803
    "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "layout_1_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
804
    ),
805
    "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
806
    "layout_2_1": ModelConfig(
807
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
808
809
    ),
    "layout_2_2": ModelConfig(
810
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
811
    ),
812
    "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
813
    "layout_3_1": ModelConfig(
814
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
815
816
    ),
    "layout_3_2": ModelConfig(
817
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
818
    ),
819
    "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
820
    "layout_4_1": ModelConfig(
821
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
822
823
    ),
    "layout_4_2": ModelConfig(
824
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
825
826
    ),
    "layout_5_0": ModelConfig(
827
        2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
828
829
    ),
    "layout_5_1": ModelConfig(
830
831
832
833
834
835
836
        2,
        2048,
        24,
        128,
        num_gqa_groups=1,
        attn_mask_type="padding_causal_bottom_right",
        window_size=(4, 0),
837
838
839
    ),
    "layout_5_2": ModelConfig(
        2,
840
        2048,
841
842
        24,
        128,
843
844
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal_bottom_right",
845
846
        window_size=(4, 0),
    ),
847
848
849
}


850
851
852
853
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
    get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
854
855
856
857
858
859
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with different QKV layouts"""
860
861
862
    config = model_configs[model]
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")
863
    logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
864
    pad_between_seqs = True
865
866
867
    test_dot_product_attention(
        dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
    )
868
    if get_cudnn_version() >= (9, 3, 0):
869
        logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
870
871
872
873
874
        # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
        pad_between_seqs = False
        test_dot_product_attention(
            dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
        )
875

876

877
def _run_dot_product_attention(
878
879
880
881
882
883
884
885
886
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_layout: str,
    workspace_opt: bool,
    pad_between_seqs: bool,
    is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
887
888
889
    """Run DotProductAttention module with one forward pass and one backward pass"""
    # Set RNG and environment varables
    reset_rng_states()
890
891
892
893
894
895
896
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
897
    _attention_backends["backend_selection_requires_update"] = True
898

899
    # Create seqlens
900
901
902
903
904
905
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
906
            seqlens_kv = seqlens_q
907
        if config.attn_type == "cross":
908
909
910
911
912
913
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
914
915
916
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
917
    else:
918
919
920
921
922
923
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
924
925
926
927
928
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

929
930
931
932
933
934
935
    seqlens_q_after_pad = seqlens_q.clone()
    seqlens_kv_after_pad = seqlens_kv.clone()
    cu_seqlens_q_after_pad = cu_seqlens_q.clone()
    cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
    pad_len = [0] * config.batch_size
    if pad_between_seqs:
        max_pad_len = 3
936
        pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda")  # 3
937
938
939
940
941
        seqlens_q_after_pad = seqlens_q + pad_len
        seqlens_kv_after_pad = seqlens_kv + pad_len
        cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
        cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)

942
943
944
    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
945
        if config.attn_type == "self":
946
947
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
948
949
950
951
952
953
954
955
956
957
958
959
960
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
961
            attention_mask = attention_mask_q.to(device="cuda")
962
        if config.attn_type == "cross":
963
964
965
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
                attention_mask_kv = torch.cat(
                    [
                        attention_mask_kv,
                        torch.Tensor(
                            [False] * seqlens_kv[i]
                            + [True] * (config.max_seqlen_kv - seqlens_kv[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
993
            attention_mask = (
994
995
996
                attention_mask_q.to(device="cuda"),
                attention_mask_kv.to(device="cuda"),
            )
997

998
    alibi_slopes = None
999
1000
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
        if config.bias_shape == "1hss":
1001
1002
1003
            alibi_slopes = (
                torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
            )
1004
        if config.bias_shape == "bhss":
1005
1006
1007
1008
1009
            alibi_slopes = (
                torch.randn(config.batch_size, config.num_heads)
                .abs()
                .to(dtype=torch.float32, device="cuda")
            )
1010

1011
1012
    # Create input tensors
    dim_to_num = {
1013
1014
1015
1016
1017
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
1018
1019
        "dqk": config.head_dim_qk,
        "dv": config.head_dim_v,
1020
1021
1022
1023
1024
1025
        "t": cu_seqlens_q_after_pad[-1],
        "tg": cu_seqlens_kv_after_pad[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
1026
    inp = []
1027
    inp_orig = []
1028
1029
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
1030
        if i == 0:
1031
            layout = layout.replace("s", "sq")
1032
        else:
1033
1034
1035
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
1036
1037
1038
1039
        if i == 2:
            layout = layout.replace("d", "dv")
        else:
            layout = layout.replace("d", "dqk")
1040
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
1041
        tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
1042
1043
        # tensor: with padding tokens
        # tensor_orig: without padding tokens
1044
        tensor_orig = tensor
1045
1046
        if qkv_format == "thd" and pad_between_seqs:
            tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1047
            if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_q_after_pad[i - 1],
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_q_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
1061
            if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_kv_after_pad[i - 1],
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_kv_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
1075
1076
        tensor_count = 1
        split_dim = 0
1077
        for dim, l in enumerate(layout.split("_")):
1078
1079
1080
1081
1082
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
1083
1084
1085
        tensors_orig = (
            torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
        )
1086
1087
1088
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
1089
                inp_orig.append(tensors_orig[j].squeeze(split_dim))
1090
1091
            else:
                inp.append(tensors[j])
1092
                inp_orig.append(tensors_orig[j])
1093
    for i in range(3):
1094
        inp[i].requires_grad = True
1095
1096
        inp_orig[i].requires_grad = True

1097
    # Create output gradient
1098
1099
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
1100
    qkv_format_kv = qkv_format_kv.replace("d", "dv")
1101
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
1102
1103
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
    out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
1104
    out_grad_orig = out_grad
1105
1106
    if qkv_format == "thd" and pad_between_seqs:
        out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1107
        if qkv_format_kv == "t_h_dv":
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
            for i in range(1, config.batch_size + 1):
                valid_range = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
                out_grad[pad_range[0] : pad_range[1]] = 0.0
                out_grad_orig = torch.cat(
                    [out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
                )
1118

1119
    # Create bias
1120
    if config.attn_bias_type in ["no_bias", "alibi"]:
1121
        bias = None
1122
1123
1124
1125
    if config.attn_bias_type == "post_scale_bias":
        shape = "_".join(config.bias_shape)
        shape = shape.replace("_s_s", "_sq_skv")
        tensor_shape = [dim_to_num[j] for j in shape.split("_")]
1126
        bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
1127
        if config.bias_shape != "1hss":
1128
            bias.requires_grad = False
1129
1130
1131
1132

    # Create RNG
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1133

1134
1135
1136
1137
1138
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    # Set up model
1139
1140
    block = DotProductAttention(
        config.num_heads,
1141
        (config.head_dim_qk, config.head_dim_v),
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
        num_gqa_groups=config.num_gqa_groups,
        attention_dropout=config.dropout_p,
        qkv_format=qkv_format,
        attn_mask_type=config.attn_mask_type,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type=config.attn_type,
1152
        softmax_type=config.softmax_type,
1153
        return_max_logit=config.return_max_logit,
1154
    ).to(dtype=dtype, device="cuda")
1155
1156
    if not is_training:
        block = block.eval()
1157
1158
    if is_training and config.softmax_type != "vanilla":
        block.softmax_offset.requires_grad = True
1159

1160
    # Run a forward and backward pass
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        q = inp_orig[0]
        k = inp_orig[1]
        v = inp_orig[2]
        d_out = out_grad_orig
    if backend == "FusedAttention":
        q = inp[0]
        k = inp[1]
        v = inp[2]
        d_out = out_grad
1171
1172
1173
1174
    out = block(
        q,
        k,
        v,
1175
        window_size=config.window_size,
1176
1177
1178
1179
1180
1181
        attention_mask=attention_mask,
        qkv_format=qkv_format,
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1182
1183
        cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
        cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
1184
1185
1186
1187
1188
1189
        attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=config.attn_bias_type,
        core_attention_bias=bias,
        alibi_slopes=alibi_slopes,
        fast_zero_fill=True,
1190
1191
        # Only pass num_splits when exercising the FlashAttention path
        num_splits=config.num_splits if backend == "FlashAttention" else 1,
1192
    )
1193
1194
1195
    max_logit = None
    if config.return_max_logit:
        out, max_logit = out
1196
1197
    if is_training:
        out.backward(d_out)
1198

1199
1200
1201
    d_softmax_offset = None
    if is_training and config.softmax_type != "vanilla":
        d_softmax_offset = block.softmax_offset.grad
1202

1203
1204
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        if is_training:
1205
            return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
1206
        else:
1207
            return out, max_logit, (None, None, None, d_softmax_offset)
1208
    if backend == "FusedAttention":
1209
1210
        if qkv_format == "thd" and pad_between_seqs:
            out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1211
1212
1213
1214
            if is_training:
                q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
            for i in range(1, config.batch_size + 1):
                valid_range_q = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                valid_range_kv = (
                    cu_seqlens_kv_after_pad[i - 1],
                    cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                )
                out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
                if is_training:
                    q_grad_orig = torch.cat(
                        [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
                    )
                    k_grad_orig = torch.cat(
                        [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
                    v_grad_orig = torch.cat(
                        [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
1235
            if is_training:
1236
1237
1238
1239
1240
                return (
                    out_orig,
                    max_logit,
                    (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
                )
1241
            else:
1242
                return out_orig, max_logit, (None, None, None, d_softmax_offset)
1243
1244
        else:
            if is_training:
1245
                return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
1246
            else:
1247
                return out, max_logit, (None, None, None, d_softmax_offset)
1248

1249

1250
model_configs_te_layer = {
1251
    # test: ModelConfig(b, sq, hq, dqk)
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
    "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "te_1_1": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "te_1_2": ModelConfig(
        2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
    ),
    "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
    "te_2_1": ModelConfig(2, 2048, 16, 64),
    "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
    "te_2_3": ModelConfig(
        1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
    "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
1268
}
1269

1270

1271
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1272
@pytest.mark.parametrize("dtype", param_types)
1273
1274
1275
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
1276
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
1277
1278
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
1279
1280
1281
def test_transformer_layer(
    dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
1282
    """Test TransformerLayer module"""
1283

Tim Moon's avatar
Tim Moon committed
1284
    # Get configs
1285
    config = model_configs[model]
1286
    tols = dict(atol=5e-2, rtol=5e-2)
1287
    workspace_opt = True
1288

1289
    # Test backend availability
1290
    is_training = True
1291
    available_backends, _, fused_attn_backends = get_available_attention_backends(
Tim Moon's avatar
Tim Moon committed
1292
        config,
1293
        qkv_dtype=dtype,
1294
1295
1296
        qkv_layout=(
            qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
        ),
1297
        is_training=is_training,
Tim Moon's avatar
Tim Moon committed
1298
    )
1299
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1300
1301
    if not fused_attn_supported:
        is_training = False
1302
        available_backends, _, fused_attn_backends = get_available_attention_backends(
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
            config,
            qkv_dtype=dtype,
            qkv_layout=(
                qkv_format.replace("hd", "h3d")
                if fused_qkv_params
                else qkv_format.replace("hd", "3hd")
            ),
            is_training=is_training,
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1313
1314
1315

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
1316
        pytest.skip("Less than two backends to compare.")
1317
1318
1319
    # Skip if qkv_format = thd and "padding" not in attn_mask_type
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        pytest.skip("THD requires padding mask.")
Tim Moon's avatar
Tim Moon committed
1320
1321

    # UnfusedDotProductAttention backend
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_format,
            workspace_opt,
            fused_qkv_params,
            RoPE,
1332
            is_training,
1333
        )
Tim Moon's avatar
Tim Moon committed
1334
1335
1336
1337
1338
1339
1340

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FusedAttention",
1341
1342
1343
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1344
1345
            fused_qkv_params,
            RoPE,
1346
            is_training,
Tim Moon's avatar
Tim Moon committed
1347
        )
1348

Tim Moon's avatar
Tim Moon committed
1349
1350
1351
1352
1353
1354
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FlashAttention",
1355
1356
1357
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1358
1359
            fused_qkv_params,
            RoPE,
1360
            is_training,
Tim Moon's avatar
Tim Moon committed
1361
        )
1362

1363
    logging.info(f"[test_transformer_layer]: is_training = {is_training}")
1364
    if unfused_attn_supported and fused_attn_supported:
1365
        logging.info("[test_transformer_layer]: unfused attn vs fused attn")
1366
1367
1368
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
    if unfused_attn_supported and flash_attn_supported:
1369
        logging.info("[test_transformer_layer]: unfused attn vs flash attn")
Tim Moon's avatar
Tim Moon committed
1370
1371
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
1372
    if fused_attn_supported and flash_attn_supported:
1373
        logging.info("[test_transformer_layer]: fused attn vs flash attn")
1374
1375
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
1376

1377

1378
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1379
@pytest.mark.parametrize("dtype", param_types_lean)
1380
1381
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
1382
1383
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format):
hugo-syn's avatar
hugo-syn committed
1384
    """Test TransformerLayer module with miscellaneous settings"""
1385
1386
1387
    ckpt_attn = True
    fused_qkv_params = True
    RoPE = True
1388
1389
1390
    test_transformer_layer(
        dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
    )
1391

1392

1393
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1394
1395
1396
1397
1398
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
    """Test TransformerLayer module with MQA/GQA"""
1399

1400
    def find_factors(x):
1401
1402
1403
1404
1405
        f = []
        for i in range(2, x + 1):
            if x % i == 0:
                f.append(i)
        return f
1406

1407
1408
1409
1410
1411
1412
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    config = model_configs[model]
    num_querys_per_gqa_group = find_factors(config.num_heads)
1413
1414

    for num_q_per_gqa_group in num_querys_per_gqa_group:
1415
1416
1417
1418
        config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
        test_transformer_layer(
            dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
        )
1419

1420

1421
def _run_transformer_layer(
1422
1423
1424
1425
1426
1427
1428
1429
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_format: str,
    workspace_opt: bool,
    fused_qkv_params: bool,
    RoPE: bool,
1430
    is_training: bool,
1431
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
1432
1433
1434
    """Run TransformerLayer module with one forward pass and one backward pass"""

    # Set RNG and environment variables
1435
    reset_rng_states()
1436
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
1437
    os.environ["NVTE_FUSED_ATTN"] = "0"
1438
1439
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
1440
1441
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
1442
    _attention_backends["backend_selection_requires_update"] = True
1443

1444
    # Create input tensor
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
    if qkv_format == "sbhd":
        inp = torch.randn(
            config.max_seqlen_q,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.max_seqlen_kv,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1462
    if qkv_format == "bshd":
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
        inp = torch.randn(
            config.batch_size,
            config.max_seqlen_q,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.batch_size,
            config.max_seqlen_kv,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1479
1480

    # Create seqlens
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
1497
    else:
1498
1499
1500
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
    if qkv_format == "thd":
        inp = torch.randn(
            cu_seqlens_q[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            cu_seqlens_kv[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1523
1524
1525
1526
1527
1528
1529

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
1530
    drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
1531

1532
    # Create bias
1533
    bias = None
1534
1535
1536
1537
1538
1539
1540
1541
1542
    if config.attn_bias_type == "post_scale_bias":
        bias = torch.randn(
            1,
            config.num_heads,
            config.max_seqlen_q,
            config.max_seqlen_kv,
            dtype=dtype,
            device="cuda",
        )
1543
1544
1545
1546

    # Create RoPE
    rotary_pos_emb = None
    if RoPE:
1547
        PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
1548
        rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1549
1550

    # Set up model
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_heads,
        num_gqa_groups=config.num_gqa_groups,
        layernorm_epsilon=1e-5,
        hidden_dropout=0.0,
        attention_dropout=config.dropout_p,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        layer_number=layer_number,
1562
        kv_channels=config.head_dim_qk,
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
        self_attn_mask_type=config.attn_mask_type,
        tp_group=None,
        tp_size=1,
        params_dtype=dtype,
        get_rng_state_tracker=None,
        fuse_wgrad_accumulation=False,
        seq_length=config.max_seqlen_q,
        micro_batch_size=config.batch_size,
        sequence_parallel=False,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
1574
        layer_type="encoder" if config.attn_type == "self" else "decoder",
1575
1576
1577
1578
1579
1580
1581
1582
1583
        drop_path_rate=drop_path_rates[layer_number - 1],
        set_parallel_mode=True,
        fuse_qkv_params=fused_qkv_params,
        zero_centered_gamma=False,
        qkv_weight_interleaved=False,
        ub_tp_comm_overlap=False,
        bias=True,
        attn_input_format=qkv_format,
    ).to(dtype=dtype, device="cuda")
1584
1585
    if not is_training:
        block = block.eval()
1586

1587
1588
1589
    # Create ALiBi slopes
    alibi_slopes = None
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
1590
        alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
1591

1592
    # Run a forward and backward pass
1593
1594
    out = block(
        inp,
1595
        self_attn_mask_type=config.attn_mask_type,
1596
1597
        encoder_output=inp_enc if config.attn_type == "cross" else None,
        enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
1598
1599
1600
        checkpoint_core_attention=False,
        rotary_pos_emb=rotary_pos_emb,
        core_attention_bias_type=config.attn_bias_type,
1601
        core_attention_bias=bias,
1602
        alibi_slopes=alibi_slopes,
1603
1604
1605
1606
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1607
    )
1608
1609
1610
    if is_training:
        loss = out.sum()
        loss.backward()
1611
1612

    return out, inp.grad
1613
1614


1615
model_configs_fp8_extra_state = {
1616
    # test: ModelConfig(b, sq, hq, dqk)
1617
1618
1619
1620
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
}


1621
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1622
1623
1624
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
1625
1626
def test_dpa_fp8_extra_state(model, dtype):
    """Test DotProductAttention module in FP8 with checkpointing"""
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
    config = model_configs_fp8_extra_state[model]
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="sb3hd",
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not fused_attn_supported and not flash_attn_supported:
        pytest.skip("No attention backend available.")

1640
1641
1642
    outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False)
    outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True)
    outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state(
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
        dtype, config, mimic_v1_6=True, checkpoint=True
    )

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )


1664
1665
def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
    """Run DotProductAttention module in FP8 with checkpointing"""
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
    steps = 10
    path = "checkpoint.pt"
    fp8_enabled = True
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_enabled,
        fp8_mha=False,
    )

    reset_rng_states()
    hidden_states = torch.randn(
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

    def get_model(dtype, config):
        sigma = 0.023
        init_method = init_method_normal(sigma)
        output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1691
        with quantized_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
            block = TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_heads,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                hidden_dropout=0.0,
                attention_dropout=0.0,
                fuse_qkv_params=True,
                params_dtype=dtype,
                device="cuda",
            )
        return block

    block = get_model(dtype, config)
    for i in range(steps // 2):
1708
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    if checkpoint:
        sd = block.state_dict()
        if mimic_v1_6:
            sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
                "self_attention.core_attention._extra_state"
            ]
            del sd["self_attention.core_attention._extra_state"]
        torch.save(sd, path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

        _cpu_rng_state_new = torch.get_rng_state()
        _cuda_rng_state_new = torch.cuda.get_rng_state()

        del block
        block = get_model(dtype, config)
        block.load_state_dict(torch.load(path, weights_only=False))
        torch.set_rng_state(_cpu_rng_state_new)
        torch.cuda.set_rng_state(_cuda_rng_state_new)

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for i in range((steps + 1) // 2):
1743
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [output, hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)

    return outputs


1761
model_configs_fp8_vs_f16 = {
1762
    # test: ModelConfig(b, sq, hq, dqk)
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
    "fp8_9": ModelConfig(2, 2048, 16, 128),
    "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
    "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
    "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
    "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
    "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
    "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
    "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
    "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
    "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
1775
}
1776

1777
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
1778
1779
1780
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]

1781

1782
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1783
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1784
1785
1786
1787
1788
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1789
@pytest.mark.parametrize("RoPE", [True, False])
1790
@pytest.mark.parametrize("is_training", [True, False])
1791
1792
1793
1794
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_mha_fp8_vs_f16(
    dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
):
1795
    """Test MultiHeadAttention module in FP8"""
1796
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1797
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1798
1799
    config = model_configs_fp8_vs_f16[model]

1800
    # Test backend availability
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
            fp8_mha=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
            fp8_mha=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
1818
1819
1820
1821
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_format.replace("hd", "h3d"),
1822
1823
        fp8=True,
        fp8_meta=fp8_meta,
1824
1825
        is_training=is_training,
    )
1826
1827
    flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
    if flash_attn_supported + fused_attn_supported_fp8 < 1:
1828
        pytest.skip("No FP8 attention backend available.")
1829
    fused_attn_supported_f16 = False
1830
1831
1832
1833
1834
1835
1836
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_format.replace("hd", "h3d"),
            is_training=is_training,
        )
1837
1838
        _, fused_attn_supported_f16, _ = available_backends
        if not fused_attn_supported_f16:
1839
1840
1841
            pytest.skip("No attention backend available.")

    if flash_attn_supported:
1842
1843
1844
1845
1846
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1847
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1848
        )
1849

1850
1851
1852
1853
1854
1855
1856
1857
    if fused_attn_supported_fp8:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "1"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
        )
1858

1859
1860
1861
1862
1863
1864
1865
1866
    if fused_attn_supported_f16:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "1"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
        fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
            dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
        )
1867

1868
1869
1870
    atol = 5e-1
    rtol = 5e-1
    rmse_tol = 0.15
1871
    if flash_attn_supported and fused_attn_supported_f16:
1872
1873
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
1874
        compare_and_assert(
1875
1876
1877
1878
1879
1880
1881
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
1882
            True,
1883
        )
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
    if fused_attn_supported_fp8 and fused_attn_supported_f16:
        logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
        compare_and_assert(
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
            True,
        )
1897

1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
        if is_training:
            for i in range(len(param_names[:1])):
                logging.debug("========== {:^25s} ==========".format(param_names[i]))
                compare_and_assert(
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                    True,
                )
1911

1912

1913
1914
1915
def _run_mha_fp8_vs_f16(
    dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
):
1916
    """Run MultiHeadAttention module in FP8"""
1917
1918
1919
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1920

1921
1922
1923
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
Tim Moon's avatar
Tim Moon committed
1924

1925
    with quantized_model_init(enabled=fp8_mha, recipe=fp8_recipe):
1926
1927
1928
1929
        rotary_pos_emb = None
        if RoPE:
            PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
            rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1930
        mha = MultiheadAttention(
1931
1932
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_heads,
1933
            kv_channels=config.head_dim_qk,
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            layer_number=1,
            bias=True,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            params_dtype=dtype,
            input_layernorm=input_layernorm,
            fuse_qkv_params=True,
            attention_type="self",
            qkv_weight_interleaved=True,
            qkv_format=qkv_format,
1945
        ).to(dtype=dtype, device="cuda")
1946
1947
        if not is_training:
            mha = mha.eval()
1948

1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
1969
1970
1971
1972
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
1973

1974
    dim_to_num = {
1975
1976
1977
1978
1979
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
1980
        "d": config.head_dim_qk,
1981
1982
1983
1984
1985
1986
1987
1988
1989
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
    layout = "_".join(qkv_format)
    layout = layout.replace("s", "sq")
    tensor_shape = [dim_to_num[j] for j in layout.split("_")]
1990
1991
    tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
    hidden_states = tensor.view(*tensor.shape[:-2], -1)
1992
1993
    if is_training:
        hidden_states.requires_grad = True
1994
1995
1996
    tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
    out_grad = tensor.view(*tensor.shape[:-2], -1)

1997
    with autocast(enabled=fp8_mha, recipe=fp8_recipe):
1998
1999
        out = mha(
            hidden_states,
2000
2001
2002
2003
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
            is_first_microbatch=None,
2004
            rotary_pos_emb=rotary_pos_emb,
2005
2006
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
2007
        )
2008
2009
    if is_training:
        out.backward(out_grad)
Tim Moon's avatar
Tim Moon committed
2010

2011
    param_names = []
2012
    param_names.append("hidden_states.grad")
2013
2014
2015
2016
    params = []
    params.append(hidden_states)
    for name, param in mha.named_parameters():
        if param.requires_grad:
2017
            param_names.append(name + ".grad")
2018
            params.append(param)
2019

2020
2021
2022
    if is_training:
        return out, param_names, tuple(x.grad for x in params)
    return out, param_names, tuple(None for x in params)
2023

2024

2025
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
2026
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
2027
2028
2029
2030
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
2031
@pytest.mark.parametrize("is_training", [True, False])
2032
2033
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
2034
    """Test DotProductAttention module in FP8"""
2035
2036
    config = model_configs_fp8_vs_f16[model]

2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
    # TODO(cyang): think of another way to verify dropout results
    # test cuDNN FP8 dropout
    # 1. we modify the config here to not affect mha_fp8_vs_f16 tests
    # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an
    #    indirect verification method, we create Q/K/V as all 1s and check if O is all 1s
    # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell
    # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type:
    #    if get_device_compute_capability() >= (10, 0):
    #        config.dropout_p = 0.1

2047
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
2048
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
2049
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"
2050

2051
    # Test backend availability
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
2067
2068
2069
2070
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_layout,
2071
2072
        fp8=True,
        fp8_meta=fp8_meta,
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if flash_attn_supported + fused_attn_supported < 1:
        pytest.skip("No FP8 attention backend available.")
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            is_training=is_training,
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")

    if flash_attn_supported:
2092
2093
2094
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
2095
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
2096
        flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
            dtype, config, True, qkv_layout, is_training, fp8_recipe
        )

    if unfused_attn_supported:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
        unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
            dtype, config, True, qkv_layout, is_training, fp8_recipe
2107
        )
2108

2109
2110
2111
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
    _attention_backends["backend_selection_requires_update"] = True
2112
    logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
2113
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2114
        dtype, config, True, qkv_layout, is_training, fp8_recipe
2115
    )
2116

2117
2118
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
2119
2120
    if config.dropout_p == 0.0:
        # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
2121
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
2122
        fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
2123
            dtype, config, False, qkv_layout, is_training, fp8_recipe
2124
        )
2125

2126
2127
    atol = 5e-1
    rtol = 5e-2
2128
    rmse_tol = 0.11
2129
    bwd_names = ["dq", "dk", "dv"]
2130
    if flash_attn_supported:
2131
2132
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2133
        compare_and_assert(
2134
2135
2136
2137
2138
2139
2140
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2141
            True,
2142
        )
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
    if unfused_attn_supported:
        logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
        compare_and_assert(
            unfused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "unfused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
            True,
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
                compare_and_assert(
                    unfused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"unfused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                    True,
                )
2169
2170
2171
2172
2173
2174
    if config.dropout_p != 0.0:
        # test cuDNN FP8 dropout
        assert torch.all(
            fused_attn_fwd_fp8 == 1
        ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
    else:
2175
2176
        logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2177
        compare_and_assert(
2178
2179
2180
2181
2182
2183
2184
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2185
            True,
2186
2187
2188
2189
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
2190
                compare_and_assert(
2191
2192
2193
2194
2195
2196
2197
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
2198
                    True,
2199
                )
2200
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0"
2201
2202


2203
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe):
2204
    """Run DotProductAttention module in FP8"""
2205
2206
2207
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
2208

2209
2210
2211
2212
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

2213
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
2214
    with quantized_model_init(enabled=fp8_dpa):
2215
2216
        dpa = DotProductAttention(
            config.num_heads,
2217
            config.head_dim_qk,
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            sequence_parallel=False,
            tp_size=1,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            tp_group=None,
            layer_number=1,
            attention_type="self",
            qkv_format=qkv_format,
        ).to(dtype=dtype, device="cuda")
2228
2229
        if not is_training:
            dpa = dpa.eval()
2230

2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
2251
2252
2253
2254
2255
2256
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

    dim_to_num = {
2257
2258
2259
2260
2261
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2262
        "d": config.head_dim_qk,
2263
2264
2265
2266
2267
2268
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
2269
    inp = []
2270
2271
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
2272
        if i == 0:
2273
            layout = layout.replace("s", "sq")
2274
        else:
2275
2276
2277
2278
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2279
2280
2281
2282
2283
        if config.dropout_p == 0.0:
            tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
        else:
            # test cuDNN FP8 dropout
            tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda")
2284
2285
        tensor_count = 1
        split_dim = 0
2286
        for dim, l in enumerate(layout.split("_")):
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
        inp[i].requires_grad = True

2300
2301
2302
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
2303
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
2304
    out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
2305

2306
    with autocast(enabled=fp8_dpa, recipe=fp8_recipe):
2307
2308
2309
2310
        out = dpa(
            inp[0],
            inp[1],
            inp[2],
2311
2312
2313
2314
2315
2316
2317
2318
            qkv_format=qkv_format,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
2319
            fp8_output=fp8_dpa,
2320
        )
2321
2322
    if is_training:
        out.backward(out_grad)
2323

2324
2325
2326
    if is_training:
        return out, (inp[0].grad, inp[1].grad, inp[2].grad)
    return out, (None, None, None)
2327
2328
2329


model_configs_fp8 = {
2330
    # test: ModelConfig(b, sq, hq, dqk)
2331
2332
2333
2334
2335
2336
2337
2338
    "fp8_1": ModelConfig(1, 512, 1, 64),
    "fp8_2": ModelConfig(4, 512, 16, 64),
    "fp8_3": ModelConfig(1, 2048, 1, 128),
    "fp8_4": ModelConfig(2, 2048, 24, 128),
    "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
    "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
    "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
    "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
2339
2340
}
param_types_fp8 = [torch.float16, torch.bfloat16]
2341
2342
2343
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
2344
2345


2346
2347
2348
2349
2350
2351
2352
2353
@pytest.mark.skipif(
    (
        get_cudnn_version() < (8, 9, 3)
        if cudnn_frontend_version == 0
        else get_cudnn_version() < (9, 2, 1)
    ),
    reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
2354
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
    """Test FP8 dot product attention implementations based on cuDNN frontend
    v0.9 and v1.0+. Each test compares results from a custom implementation of
    an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
    implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
    Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""

    config = model_configs_fp8[model]

2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not (fused_attn_backends and unfused_attn_supported):
        pytest.skip("Not enough backends to run this test with.")

2378
2379
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
    unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
2380

2381
2382
    atol = 5e-1
    rtol = 5e-1
2383
    rmse_tol = 0.13
2384
    compare_and_assert(
2385
2386
2387
2388
2389
2390
2391
        fused_attn_fwd_fp8,
        unfused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "unfused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
2392
        True,
2393
    )
2394
    compare_and_assert(
2395
2396
2397
2398
2399
2400
2401
        fused_attn_bwd_fp8,
        unfused_attn_bwd_f16,
        "fused_attn_bwd_fp8",
        "unfused_attn_bwd_f16",
        atol,
        rtol,
        rmse_tol,
2402
        True,
2403
    )
2404
2405
2406
2407
2408


def _run_custom_mha_fp8(dtype, config, backend):
    """Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
2409
    reset_rng_states()
2410
2411
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
2412
2413
2414
2415
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2416
    _attention_backends["backend_selection_requires_update"] = True
2417

2418
2419
2420
    inp = 0.0001 * torch.randint(
        -100,
        100,
2421
        (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
2422
2423
2424
2425
2426
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
2427
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2428
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2429

2430
    out_grad = 0.01 * torch.randn(
2431
        config.batch_size * config.max_seqlen_q,
2432
        config.num_heads * config.head_dim_qk,
2433
2434
2435
2436
        dtype=dtype,
        device="cuda",
    )
    torch.save(out_grad, "out_grad.pt")
2437
2438
2439
2440
2441
2442
2443
2444

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

2445
    mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
2446
    with autocast(enabled=True, recipe=fp8_recipe):
2447
        out = mha(inp, cu_seqlens, config.max_seqlen_q)
2448
    out.backward(out_grad)
2449

2450
    out = torch.load("out.pt")
2451
2452
2453
2454
    dqkv = torch.load("dqkv.pt")
    return (
        out.view(config.batch_size, config.max_seqlen_q, -1),
        dqkv.view(
2455
            config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
2456
2457
        ).contiguous(),
    )
2458

2459

2460
2461
2462
def _run_ref_mha_f16(dtype, config, backend):
    """Run reference F16 FusedAttention. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
2463
2464
2465
2466
2467
2468
2469

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2470
    _attention_backends["backend_selection_requires_update"] = True
2471

2472
    inp = torch.load("qkv.pt").to(device="cuda")
2473
2474
2475
    inp.requires_grad = True
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2476
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2477
2478
2479
    out_grad = (
        torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
    )
2480
2481
2482
2483
2484
2485
2486

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
2487

2488
2489
    block = DotProductAttention(
        config.num_heads,
2490
        config.head_dim_qk,
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
        attention_dropout=config.dropout_p,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type="self",
        qkv_format="bshd",
    ).to(dtype=dtype, device="cuda")

    q = inp[:, :, 0, :, :]
    k = inp[:, :, 1, :, :]
    v = inp[:, :, 2, :, :]
2504
2505
2506
2507
    out = block(q, k, v, attn_mask_type=config.attn_mask_type)
    out.backward(out_grad)

    return out, inp.grad
2508
2509
2510
2511
2512
2513
2514


_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

2515
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
2516
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
2517
2518
2519
2520
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
2521
2522


2523
class _custom_mha_fp8(torch.autograd.Function):
2524
2525
2526
2527
2528
2529
2530
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
2531
        num_heads: int,
2532
2533
2534
2535
2536
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        is_training: bool,
2537
        mask_type: str,
2538
        quantizers: list[Quantizer],
2539
    ) -> torch.Tensor:
2540
        qkv_dtype = inp.dtype
2541
2542
2543

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
2544
        h = num_heads
2545
2546
2547
        d = in_features // h
        b = cu_seqlens.numel() - 1

2548
2549
2550
2551
2552
2553
2554
2555
        input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
        qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
        qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
        dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
        dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
2556

2557
        inp_fp8 = input_quantizer(inp)
2558

2559
        qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight)
2560

2561
        qkv, *_ = ext.general_gemm(
2562
            qkv_weight_fp8,
2563
            inp_fp8,
2564
            bias=qkv_bias,
2565
2566
            out_dtype=qkv_weight_fp8.dtype,
            quantization_params=qkv_quantizer,
2567
2568
            use_split_accumulator=_2X_ACC_FPROP,
        )
2569
        qkv = qkv.view(-1, 3, h, d)
2570
        qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
2571
        torch.save(qkv_fp16, "qkv.pt")
2572
        if cudnn_frontend_version == 1:
2573
            qkv = qkv.view(b, max_s, 3, h, d)  # bs3hd
2574
2575

        # FMHA
2576
2577
2578
2579
2580
2581
2582
2583
        q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
        k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
        v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
        q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
        k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
        v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)

        out, aux_ctx_tensors = fused_attn_fwd(
2584
2585
2586
2587
2588
            is_training,
            max_s,
            max_s,
            cu_seqlens,
            cu_seqlens,
2589
2590
2591
2592
            q,
            k,
            v,
            qkv_dtype,
2593
2594
2595
2596
2597
2598
2599
2600
            FusedAttnBackend["FP8"],
            attn_scale=None,
            dropout=p_dropout,
            fast_zero_fill=fast_zero_fill,
            qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
            attn_bias_type="no_bias",
            attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
            rng_gen=None,
2601
2602
            o_quantizer=o_quantizer,
            s_quantizer=s_quantizer,
2603
        )
2604

2605
        tensors_to_save, tensor_objects = prepare_for_saving(q, k, v, inp_fp8, qkv_weight_fp8, out)
2606
2607
2608

        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
2609
        ctx.aux_ctx_tensors = aux_ctx_tensors
2610
        ctx.qkv_dtype = qkv_dtype
2611
2612
2613
2614
2615
2616
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.hidden_size = in_features
2617
        ctx.num_heads = num_heads
2618
2619
        ctx.mask_type = mask_type
        ctx.dtype = inp.dtype
2620

2621
2622
2623
2624
2625
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = s_quantizer

2626
        out = out.view(-1, in_features)  # (bs)(hd)
2627
        out_fp16 = out.dequantize()
2628
        torch.save(out_fp16, "out.pt")  # (bs)(hd)
2629
        return out_fp16
2630
2631

    @staticmethod
2632
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
2633
        with torch.cuda.nvtx.range("_DPA"):
2634
            saved_tensors = ctx.saved_tensors
2635
            (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
2636
2637
                ctx.tensor_objects, saved_tensors
            )
2638

2639
            proj_dgrad = ctx.dO_quantizer(grad_output)
2640
            fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2641

2642
            dq, dk, dv, *rest = fused_attn_bwd(
2643
2644
2645
2646
                ctx.max_s,
                ctx.max_s,
                ctx.cu_seqlens,
                ctx.cu_seqlens,
2647
2648
2649
                q,
                k,
                v,
2650
2651
                out,
                proj_dgrad.view_as(out),
2652
                ctx.qkv_dtype,
2653
2654
2655
2656
2657
                fp8_dtype_backward,
                ctx.aux_ctx_tensors,
                FusedAttnBackend["FP8"],
                None,
                None,
2658
2659
2660
                ctx.S_quantizer,
                ctx.dP_quantizer,
                ctx.dQKV_quantizer,
2661
2662
2663
2664
2665
2666
2667
                attn_scale=None,
                dropout=ctx.p_dropout,
                fast_zero_fill=ctx.fast_zero_fill,
                qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
                attn_bias_type="no_bias",
                attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
            )
2668
            dim = 2 if cudnn_frontend_version == 1 else 1
2669
2670
            dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
            dqkv_shape = list(dq._data.shape)
2671
            dqkv_shape.insert(dim, 3)
2672
            dqkv_stride = list(dq._data.stride())
2673
            dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
2674
2675
2676
            dqkv.set_(
                dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride
            )  # bs3hd
2677

2678
            dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
2679
2680
            dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape)
            dqkv_c_fp16 = dqkv_c.dequantize()
2681
            torch.save(dqkv_c_fp16, "dqkv.pt")
2682

2683
2684
2685
            qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer)
            dqkv_c._transpose = None
            dqkv_c._create_transpose()
2686
2687

            # QKV DGRAD
2688
2689
            qkv_dgrad, *_ = ext.general_gemm(
                qkv_weight_fp8,
2690
                dqkv_c,
2691
                ctx.dtype,
2692
                use_split_accumulator=_2X_ACC_DGRAD,
2693
                layout="NN",
2694
            )
2695

2696
            # QKV WGRAD
2697
2698
2699
2700
            qkv_wgrad, *_ = ext.general_gemm(
                inp_fp8,
                dqkv,
                ctx.dtype,
2701
                use_split_accumulator=_2X_ACC_WGRAD,
2702
                layout="NT",
2703
2704
            )

2705
2706
        return (
            qkv_dgrad,
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2718
2719
            None,
        )
2720

2721

2722
class Custom_MHA_FP8(TransformerEngineBaseModule):
2723
    def __init__(self, config, params_dtype: torch.dtype = torch.float32):
2724
2725
        super().__init__()
        self.p_dropout = config.dropout_p
2726
        self.h = config.num_heads
2727
        self.hidden_size = config.hidden_size
2728
        self.head_dim = config.head_dim_qk
2729
        self.fast_zero_fill = True
2730
        self.mask_type = config.attn_mask_type
2731

Tim Moon's avatar
Tim Moon committed
2732
        self.qkv_weight = torch.nn.Parameter(
2733
2734
2735
2736
2737
2738
2739
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
Tim Moon's avatar
Tim Moon committed
2740
        self.qkv_bias = torch.nn.Parameter(
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)

    def forward(
2752
2753
2754
2755
        self,
        inp: torch.Tensor,
        cu_seqlens,
        max_s,
2756
    ) -> torch.Tensor:
2757
        with self.prepare_forward(inp, num_gemms=3) as inp:
2758
            out = _custom_mha_fp8.apply(
2759
2760
2761
2762
2763
2764
2765
2766
2767
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
2768
                self.training,
2769
                self.mask_type,
2770
                self.quantizers,
2771
            )
2772
        return out