test_attention.py 101 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
import logging
Tim Moon's avatar
Tim Moon committed
5
import os
6
7
import sys
import pathlib
8
from typing import Any, Dict, Tuple, Union
Tim Moon's avatar
Tim Moon committed
9

10
import pytest
Tim Moon's avatar
Tim Moon committed
11
import torch
12

13
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype
Tim Moon's avatar
Tim Moon committed
14
from transformer_engine.common import recipe
15
16
17
18
from transformer_engine.pytorch import (
    TransformerLayer,
    autocast,
    quantized_model_init,
Tim Moon's avatar
Tim Moon committed
19
    DotProductAttention,
20
21
22
23
24
25
26
    MultiheadAttention,
    get_device_compute_capability,
    Quantizer,
    is_fp8_available,
    is_bf16_available,
)
from transformer_engine.pytorch.attention.dot_product_attention import (
27
28
    _attention_backends,
)
29
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
30
    FlashAttentionUtils,
31
    check_set_window_size,
Tim Moon's avatar
Tim Moon committed
32
)
33
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
Tim Moon's avatar
Tim Moon committed
34
35
36
37
38
39
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    FusedAttnBackend,
    fused_attn_bwd,
    fused_attn_fwd,
)
40
41
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
42
43
44
45
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
46
from transformer_engine.pytorch.utils import get_cudnn_version
47
import transformer_engine_torch as tex
48
49
50
51
from transformer_engine.pytorch.tensor.quantized_tensor import (
    prepare_for_saving,
    restore_from_saved,
)
52

53
54
55
56
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
    reset_rng_states,
57
    compare_and_assert,
58
59
60
61
62
    ModelConfig,
    dtype_tols,
    get_available_attention_backends,
)

63
# Check if hardware supports FP8
64
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
65

66
# Reset RNG seed and states
67
seed = 1234
68
reset_rng_states()
69
70


71
# Reset FP8 global state manager
72
73
74
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
75
    FP8GlobalStateManager.reset()
76
77


78
79
# Define F16 data types to test
param_types = [torch.float16]
80
if is_bf16_available():
81
82
    param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
83

84
model_configs_base = {
85
    # test: ModelConfig(b, sq, hq, dqk)
86
87
88
89
90
91
92
93
94
95
96
97
    "base_1_0": ModelConfig(8, 128, 16, 64),
    "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
    "base_2_0": ModelConfig(2, 2048, 24, 128),
    "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
    "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
    "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
    "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
    "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
    "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
    "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
    "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
98
99
}

100

101
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
102
@pytest.mark.parametrize("dtype", param_types)
103
104
105
106
107
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
108
@pytest.mark.parametrize("swa", [False])
109
@pytest.mark.parametrize("pad_between_seqs", [False])
110
111
112
def test_dot_product_attention(
    dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
):
113
    """Test DotProductAttention module"""
114

Tim Moon's avatar
Tim Moon committed
115
    # Get configs
116
    tols = dict(atol=1e-3, rtol=1e-3)
Tim Moon's avatar
Tim Moon committed
117
    if dtype == torch.bfloat16:
118
        tols = dict(atol=1.5e-2, rtol=1.5e-2)
119
    config = model_configs[model]
120
    is_mla = config.head_dim_qk != config.head_dim_v
121
    is_mqa_gqa = config.num_heads != config.num_gqa_groups
122
123
    if qkv_layout is None:
        if config.attn_type == "self":
124
            qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
125
        else:
126
            qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
127
    if "3" in qkv_layout and config.attn_type == "cross":
128
        pytest.skip("No need to test this layout for cross attention")
Tim Moon's avatar
Tim Moon committed
129

130
131
132
    if config.window_size == (-1, -1) and swa:
        config.window_size = [2, 2]
    config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
133

134
    # Get backends
135
    is_training = True
136
    available_backends, _, fused_attn_backends = get_available_attention_backends(
137
        config,
138
        qkv_dtype=dtype,
139
        qkv_layout=qkv_layout,
140
        pad_between_seqs=pad_between_seqs,
141
        is_training=is_training,
Tim Moon's avatar
Tim Moon committed
142
    )
143
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
144
145
    if not fused_attn_supported:
        is_training = False
146
        available_backends, _, fused_attn_backends = get_available_attention_backends(
147
148
149
150
151
152
153
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            pad_between_seqs=pad_between_seqs,
            is_training=is_training,
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
154

155
156
    # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
    # mannually pads and unpads the input and output of FlashAttention for testing purposes
157
158
    if (
        pad_between_seqs
159
        and FlashAttentionUtils.is_installed
160
161
162
163
        and not (
            config.max_seqlen_q != config.max_seqlen_kv
            and config.attn_mask_type in ["causal", "padding_causal"]
        )
164
        and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
165
    ):
166
        flash_attn_supported = True
167
168
169

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
170
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
171
172

    # UnfusedDotProductAttention backend
173
174
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
175
176
177
178
179
180
181
182
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
183
        )
Tim Moon's avatar
Tim Moon committed
184
185
186

    # FusedAttention backend
    if fused_attn_supported:
187
        if len(fused_attn_backends) == 1:
188
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
189
190
191
192
193
194
195
196
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
197
            )
198
        if len(fused_attn_backends) == 2:
199
200
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
201
202
203
204
205
206
207
208
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
209
210
211
            )
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
            fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
212
213
214
215
216
217
218
219
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
220
            )
221

Tim Moon's avatar
Tim Moon committed
222
223
224
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
225
226
227
228
229
230
231
232
            dtype,
            config,
            "FlashAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
Tim Moon's avatar
Tim Moon committed
233
        )
234

235
    # Compare results
236
    logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
237
    if unfused_attn_supported and flash_attn_supported:
238
        logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
239
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
240
        for i, _ in enumerate(flash_attn_bwd):
241
            torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
242
243
244
245
246
    if unfused_attn_supported and fused_attn_supported:
        logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        for i, _ in enumerate(unfused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
247
    if fused_attn_supported and flash_attn_supported:
248
        logging.info("[test_dot_product_attention]: fused attn vs flash attn")
249
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
250
        for i, _ in enumerate(flash_attn_bwd):
251
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
252
    if fused_attn_supported and len(fused_attn_backends) == 2:
253
        logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
254
        torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
255
        for i, _ in enumerate(fused_attn_bwd):
256
257
            torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)

258

259
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
260
261
262
263
264
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
    """Test DotProductAttention module with checkpointing"""
265
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
266

267

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
model_configs_softmax = {
    # test: ModelConfig(b, sq, hq, dqk)
    "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
    "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"),
    "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"),
    "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
    "softmax_2_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
    ),
    "softmax_2_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
    ),
    "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"),
    "softmax_3_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one"
    ),
    "softmax_3_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable"
    ),
    "softmax_4_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal"
    ),
    "softmax_4_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="off-by-one",
    ),
    "softmax_4_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="learnable",
    ),
    "softmax_5_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal"
    ),
    "softmax_5_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="off-by-one",
    ),
    "softmax_5_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="learnable",
    ),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax(dtype, model_configs, model):
    """Test DotProductAttention module with different softmax types"""
    test_dot_product_attention(
        dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False
    )


347
model_configs_mla = {
348
    #TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
349
    #    test:             b,  h, hg, dqk, sq, skv,   p,      mask,      bias   # attn , backend
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    # "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),  # self , 0
    # "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    # "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    # "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),  # self , 1
    # "mla_2_1": ModelConfig(
    #     1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
    # ),  # cross, 1
    # "mla_2_2": ModelConfig(
    #     1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
    # ),  # cross, 1
    # "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),  # inference
    # "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),  # inference
    # "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),  # inference
    # "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),  # inference
364
    "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),  # inference
365
366
367
368
369
370
371
372
373
374
375
376
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
    """Test DotProductAttention module with Multi-Latent Attention (MLA)"""
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


377
model_configs_mask = {
378
    # test: ModelConfig(b, sq, hq, dqk)
379
380
381
382
383
384
    "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
    "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "mask_2_1": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
385
    ),
386
387
388
389
390
391
392
393
394
395
    "mask_2_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
    ),
    "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
396
    "mask_5_1": ModelConfig(
397
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
398
399
    ),
    "mask_5_2": ModelConfig(
400
401
402
403
404
405
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_7_0": ModelConfig(
        2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
406
    ),
407
408
409
410
411
412
413
    "mask_7_1": ModelConfig(
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
    ),
    "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
414
    "mask_10_0": ModelConfig(
415
        2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
416
    ),
417
    "mask_10_1": ModelConfig(
418
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
419
    ),
420
}
421

422

423
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
424
425
426
427
428
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
    """Test DotProductAttention module with different mask types"""
429
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
430

431

432
model_configs_bias = {
433
    # test: ModelConfig(b, sq, hq, dqk)
434
435
436
437
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
438
439
    "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),
    "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"),
440
441
    "bias_2_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
442
    ),
443
444
445
446
447
448
449
450
    "bias_2_1": ModelConfig(
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
451
    ),
452
    "bias_2_2": ModelConfig(
453
        4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
454
    ),
455
    "bias_2_3": ModelConfig(
456
457
458
459
460
461
462
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
463
464
    ),
    "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"),
465
466
    "bias_2_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
467
    ),
468
469
470
471
472
473
474
475
476
    "bias_3_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_1": ModelConfig(
        2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_2": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
477
    "bias_3_3": ModelConfig(
478
479
480
481
482
483
484
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="causal",
        attn_bias_type="post_scale_bias",
485
    ),
486
487
488
    "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
    "bias_3_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
489
    ),
490
    "bias_4_0": ModelConfig(
491
        4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
492
    ),
493
    "bias_4_1": ModelConfig(
494
495
496
497
498
499
500
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
501
    ),
502
    "bias_4_2": ModelConfig(
503
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
504
    ),
505
    "bias_4_3": ModelConfig(
506
507
508
509
510
511
512
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
513
    ),
514
515
    "bias_4_4": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
516
    ),
517
518
519
520
521
522
523
524
    "bias_4_5": ModelConfig(
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="alibi",
525
    ),
526
}
527

528

529
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
530
531
532
533
534
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types"""
535
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
536

537

538
model_configs_bias_shapes = {
539
    # test: ModelConfig(b, sq, hq, dqk)
540
541
542
543
544
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
    "bias_1_4": ModelConfig(
545
        4,
546
547
        2048,
        24,
548
        128,
549
550
551
552
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="1hss",
        alibi_type="custom",
553
554
    ),
    "bias_1_5": ModelConfig(
555
556
557
558
559
560
561
562
        2,
        2048,
        24,
        128,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="bhss",
        alibi_type="custom",
563
    ),
564
565
}

566

567
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
568
569
570
571
572
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types and shapes"""
573
574
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)

575

576
model_configs_swa = {
577
    # test: ModelConfig(b, sq, hq, dqk)
578
579
580
581
582
583
584
585
586
587
588
589
    "swa_1_1": ModelConfig(2, 2048, 16, 64),
    "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
    "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
    "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "swa_3_2": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
    ),
    "swa_3_3": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
590
    ),
591
592
593
594
595
596
597
    "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
    "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
598
    "swa_6_2": ModelConfig(
599
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
600
601
    ),
    "swa_6_3": ModelConfig(
602
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
603
    ),
604
}
605
606


607
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
608
609
610
611
612
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
    """Test DotProductAttention module with sliding window attention"""
613
614
    test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)

615

616
model_configs_alibi_slopes = {
617
    # test: ModelConfig(b, sq, hq, dqk)
618
619
620
621
622
623
624
625
626
627
628
629
630
    "alibi_1_0": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
    ),
    "alibi_1_1": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="vanilla",
    ),
631
    "alibi_2_0": ModelConfig(
632
        2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
633
634
    ),
    "alibi_2_1": ModelConfig(
635
636
637
638
639
640
641
642
        1,
        1024,
        24,
        128,
        max_seqlen_kv=2048,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="custom",
643
    ),
644
}
645
646


647
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
648
649
650
651
652
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
    """Test DotProductAttention module with ALiBi slopes"""
653
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
654

655

656
qkv_layouts = [
657
658
659
660
661
662
663
664
665
666
667
    "sb3hd",
    "sbh3d",
    "sbhd_sb2hd",
    "sbhd_sbh2d",
    "sbhd_sbhd_sbhd",
    "bs3hd",
    "bsh3d",
    "bshd_bs2hd",
    "bshd_bsh2d",
    "bshd_bshd_bshd",
]
668

669

670
model_configs_layout = {
671
    # test: ModelConfig(b, sq, hq, dqk)
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
    "layout_0_0": ModelConfig(2, 128, 16, 64),
    "layout_0_1": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "layout_0_3": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_1_0": ModelConfig(2, 2048, 24, 128),
    "layout_1_1": ModelConfig(
        2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_3": ModelConfig(
        1,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
    "layout_2_1": ModelConfig(
        2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
704
705
}

706

707
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
708
@pytest.mark.parametrize("dtype", param_types_lean)
709
710
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
711
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
712
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
713
    """Test DotProductAttention module with different QKV layouts"""
714
715
716
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


717
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
718
model_configs_layout_thd = {
719
    # test: ModelConfig(b, sq, hq, dqk)
720
721
722
723
724
725
726
    "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "layout_1_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
727
    ),
728
    "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
729
    "layout_2_1": ModelConfig(
730
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
731
732
    ),
    "layout_2_2": ModelConfig(
733
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
734
    ),
735
    "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
736
    "layout_3_1": ModelConfig(
737
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
738
739
    ),
    "layout_3_2": ModelConfig(
740
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
741
    ),
742
    "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
743
    "layout_4_1": ModelConfig(
744
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
745
746
    ),
    "layout_4_2": ModelConfig(
747
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
748
749
    ),
    "layout_5_0": ModelConfig(
750
        2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
751
752
    ),
    "layout_5_1": ModelConfig(
753
754
755
756
757
758
759
        2,
        2048,
        24,
        128,
        num_gqa_groups=1,
        attn_mask_type="padding_causal_bottom_right",
        window_size=(4, 0),
760
761
762
    ),
    "layout_5_2": ModelConfig(
        2,
763
        2048,
764
765
        24,
        128,
766
767
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal_bottom_right",
768
769
        window_size=(4, 0),
    ),
770
771
772
}


773
774
775
776
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
    get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
777
778
779
780
781
782
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with different QKV layouts"""
783
784
785
    config = model_configs[model]
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")
786
    logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
787
    pad_between_seqs = True
788
789
790
    test_dot_product_attention(
        dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
    )
791
    if get_cudnn_version() >= (9, 3, 0):
792
        logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
793
794
795
796
797
        # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
        pad_between_seqs = False
        test_dot_product_attention(
            dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
        )
798

799

800
def _run_dot_product_attention(
801
802
803
804
805
806
807
808
809
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_layout: str,
    workspace_opt: bool,
    pad_between_seqs: bool,
    is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
810
811
812
    """Run DotProductAttention module with one forward pass and one backward pass"""
    # Set RNG and environment varables
    reset_rng_states()
813
814
815
816
817
818
819
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
820
    _attention_backends["backend_selection_requires_update"] = True
821

822
    # Create seqlens
823
824
825
826
827
828
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
829
            seqlens_kv = seqlens_q
830
        if config.attn_type == "cross":
831
832
833
834
835
836
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
837
838
839
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
840
    else:
841
842
843
844
845
846
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
847
848
849
850
851
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

852
853
854
855
856
857
858
    seqlens_q_after_pad = seqlens_q.clone()
    seqlens_kv_after_pad = seqlens_kv.clone()
    cu_seqlens_q_after_pad = cu_seqlens_q.clone()
    cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
    pad_len = [0] * config.batch_size
    if pad_between_seqs:
        max_pad_len = 3
859
        pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda")  # 3
860
861
862
863
864
        seqlens_q_after_pad = seqlens_q + pad_len
        seqlens_kv_after_pad = seqlens_kv + pad_len
        cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
        cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)

865
866
867
    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
868
        if config.attn_type == "self":
869
870
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
871
872
873
874
875
876
877
878
879
880
881
882
883
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
884
            attention_mask = attention_mask_q.to(device="cuda")
885
        if config.attn_type == "cross":
886
887
888
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
                attention_mask_kv = torch.cat(
                    [
                        attention_mask_kv,
                        torch.Tensor(
                            [False] * seqlens_kv[i]
                            + [True] * (config.max_seqlen_kv - seqlens_kv[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
916
            attention_mask = (
917
918
919
                attention_mask_q.to(device="cuda"),
                attention_mask_kv.to(device="cuda"),
            )
920

921
    alibi_slopes = None
922
923
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
        if config.bias_shape == "1hss":
924
925
926
            alibi_slopes = (
                torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
            )
927
        if config.bias_shape == "bhss":
928
929
930
931
932
            alibi_slopes = (
                torch.randn(config.batch_size, config.num_heads)
                .abs()
                .to(dtype=torch.float32, device="cuda")
            )
933

934
935
    # Create input tensors
    dim_to_num = {
936
937
938
939
940
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
941
942
        "dqk": config.head_dim_qk,
        "dv": config.head_dim_v,
943
944
945
946
947
948
        "t": cu_seqlens_q_after_pad[-1],
        "tg": cu_seqlens_kv_after_pad[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
949
    inp = []
950
    inp_orig = []
951
952
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
953
        if i == 0:
954
            layout = layout.replace("s", "sq")
955
        else:
956
957
958
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
959
960
961
962
        if i == 2:
            layout = layout.replace("d", "dv")
        else:
            layout = layout.replace("d", "dqk")
963
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
964
        tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
965
        tensor_orig = tensor
966
967
        if qkv_format == "thd" and pad_between_seqs:
            tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
968
            if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
969
970
971
972
973
974
975
976
977
978
979
980
981
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_q_after_pad[i - 1],
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_q_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
982
            if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
983
984
985
986
987
988
989
990
991
992
993
994
995
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_kv_after_pad[i - 1],
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_kv_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
996
997
        tensor_count = 1
        split_dim = 0
998
        for dim, l in enumerate(layout.split("_")):
999
1000
1001
1002
1003
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
1004
1005
1006
        tensors_orig = (
            torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
        )
1007
1008
1009
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
1010
                inp_orig.append(tensors_orig[j].squeeze(split_dim))
1011
1012
            else:
                inp.append(tensors[j])
1013
                inp_orig.append(tensors_orig[j])
1014
    for i in range(3):
1015
        inp[i].requires_grad = True
1016
1017
        inp_orig[i].requires_grad = True

1018
    # Create output gradient
1019
1020
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
1021
    qkv_format_kv = qkv_format_kv.replace("d", "dv")
1022
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
1023
1024
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
    out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
1025
    out_grad_orig = out_grad
1026
1027
    if qkv_format == "thd" and pad_between_seqs:
        out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1028
        if qkv_format_kv == "t_h_dv":
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
            for i in range(1, config.batch_size + 1):
                valid_range = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
                out_grad[pad_range[0] : pad_range[1]] = 0.0
                out_grad_orig = torch.cat(
                    [out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
                )
1039

1040
    # Create bias
1041
    if config.attn_bias_type in ["no_bias", "alibi"]:
1042
        bias = None
1043
1044
1045
1046
    if config.attn_bias_type == "post_scale_bias":
        shape = "_".join(config.bias_shape)
        shape = shape.replace("_s_s", "_sq_skv")
        tensor_shape = [dim_to_num[j] for j in shape.split("_")]
1047
        bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
1048
        if config.bias_shape != "1hss":
1049
            bias.requires_grad = False
1050
1051
1052
1053

    # Create RNG
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1054

1055
1056
1057
1058
1059
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    # Set up model
1060
1061
    block = DotProductAttention(
        config.num_heads,
1062
        (config.head_dim_qk, config.head_dim_v),
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
        num_gqa_groups=config.num_gqa_groups,
        attention_dropout=config.dropout_p,
        qkv_format=qkv_format,
        attn_mask_type=config.attn_mask_type,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type=config.attn_type,
1073
        softmax_type=config.softmax_type,
1074
    ).to(dtype=dtype, device="cuda")
1075
1076
    if not is_training:
        block = block.eval()
1077
1078
    if is_training and config.softmax_type != "vanilla":
        block.softmax_offset.requires_grad = True
1079

1080
    # Run a forward and backward pass
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        q = inp_orig[0]
        k = inp_orig[1]
        v = inp_orig[2]
        d_out = out_grad_orig
    if backend == "FusedAttention":
        q = inp[0]
        k = inp[1]
        v = inp[2]
        d_out = out_grad
1091
1092
1093
1094
    out = block(
        q,
        k,
        v,
1095
        window_size=config.window_size,
1096
1097
1098
1099
1100
1101
        attention_mask=attention_mask,
        qkv_format=qkv_format,
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1102
1103
        cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
        cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
1104
1105
1106
1107
1108
1109
1110
        attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=config.attn_bias_type,
        core_attention_bias=bias,
        alibi_slopes=alibi_slopes,
        fast_zero_fill=True,
    )
1111
1112
    if is_training:
        out.backward(d_out)
1113
1114
1115
    d_softmax_offset = None
    if is_training and config.softmax_type != "vanilla":
        d_softmax_offset = block.softmax_offset.grad
1116
1117
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        if is_training:
1118
            return out, (q.grad, k.grad, v.grad, d_softmax_offset)
1119
        else:
1120
            return out, (None, None, None, d_softmax_offset)
1121
    if backend == "FusedAttention":
1122
1123
        if qkv_format == "thd" and pad_between_seqs:
            out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1124
1125
1126
1127
            if is_training:
                q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
            for i in range(1, config.batch_size + 1):
                valid_range_q = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                valid_range_kv = (
                    cu_seqlens_kv_after_pad[i - 1],
                    cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                )
                out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
                if is_training:
                    q_grad_orig = torch.cat(
                        [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
                    )
                    k_grad_orig = torch.cat(
                        [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
                    v_grad_orig = torch.cat(
                        [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
1148
            if is_training:
1149
                return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
1150
            else:
1151
                return out_orig, (None, None, None, d_softmax_offset)
1152
1153
        else:
            if is_training:
1154
                return out, (q.grad, k.grad, v.grad, d_softmax_offset)
1155
            else:
1156
                return out, (None, None, None, d_softmax_offset)
1157

1158

1159
model_configs_te_layer = {
1160
    # test: ModelConfig(b, sq, hq, dqk)
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
    "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "te_1_1": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "te_1_2": ModelConfig(
        2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
    ),
    "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
    "te_2_1": ModelConfig(2, 2048, 16, 64),
    "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
    "te_2_3": ModelConfig(
        1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
    "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
1177
}
1178

1179

1180
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1181
@pytest.mark.parametrize("dtype", param_types)
1182
1183
1184
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
1185
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
1186
1187
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
1188
1189
1190
def test_transformer_layer(
    dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
1191
    """Test TransformerLayer module"""
1192

Tim Moon's avatar
Tim Moon committed
1193
    # Get configs
1194
    config = model_configs[model]
1195
    tols = dict(atol=5e-2, rtol=5e-2)
1196
    workspace_opt = True
1197

1198
    # Test backend availability
1199
    is_training = True
1200
    available_backends, _, fused_attn_backends = get_available_attention_backends(
Tim Moon's avatar
Tim Moon committed
1201
        config,
1202
        qkv_dtype=dtype,
1203
1204
1205
        qkv_layout=(
            qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
        ),
1206
        is_training=is_training,
Tim Moon's avatar
Tim Moon committed
1207
    )
1208
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1209
1210
    if not fused_attn_supported:
        is_training = False
1211
        available_backends, _, fused_attn_backends = get_available_attention_backends(
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
            config,
            qkv_dtype=dtype,
            qkv_layout=(
                qkv_format.replace("hd", "h3d")
                if fused_qkv_params
                else qkv_format.replace("hd", "3hd")
            ),
            is_training=is_training,
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1222
1223
1224

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
1225
        pytest.skip("Less than two backends to compare.")
1226
1227
1228
    # Skip if qkv_format = thd and "padding" not in attn_mask_type
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        pytest.skip("THD requires padding mask.")
Tim Moon's avatar
Tim Moon committed
1229
1230

    # UnfusedDotProductAttention backend
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_format,
            workspace_opt,
            fused_qkv_params,
            RoPE,
1241
            is_training,
1242
        )
Tim Moon's avatar
Tim Moon committed
1243
1244
1245
1246
1247
1248
1249

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FusedAttention",
1250
1251
1252
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1253
1254
            fused_qkv_params,
            RoPE,
1255
            is_training,
Tim Moon's avatar
Tim Moon committed
1256
        )
1257

Tim Moon's avatar
Tim Moon committed
1258
1259
1260
1261
1262
1263
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FlashAttention",
1264
1265
1266
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1267
1268
            fused_qkv_params,
            RoPE,
1269
            is_training,
Tim Moon's avatar
Tim Moon committed
1270
        )
1271

1272
    logging.info(f"[test_transformer_layer]: is_training = {is_training}")
1273
    if unfused_attn_supported and fused_attn_supported:
1274
        logging.info("[test_transformer_layer]: unfused attn vs fused attn")
1275
1276
1277
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
    if unfused_attn_supported and flash_attn_supported:
1278
        logging.info("[test_transformer_layer]: unfused attn vs flash attn")
Tim Moon's avatar
Tim Moon committed
1279
1280
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
1281
    if fused_attn_supported and flash_attn_supported:
1282
        logging.info("[test_transformer_layer]: fused attn vs flash attn")
1283
1284
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
1285

1286

1287
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1288
@pytest.mark.parametrize("dtype", param_types_lean)
1289
1290
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
1291
1292
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format):
hugo-syn's avatar
hugo-syn committed
1293
    """Test TransformerLayer module with miscellaneous settings"""
1294
1295
1296
    ckpt_attn = True
    fused_qkv_params = True
    RoPE = True
1297
1298
1299
    test_transformer_layer(
        dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
    )
1300

1301

1302
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1303
1304
1305
1306
1307
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
    """Test TransformerLayer module with MQA/GQA"""
1308

1309
    def find_factors(x):
1310
1311
1312
1313
1314
        f = []
        for i in range(2, x + 1):
            if x % i == 0:
                f.append(i)
        return f
1315

1316
1317
1318
1319
1320
1321
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    config = model_configs[model]
    num_querys_per_gqa_group = find_factors(config.num_heads)
1322
1323

    for num_q_per_gqa_group in num_querys_per_gqa_group:
1324
1325
1326
1327
        config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
        test_transformer_layer(
            dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
        )
1328

1329

1330
def _run_transformer_layer(
1331
1332
1333
1334
1335
1336
1337
1338
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_format: str,
    workspace_opt: bool,
    fused_qkv_params: bool,
    RoPE: bool,
1339
    is_training: bool,
1340
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
1341
1342
1343
    """Run TransformerLayer module with one forward pass and one backward pass"""

    # Set RNG and environment variables
1344
    reset_rng_states()
1345
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
1346
    os.environ["NVTE_FUSED_ATTN"] = "0"
1347
1348
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
1349
1350
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
1351
    _attention_backends["backend_selection_requires_update"] = True
1352

1353
    # Create input tensor
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
    if qkv_format == "sbhd":
        inp = torch.randn(
            config.max_seqlen_q,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.max_seqlen_kv,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1371
    if qkv_format == "bshd":
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
        inp = torch.randn(
            config.batch_size,
            config.max_seqlen_q,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.batch_size,
            config.max_seqlen_kv,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1388
1389

    # Create seqlens
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
1406
    else:
1407
1408
1409
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
    if qkv_format == "thd":
        inp = torch.randn(
            cu_seqlens_q[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            cu_seqlens_kv[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1432
1433
1434
1435
1436
1437
1438

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
1439
    drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
1440

1441
    # Create bias
1442
    bias = None
1443
1444
1445
1446
1447
1448
1449
1450
1451
    if config.attn_bias_type == "post_scale_bias":
        bias = torch.randn(
            1,
            config.num_heads,
            config.max_seqlen_q,
            config.max_seqlen_kv,
            dtype=dtype,
            device="cuda",
        )
1452
1453
1454
1455

    # Create RoPE
    rotary_pos_emb = None
    if RoPE:
1456
        PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
1457
        rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1458
1459

    # Set up model
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_heads,
        num_gqa_groups=config.num_gqa_groups,
        layernorm_epsilon=1e-5,
        hidden_dropout=0.0,
        attention_dropout=config.dropout_p,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        layer_number=layer_number,
1471
        kv_channels=config.head_dim_qk,
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
        self_attn_mask_type=config.attn_mask_type,
        tp_group=None,
        tp_size=1,
        params_dtype=dtype,
        get_rng_state_tracker=None,
        fuse_wgrad_accumulation=False,
        seq_length=config.max_seqlen_q,
        micro_batch_size=config.batch_size,
        sequence_parallel=False,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
1483
        layer_type="encoder" if config.attn_type == "self" else "decoder",
1484
1485
1486
1487
1488
1489
1490
1491
1492
        drop_path_rate=drop_path_rates[layer_number - 1],
        set_parallel_mode=True,
        fuse_qkv_params=fused_qkv_params,
        zero_centered_gamma=False,
        qkv_weight_interleaved=False,
        ub_tp_comm_overlap=False,
        bias=True,
        attn_input_format=qkv_format,
    ).to(dtype=dtype, device="cuda")
1493
1494
    if not is_training:
        block = block.eval()
1495

1496
1497
1498
    # Create ALiBi slopes
    alibi_slopes = None
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
1499
        alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
1500

1501
    # Run a forward and backward pass
1502
1503
    out = block(
        inp,
1504
        self_attn_mask_type=config.attn_mask_type,
1505
1506
        encoder_output=inp_enc if config.attn_type == "cross" else None,
        enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
1507
1508
1509
        checkpoint_core_attention=False,
        rotary_pos_emb=rotary_pos_emb,
        core_attention_bias_type=config.attn_bias_type,
1510
        core_attention_bias=bias,
1511
        alibi_slopes=alibi_slopes,
1512
1513
1514
1515
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1516
    )
1517
1518
1519
    if is_training:
        loss = out.sum()
        loss.backward()
1520
1521

    return out, inp.grad
1522
1523


1524
model_configs_fp8_extra_state = {
1525
    # test: ModelConfig(b, sq, hq, dqk)
1526
1527
1528
1529
1530
1531
1532
1533
1534
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
}


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
1535
1536
def test_dpa_fp8_extra_state(model, dtype):
    """Test DotProductAttention module in FP8 with checkpointing"""
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
    config = model_configs_fp8_extra_state[model]
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="sb3hd",
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not fused_attn_supported and not flash_attn_supported:
        pytest.skip("No attention backend available.")

1550
1551
1552
    outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False)
    outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True)
    outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state(
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
        dtype, config, mimic_v1_6=True, checkpoint=True
    )

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )


1574
1575
def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
    """Run DotProductAttention module in FP8 with checkpointing"""
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
    steps = 10
    path = "checkpoint.pt"
    fp8_enabled = True
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_enabled,
        fp8_mha=False,
    )

    reset_rng_states()
    hidden_states = torch.randn(
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

    def get_model(dtype, config):
        sigma = 0.023
        init_method = init_method_normal(sigma)
        output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1601
        with quantized_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
            block = TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_heads,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                hidden_dropout=0.0,
                attention_dropout=0.0,
                fuse_qkv_params=True,
                params_dtype=dtype,
                device="cuda",
            )
        return block

    block = get_model(dtype, config)
    for i in range(steps // 2):
1618
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    if checkpoint:
        sd = block.state_dict()
        if mimic_v1_6:
            sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
                "self_attention.core_attention._extra_state"
            ]
            del sd["self_attention.core_attention._extra_state"]
        torch.save(sd, path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

        _cpu_rng_state_new = torch.get_rng_state()
        _cuda_rng_state_new = torch.cuda.get_rng_state()

        del block
        block = get_model(dtype, config)
        block.load_state_dict(torch.load(path, weights_only=False))
        torch.set_rng_state(_cpu_rng_state_new)
        torch.cuda.set_rng_state(_cuda_rng_state_new)

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for i in range((steps + 1) // 2):
1653
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [output, hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)

    return outputs


1671
model_configs_fp8_vs_f16 = {
1672
    # test: ModelConfig(b, sq, hq, dqk)
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
    "fp8_9": ModelConfig(2, 2048, 16, 128),
    "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
    "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
    "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
    "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
    "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
    "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
    "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
    "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
    "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
1685
}
1686

1687
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
1688
1689
1690
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]

1691

1692
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1693
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1694
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
1695
1696
1697
1698
1699
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1700
@pytest.mark.parametrize("RoPE", [True, False])
1701
@pytest.mark.parametrize("is_training", [True, False])
1702
1703
1704
1705
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_mha_fp8_vs_f16(
    dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
):
1706
    """Test MultiHeadAttention module in FP8"""
1707
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1708
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1709
1710
    config = model_configs_fp8_vs_f16[model]

1711
    # Test backend availability
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
            fp8_mha=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
            fp8_mha=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
1729
1730
1731
1732
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_format.replace("hd", "h3d"),
1733
1734
        fp8=True,
        fp8_meta=fp8_meta,
1735
1736
1737
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1738
1739
    if flash_attn_supported + fused_attn_supported < 1:
        pytest.skip("No FP8 attention backend available.")
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_format.replace("hd", "h3d"),
            is_training=is_training,
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")

    if flash_attn_supported:
1752
1753
1754
1755
1756
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1757
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1758
        )
1759

1760
1761
1762
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
    _attention_backends["backend_selection_requires_update"] = True
1763
    logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
1764
    fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1765
        dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1766
    )
1767
1768

    logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
1769
    fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
1770
        dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1771
1772
    )

1773
1774
1775
    atol = 5e-1
    rtol = 5e-1
    rmse_tol = 0.15
1776
    if flash_attn_supported:
1777
1778
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
1779
        compare_and_assert(
1780
1781
1782
1783
1784
1785
1786
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
1787
            True,
1788
        )
1789
1790
    logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
    logging.debug("========== {:^25s} ==========".format("forward output"))
1791
    compare_and_assert(
1792
1793
1794
1795
1796
1797
1798
        fused_attn_fwd_fp8,
        fused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "fused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
1799
        True,
1800
    )
1801

1802
1803
1804
    if is_training:
        for i in range(len(param_names[:1])):
            logging.debug("========== {:^25s} ==========".format(param_names[i]))
1805
            compare_and_assert(
1806
1807
1808
1809
1810
1811
1812
                fused_attn_bwd_fp8[i],
                fused_attn_bwd_f16[i],
                f"fused_attn_bwd_fp8[{i}]",
                f"fused_attn_bwd_f16[{i}]",
                atol,
                rtol,
                rmse_tol,
1813
                True,
1814
1815
            )

1816

1817
1818
1819
def _run_mha_fp8_vs_f16(
    dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
):
1820
    """Run MultiHeadAttention module in FP8"""
1821
1822
1823
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1824

1825
1826
1827
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
Tim Moon's avatar
Tim Moon committed
1828

1829
    with quantized_model_init(enabled=fp8_mha, recipe=fp8_recipe):
1830
1831
1832
1833
        rotary_pos_emb = None
        if RoPE:
            PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
            rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1834
        mha = MultiheadAttention(
1835
1836
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_heads,
1837
            kv_channels=config.head_dim_qk,
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            layer_number=1,
            bias=True,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            params_dtype=dtype,
            input_layernorm=input_layernorm,
            fuse_qkv_params=True,
            attention_type="self",
            qkv_weight_interleaved=True,
            qkv_format=qkv_format,
1849
        ).to(dtype=dtype, device="cuda")
1850
1851
        if not is_training:
            mha = mha.eval()
1852

1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
1873
1874
1875
1876
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
1877

1878
    dim_to_num = {
1879
1880
1881
1882
1883
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
1884
        "d": config.head_dim_qk,
1885
1886
1887
1888
1889
1890
1891
1892
1893
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
    layout = "_".join(qkv_format)
    layout = layout.replace("s", "sq")
    tensor_shape = [dim_to_num[j] for j in layout.split("_")]
1894
1895
    tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
    hidden_states = tensor.view(*tensor.shape[:-2], -1)
1896
1897
    if is_training:
        hidden_states.requires_grad = True
1898
1899
1900
    tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
    out_grad = tensor.view(*tensor.shape[:-2], -1)

1901
    with autocast(enabled=fp8_mha, recipe=fp8_recipe):
1902
1903
        out = mha(
            hidden_states,
1904
1905
1906
1907
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
            is_first_microbatch=None,
1908
            rotary_pos_emb=rotary_pos_emb,
1909
1910
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
1911
        )
1912
1913
    if is_training:
        out.backward(out_grad)
Tim Moon's avatar
Tim Moon committed
1914

1915
    param_names = []
1916
    param_names.append("hidden_states.grad")
1917
1918
1919
1920
    params = []
    params.append(hidden_states)
    for name, param in mha.named_parameters():
        if param.requires_grad:
1921
            param_names.append(name + ".grad")
1922
            params.append(param)
1923

1924
1925
1926
    if is_training:
        return out, param_names, tuple(x.grad for x in params)
    return out, param_names, tuple(None for x in params)
1927

1928

1929
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1930
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1931
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
1932
1933
1934
1935
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1936
@pytest.mark.parametrize("is_training", [True, False])
1937
1938
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
1939
    """Test DotProductAttention module in FP8"""
1940
1941
    config = model_configs_fp8_vs_f16[model]

1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
    # TODO(cyang): think of another way to verify dropout results
    # test cuDNN FP8 dropout
    # 1. we modify the config here to not affect mha_fp8_vs_f16 tests
    # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an
    #    indirect verification method, we create Q/K/V as all 1s and check if O is all 1s
    # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell
    # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type:
    #    if get_device_compute_capability() >= (10, 0):
    #        config.dropout_p = 0.1

1952
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1953
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1954
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"
1955

1956
    # Test backend availability
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
1972
1973
1974
1975
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_layout,
1976
1977
        fp8=True,
        fp8_meta=fp8_meta,
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if flash_attn_supported + fused_attn_supported < 1:
        pytest.skip("No FP8 attention backend available.")
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            is_training=is_training,
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")

    if flash_attn_supported:
1997
1998
1999
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
2000
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
2001
        flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
            dtype, config, True, qkv_layout, is_training, fp8_recipe
        )

    if unfused_attn_supported:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
        unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
            dtype, config, True, qkv_layout, is_training, fp8_recipe
2012
        )
2013

2014
2015
2016
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
    _attention_backends["backend_selection_requires_update"] = True
2017
    logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
2018
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2019
        dtype, config, True, qkv_layout, is_training, fp8_recipe
2020
    )
2021

2022
2023
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
2024
2025
    if config.dropout_p == 0.0:
        # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
2026
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
2027
        fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
2028
            dtype, config, False, qkv_layout, is_training, fp8_recipe
2029
        )
2030

2031
2032
    atol = 5e-1
    rtol = 5e-2
2033
    rmse_tol = 0.11
2034
    bwd_names = ["dq", "dk", "dv"]
2035
    if flash_attn_supported:
2036
2037
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2038
        compare_and_assert(
2039
2040
2041
2042
2043
2044
2045
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2046
            True,
2047
        )
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
    if unfused_attn_supported:
        logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
        compare_and_assert(
            unfused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "unfused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
            True,
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
                compare_and_assert(
                    unfused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"unfused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                    True,
                )
2074
2075
2076
2077
2078
2079
    if config.dropout_p != 0.0:
        # test cuDNN FP8 dropout
        assert torch.all(
            fused_attn_fwd_fp8 == 1
        ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
    else:
2080
2081
        logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2082
        compare_and_assert(
2083
2084
2085
2086
2087
2088
2089
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2090
            True,
2091
2092
2093
2094
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
2095
                compare_and_assert(
2096
2097
2098
2099
2100
2101
2102
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
2103
                    True,
2104
                )
2105
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0"
2106
2107


2108
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe):
2109
    """Run DotProductAttention module in FP8"""
2110
2111
2112
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
2113

2114
2115
2116
2117
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

2118
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
2119
    with quantized_model_init(enabled=fp8_dpa):
2120
2121
        dpa = DotProductAttention(
            config.num_heads,
2122
            config.head_dim_qk,
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            sequence_parallel=False,
            tp_size=1,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            tp_group=None,
            layer_number=1,
            attention_type="self",
            qkv_format=qkv_format,
        ).to(dtype=dtype, device="cuda")
2133
2134
        if not is_training:
            dpa = dpa.eval()
2135

2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
2156
2157
2158
2159
2160
2161
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

    dim_to_num = {
2162
2163
2164
2165
2166
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2167
        "d": config.head_dim_qk,
2168
2169
2170
2171
2172
2173
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
2174
    inp = []
2175
2176
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
2177
        if i == 0:
2178
            layout = layout.replace("s", "sq")
2179
        else:
2180
2181
2182
2183
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2184
2185
2186
2187
2188
        if config.dropout_p == 0.0:
            tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
        else:
            # test cuDNN FP8 dropout
            tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda")
2189
2190
        tensor_count = 1
        split_dim = 0
2191
        for dim, l in enumerate(layout.split("_")):
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
        inp[i].requires_grad = True

2205
2206
2207
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
2208
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
2209
    out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
2210

2211
    with autocast(enabled=fp8_dpa, recipe=fp8_recipe):
2212
2213
2214
2215
        out = dpa(
            inp[0],
            inp[1],
            inp[2],
2216
2217
2218
2219
2220
2221
2222
2223
            qkv_format=qkv_format,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
2224
            fp8_output=fp8_dpa,
2225
        )
2226
2227
    if is_training:
        out.backward(out_grad)
2228

2229
2230
2231
    if is_training:
        return out, (inp[0].grad, inp[1].grad, inp[2].grad)
    return out, (None, None, None)
2232
2233
2234


model_configs_fp8 = {
2235
    # test: ModelConfig(b, sq, hq, dqk)
2236
2237
2238
2239
2240
2241
2242
2243
    "fp8_1": ModelConfig(1, 512, 1, 64),
    "fp8_2": ModelConfig(4, 512, 16, 64),
    "fp8_3": ModelConfig(1, 2048, 1, 128),
    "fp8_4": ModelConfig(2, 2048, 24, 128),
    "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
    "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
    "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
    "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
2244
2245
}
param_types_fp8 = [torch.float16, torch.bfloat16]
2246
2247
2248
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
2249
2250


2251
2252
2253
2254
2255
2256
2257
2258
@pytest.mark.skipif(
    (
        get_cudnn_version() < (8, 9, 3)
        if cudnn_frontend_version == 0
        else get_cudnn_version() < (9, 2, 1)
    ),
    reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
2259
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
2260
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
    """Test FP8 dot product attention implementations based on cuDNN frontend
    v0.9 and v1.0+. Each test compares results from a custom implementation of
    an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
    implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
    Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""

    config = model_configs_fp8[model]

2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not (fused_attn_backends and unfused_attn_supported):
        pytest.skip("Not enough backends to run this test with.")

2284
2285
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
    unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
2286

2287
2288
    atol = 5e-1
    rtol = 5e-1
2289
    rmse_tol = 0.13
2290
    compare_and_assert(
2291
2292
2293
2294
2295
2296
2297
        fused_attn_fwd_fp8,
        unfused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "unfused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
2298
        True,
2299
    )
2300
    compare_and_assert(
2301
2302
2303
2304
2305
2306
2307
        fused_attn_bwd_fp8,
        unfused_attn_bwd_f16,
        "fused_attn_bwd_fp8",
        "unfused_attn_bwd_f16",
        atol,
        rtol,
        rmse_tol,
2308
        True,
2309
    )
2310
2311
2312
2313
2314


def _run_custom_mha_fp8(dtype, config, backend):
    """Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
2315
    reset_rng_states()
2316
2317
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
2318
2319
2320
2321
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2322
    _attention_backends["backend_selection_requires_update"] = True
2323

2324
2325
2326
    inp = 0.0001 * torch.randint(
        -100,
        100,
2327
        (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
2328
2329
2330
2331
2332
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
2333
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2334
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2335

2336
    out_grad = 0.01 * torch.randn(
2337
        config.batch_size * config.max_seqlen_q,
2338
        config.num_heads * config.head_dim_qk,
2339
2340
2341
2342
        dtype=dtype,
        device="cuda",
    )
    torch.save(out_grad, "out_grad.pt")
2343
2344
2345
2346
2347
2348
2349
2350

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

2351
    mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
2352
    with autocast(enabled=True, recipe=fp8_recipe):
2353
        out = mha(inp, cu_seqlens, config.max_seqlen_q)
2354
    out.backward(out_grad)
2355

2356
    out = torch.load("out.pt")
2357
2358
2359
2360
    dqkv = torch.load("dqkv.pt")
    return (
        out.view(config.batch_size, config.max_seqlen_q, -1),
        dqkv.view(
2361
            config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
2362
2363
        ).contiguous(),
    )
2364

2365

2366
2367
2368
def _run_ref_mha_f16(dtype, config, backend):
    """Run reference F16 FusedAttention. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
2369
2370
2371
2372
2373
2374
2375

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2376
    _attention_backends["backend_selection_requires_update"] = True
2377

2378
    inp = torch.load("qkv.pt").to(device="cuda")
2379
2380
2381
    inp.requires_grad = True
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2382
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2383
2384
2385
    out_grad = (
        torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
    )
2386
2387
2388
2389
2390
2391
2392

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
2393

2394
2395
    block = DotProductAttention(
        config.num_heads,
2396
        config.head_dim_qk,
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
        attention_dropout=config.dropout_p,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type="self",
        qkv_format="bshd",
    ).to(dtype=dtype, device="cuda")

    q = inp[:, :, 0, :, :]
    k = inp[:, :, 1, :, :]
    v = inp[:, :, 2, :, :]
2410
2411
2412
2413
    out = block(q, k, v, attn_mask_type=config.attn_mask_type)
    out.backward(out_grad)

    return out, inp.grad
2414
2415
2416
2417
2418
2419
2420


_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

2421
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
2422
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
2423
2424
2425
2426
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
2427
2428


2429
class _custom_mha_fp8(torch.autograd.Function):
2430
2431
2432
2433
2434
2435
2436
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
2437
        num_heads: int,
2438
2439
2440
2441
2442
2443
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        workspace: torch.Tensor,
        is_training: bool,
2444
        mask_type: str,
2445
        quantizers: list[Quantizer],
2446
    ) -> torch.Tensor:
2447
        qkv_dtype = inp.dtype
2448
2449
2450

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
2451
        h = num_heads
2452
2453
2454
        d = in_features // h
        b = cu_seqlens.numel() - 1

2455
2456
2457
2458
2459
2460
2461
2462
        input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
        qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
        qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
        dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
        dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
2463

2464
        inp_fp8 = input_quantizer(inp)
2465

2466
        qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight)
2467

2468
        qkv, *_ = ext.general_gemm(
2469
            qkv_weight_fp8,
2470
            inp_fp8,
2471
2472
            workspace,
            bias=qkv_bias,
2473
2474
            out_dtype=qkv_weight_fp8.dtype,
            quantization_params=qkv_quantizer,
2475
2476
            use_split_accumulator=_2X_ACC_FPROP,
        )
2477
        qkv = qkv.view(-1, 3, h, d)
2478
        qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
2479
        torch.save(qkv_fp16, "qkv.pt")
2480
        if cudnn_frontend_version == 1:
2481
            qkv = qkv.view(b, max_s, 3, h, d)  # bs3hd
2482
2483

        # FMHA
2484
2485
2486
2487
2488
2489
2490
2491
        q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
        k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
        v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
        q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
        k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
        v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)

        out, aux_ctx_tensors = fused_attn_fwd(
2492
2493
2494
2495
2496
            is_training,
            max_s,
            max_s,
            cu_seqlens,
            cu_seqlens,
2497
2498
2499
2500
            q,
            k,
            v,
            qkv_dtype,
2501
2502
2503
2504
2505
2506
2507
2508
            FusedAttnBackend["FP8"],
            attn_scale=None,
            dropout=p_dropout,
            fast_zero_fill=fast_zero_fill,
            qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
            attn_bias_type="no_bias",
            attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
            rng_gen=None,
2509
2510
            o_quantizer=o_quantizer,
            s_quantizer=s_quantizer,
2511
        )
2512

2513
2514
        tensors_to_save, tensor_objects = prepare_for_saving(
            q, k, v, inp_fp8, qkv_weight_fp8, workspace, out
2515
        )
2516
2517
2518

        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
2519
        ctx.aux_ctx_tensors = aux_ctx_tensors
2520
        ctx.qkv_dtype = qkv_dtype
2521
2522
2523
2524
2525
2526
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.hidden_size = in_features
2527
        ctx.num_heads = num_heads
2528
2529
        ctx.mask_type = mask_type
        ctx.dtype = inp.dtype
2530

2531
2532
2533
2534
2535
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = s_quantizer

2536
        out = out.view(-1, in_features)  # (bs)(hd)
2537
        out_fp16 = out.dequantize()
2538
        torch.save(out_fp16, "out.pt")  # (bs)(hd)
2539
        return out_fp16
2540
2541

    @staticmethod
2542
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
2543
        with torch.cuda.nvtx.range("_DPA"):
2544
2545
2546
2547
            saved_tensors = ctx.saved_tensors
            (q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved(
                ctx.tensor_objects, saved_tensors
            )
2548

2549
            proj_dgrad = ctx.dO_quantizer(grad_output)
2550
            fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2551

2552
            dq, dk, dv, *rest = fused_attn_bwd(
2553
2554
2555
2556
                ctx.max_s,
                ctx.max_s,
                ctx.cu_seqlens,
                ctx.cu_seqlens,
2557
2558
2559
                q,
                k,
                v,
2560
2561
                out,
                proj_dgrad.view_as(out),
2562
                ctx.qkv_dtype,
2563
2564
2565
2566
2567
                fp8_dtype_backward,
                ctx.aux_ctx_tensors,
                FusedAttnBackend["FP8"],
                None,
                None,
2568
2569
2570
                ctx.S_quantizer,
                ctx.dP_quantizer,
                ctx.dQKV_quantizer,
2571
2572
2573
2574
2575
2576
2577
                attn_scale=None,
                dropout=ctx.p_dropout,
                fast_zero_fill=ctx.fast_zero_fill,
                qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
                attn_bias_type="no_bias",
                attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
            )
2578
            dim = 2 if cudnn_frontend_version == 1 else 1
2579
2580
            dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
            dqkv_shape = list(dq._data.shape)
2581
            dqkv_shape.insert(dim, 3)
2582
            dqkv_stride = list(dq._data.stride())
2583
            dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
2584
2585
2586
            dqkv.set_(
                dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride
            )  # bs3hd
2587

2588
            dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
2589
2590
            dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape)
            dqkv_c_fp16 = dqkv_c.dequantize()
2591
            torch.save(dqkv_c_fp16, "dqkv.pt")
2592

2593
2594
2595
            qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer)
            dqkv_c._transpose = None
            dqkv_c._create_transpose()
2596
2597

            # QKV DGRAD
2598
2599
            qkv_dgrad, *_ = ext.general_gemm(
                qkv_weight_fp8,
2600
                dqkv_c,
2601
                workspace,
2602
                ctx.dtype,
2603
                use_split_accumulator=_2X_ACC_DGRAD,
2604
                layout="NN",
2605
            )
2606

2607
            # QKV WGRAD
2608
2609
2610
            qkv_wgrad, *_ = ext.general_gemm(
                inp_fp8,
                dqkv,
2611
                workspace,
2612
                ctx.dtype,
2613
                use_split_accumulator=_2X_ACC_WGRAD,
2614
                layout="NT",
2615
2616
            )

2617
2618
        return (
            qkv_dgrad,
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2630
2631
            None,
        )
2632

2633

2634
class Custom_MHA_FP8(TransformerEngineBaseModule):
2635
    def __init__(self, config, params_dtype: torch.dtype = torch.float32):
2636
2637
        super().__init__()
        self.p_dropout = config.dropout_p
2638
        self.h = config.num_heads
2639
        self.hidden_size = config.hidden_size
2640
        self.head_dim = config.head_dim_qk
2641
        self.fast_zero_fill = True
2642
        self.mask_type = config.attn_mask_type
2643

Tim Moon's avatar
Tim Moon committed
2644
        self.qkv_weight = torch.nn.Parameter(
2645
2646
2647
2648
2649
2650
2651
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
Tim Moon's avatar
Tim Moon committed
2652
        self.qkv_bias = torch.nn.Parameter(
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)
        self.workspace = torch.empty(
            _CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
        )

    def forward(
2667
2668
2669
2670
        self,
        inp: torch.Tensor,
        cu_seqlens,
        max_s,
2671
    ) -> torch.Tensor:
2672
        with self.prepare_forward(inp, num_gemms=3) as inp:
2673
            out = _custom_mha_fp8.apply(
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
                self.workspace,
2684
                self.training,
2685
                self.mask_type,
2686
                self.quantizers,
2687
            )
2688
        return out