test_attention.py 96.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
import logging
import math
Tim Moon's avatar
Tim Moon committed
6
import os
7
8
import sys
import pathlib
9
from typing import Any, Dict, Tuple, Union
Tim Moon's avatar
Tim Moon committed
10

11
import pytest
Tim Moon's avatar
Tim Moon committed
12
import torch
13

Tim Moon's avatar
Tim Moon committed
14
from transformer_engine.common import recipe
15
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
16
from transformer_engine.pytorch.attention.dot_product_attention import (
Tim Moon's avatar
Tim Moon committed
17
    DotProductAttention,
18
19
    _attention_backends,
)
20
21
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
22
    FlashAttentionUtils,
23
    check_set_window_size,
Tim Moon's avatar
Tim Moon committed
24
)
25
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
Tim Moon's avatar
Tim Moon committed
26
27
28
29
30
31
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    FusedAttnBackend,
    fused_attn_bwd,
    fused_attn_fwd,
)
32
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
Tim Moon's avatar
Tim Moon committed
33
import transformer_engine.pytorch.fp8 as fp8
34
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
35
from transformer_engine.pytorch.utils import (
Tim Moon's avatar
Tim Moon committed
36
    get_device_compute_capability,
37
38
    init_method_normal,
    scaled_init_method_normal,
39
    is_bf16_compatible,
40
)
41
from transformer_engine.pytorch.utils import get_cudnn_version
42
import transformer_engine_torch as tex
43
44
45
46
47
from transformer_engine.pytorch.tensor.quantized_tensor import (
    Quantizer,
    prepare_for_saving,
    restore_from_saved,
)
48

49
50
51
52
53
54
55
56
57
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
    reset_rng_states,
    ModelConfig,
    dtype_tols,
    get_available_attention_backends,
)

58
# Only run FP8 tests on H100
Tim Moon's avatar
Tim Moon committed
59
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
60
61

seed = 1234
62
63
# Reset RNG states
reset_rng_states()
64
65
66
67
68
69
70


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    fp8.FP8GlobalStateManager.reset()

71

72
model_configs_base = {
73
    #     test:             b,  h, hg,  d,  sq, skv,   p,      mask,      bias
74
75
76
77
78
79
80
81
82
83
84
85
    "base_1_0": ModelConfig(8, 128, 16, 64),
    "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
    "base_2_0": ModelConfig(2, 2048, 24, 128),
    "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
    "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
    "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
    "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
    "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
    "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
    "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
    "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
86
87
}

88

89
param_types = [torch.float16]
90
if is_bf16_compatible():  # bf16 requires sm_80 or higher
91
92
93
    param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]

94

95
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
96
@pytest.mark.parametrize("dtype", param_types)
97
98
99
100
101
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
102
@pytest.mark.parametrize("swa", [False])
103
@pytest.mark.parametrize("pad_between_seqs", [False])
104
105
106
def test_dot_product_attention(
    dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
):
107
    """Test DotProductAttention module"""
108

Tim Moon's avatar
Tim Moon committed
109
    # Get configs
110
    tols = dict(atol=1e-3, rtol=1e-3)
Tim Moon's avatar
Tim Moon committed
111
    if dtype == torch.bfloat16:
112
        tols = dict(atol=1.5e-2, rtol=1.5e-2)
113
    config = model_configs[model]
114
    is_mla = config.head_dim_qk != config.head_dim_v
115
    is_mqa_gqa = config.num_heads != config.num_gqa_groups
116
117
    if qkv_layout is None:
        if config.attn_type == "self":
118
            qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
119
        else:
120
            qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
121
    if "3" in qkv_layout and config.attn_type == "cross":
122
        pytest.skip("No need to test this layout for cross attention")
Tim Moon's avatar
Tim Moon committed
123

124
125
126
    if config.window_size == (-1, -1) and swa:
        config.window_size = [2, 2]
    config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
127
128

    is_training = True
129
    available_backends, _, fused_attn_backends = get_available_attention_backends(
130
        config,
131
        qkv_dtype=dtype,
132
        qkv_layout=qkv_layout,
133
        window_size=config.window_size,
134
        pad_between_seqs=pad_between_seqs,
135
        is_training=is_training,
Tim Moon's avatar
Tim Moon committed
136
    )
137
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
138
139
    if not fused_attn_supported:
        is_training = False
140
        available_backends, _, fused_attn_backends = get_available_attention_backends(
141
142
143
144
145
146
147
148
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            window_size=config.window_size,
            pad_between_seqs=pad_between_seqs,
            is_training=is_training,
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
149

150
151
    # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
    # mannually pads and unpads the input and output of FlashAttention for testing purposes
152
153
    if (
        pad_between_seqs
154
        and FlashAttentionUtils.is_installed
155
156
157
158
        and not (
            config.max_seqlen_q != config.max_seqlen_kv
            and config.attn_mask_type in ["causal", "padding_causal"]
        )
159
        and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
160
    ):
161
        flash_attn_supported = True
162
163
164

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
165
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
166
167

    # UnfusedDotProductAttention backend
168
169
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
170
171
172
173
174
175
176
177
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
178
        )
Tim Moon's avatar
Tim Moon committed
179
180
181

    # FusedAttention backend
    if fused_attn_supported:
182
        if len(fused_attn_backends) == 1:
183
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
184
185
186
187
188
189
190
191
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
192
            )
193
        if len(fused_attn_backends) == 2:
194
195
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
196
197
198
199
200
201
202
203
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
204
205
206
            )
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
            fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
207
208
209
210
211
212
213
214
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
215
            )
216

Tim Moon's avatar
Tim Moon committed
217
218
    # FlashAttention backend
    if flash_attn_supported:
219
220
221
        from torch.utils.cpp_extension import IS_HIP_EXTENSION
        if IS_HIP_EXTENSION and config.head_dim_qk < config.head_dim_v:
            pytest.skip("FlashAttention on ROCm does not support MLA with head_dim_qk < head_dim_v")
Tim Moon's avatar
Tim Moon committed
222
        flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
223
224
225
226
227
228
229
230
            dtype,
            config,
            "FlashAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
Tim Moon's avatar
Tim Moon committed
231
        )
232

233
    logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
234
    if unfused_attn_supported and flash_attn_supported:
235
        logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
236
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
237
        for i, _ in enumerate(flash_attn_bwd):
238
            torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
239
240
241
242
243
    if unfused_attn_supported and fused_attn_supported:
        logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        for i, _ in enumerate(unfused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
244
    if fused_attn_supported and flash_attn_supported:
245
        logging.info("[test_dot_product_attention]: fused attn vs flash attn")
246
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
247
        for i, _ in enumerate(flash_attn_bwd):
248
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
249
    if fused_attn_supported and len(fused_attn_backends) == 2:
250
        logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
251
        torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
252
        for i, _ in enumerate(fused_attn_bwd):
253
254
            torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)

255

256
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
257
258
259
260
261
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
    """Test DotProductAttention module with checkpointing"""
262
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
263

264

265
266
model_configs_mla = {
    #    test:             b,  h, hg, dqk, sq, skv,   p,      mask,      bias   # attn , backend
267
268
269
270
    "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),  # self , 0
    "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),  # self , 1
271
    "mla_2_1": ModelConfig(
272
        1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
273
    ),  # cross, 1
274
    "mla_2_2": ModelConfig(
275
        1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
276
    ),  # cross, 1
277
278
279
    "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),  # inference
    "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),  # inference
    "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),  # inference
280
281
    "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),  # inference
    "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),  # inference
282
283
284
285
286
287
288
289
290
291
292
293
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
    """Test DotProductAttention module with Multi-Latent Attention (MLA)"""
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


294
295
model_configs_mask = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,             mask,      bias
296
297
298
299
300
301
    "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
    "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "mask_2_1": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
302
    ),
303
304
305
306
307
308
309
310
311
312
    "mask_2_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
    ),
    "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
313
    "mask_5_1": ModelConfig(
314
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
315
316
    ),
    "mask_5_2": ModelConfig(
317
318
319
320
321
322
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_7_0": ModelConfig(
        2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
323
    ),
324
325
326
327
328
329
330
    "mask_7_1": ModelConfig(
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
    ),
    "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
331
    "mask_10_0": ModelConfig(
332
        2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
333
    ),
334
    "mask_10_1": ModelConfig(
335
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
336
    ),
337
}
338

339

340
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
341
342
343
344
345
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
    """Test DotProductAttention module with different mask types"""
346
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
347

348

349
350
model_configs_bias = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
    "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),  # skipped
    "bias_1_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"
    ),  # skipped
    "bias_2_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
    ),  # skipped
    "bias_2_1": ModelConfig(
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
    ),  # skipped
371
    "bias_2_2": ModelConfig(
372
        4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
373
374
    ),  # skipped
    "bias_2_3": ModelConfig(
375
376
377
378
379
380
381
382
383
384
385
386
387
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
    ),  # skipped
    "bias_2_4": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"
    ),  # skipped
    "bias_2_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
388
    ),  # skipped
389
390
391
392
393
394
395
396
397
    "bias_3_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_1": ModelConfig(
        2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_2": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
398
    "bias_3_3": ModelConfig(
399
400
401
402
403
404
405
406
407
408
409
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="causal",
        attn_bias_type="post_scale_bias",
    ),  # skipped
    "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
    "bias_3_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
410
411
    ),  # skipped
    "bias_4_0": ModelConfig(
412
        4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
413
414
    ),  # skipped
    "bias_4_1": ModelConfig(
415
416
417
418
419
420
421
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
422
423
    ),  # skipped
    "bias_4_2": ModelConfig(
424
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
425
426
    ),  # skipped
    "bias_4_3": ModelConfig(
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),  # skipped
    "bias_4_4": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
    ),  # skipped
    "bias_4_5": ModelConfig(
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="alibi",
446
    ),  # skipped
447
}
448

449

450
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
451
452
453
454
455
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types"""
456
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
457

458

459
460
model_configs_bias_shapes = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,
461
462
463
464
465
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
    "bias_1_4": ModelConfig(
466
        4,
467
468
        2048,
        24,
469
        128,
470
471
472
473
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="1hss",
        alibi_type="custom",
474
475
    ),
    "bias_1_5": ModelConfig(
476
477
478
479
480
481
482
483
        2,
        2048,
        24,
        128,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="bhss",
        alibi_type="custom",
484
    ),
485
486
}

487

488
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
489
490
491
492
493
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types and shapes"""
494
495
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)

496

497
model_configs_swa = {
498
    #    test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
499
500
501
502
503
504
505
506
507
508
509
510
    "swa_1_1": ModelConfig(2, 2048, 16, 64),
    "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
    "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
    "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "swa_3_2": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
    ),
    "swa_3_3": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
511
    ),
512
513
514
515
516
517
518
    "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
    "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
519
    "swa_6_2": ModelConfig(
520
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
521
522
    ),
    "swa_6_3": ModelConfig(
523
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
524
    ),
525
}
526
527


528
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
529
530
531
532
533
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
    """Test DotProductAttention module with sliding window attention"""
534
535
    test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)

536

537
538
model_configs_alibi_slopes = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,      mask,    bias, alibi_type
539
540
541
542
543
544
545
546
547
548
549
550
551
    "alibi_1_0": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
    ),
    "alibi_1_1": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="vanilla",
    ),
552
    "alibi_2_0": ModelConfig(
553
        2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
554
555
    ),
    "alibi_2_1": ModelConfig(
556
557
558
559
560
561
562
563
        1,
        1024,
        24,
        128,
        max_seqlen_kv=2048,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="custom",
564
    ),
565
}
566
567


568
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
569
570
571
572
573
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
    """Test DotProductAttention module with ALiBi slopes"""
574
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
575

576

577
qkv_layouts = [
578
579
580
581
582
583
584
585
586
587
588
    "sb3hd",
    "sbh3d",
    "sbhd_sb2hd",
    "sbhd_sbh2d",
    "sbhd_sbhd_sbhd",
    "bs3hd",
    "bsh3d",
    "bshd_bs2hd",
    "bshd_bsh2d",
    "bshd_bshd_bshd",
]
589

590

591
592
model_configs_layout = {
    #       test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
    "layout_0_0": ModelConfig(2, 128, 16, 64),
    "layout_0_1": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "layout_0_3": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_1_0": ModelConfig(2, 2048, 24, 128),
    "layout_1_1": ModelConfig(
        2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_3": ModelConfig(
        1,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
    "layout_2_1": ModelConfig(
        2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
625
626
}

627

628
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
629
@pytest.mark.parametrize("dtype", param_types_lean)
630
631
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
632
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
633
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
634
    """Test DotProductAttention module with different QKV layouts"""
635
636
637
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


638
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
639
640
model_configs_layout_thd = {
    #       test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
641
642
643
644
645
646
647
    "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "layout_1_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
648
    ),
649
    "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
650
    "layout_2_1": ModelConfig(
651
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
652
653
    ),
    "layout_2_2": ModelConfig(
654
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
655
    ),
656
    "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
657
    "layout_3_1": ModelConfig(
658
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
659
660
    ),
    "layout_3_2": ModelConfig(
661
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
662
    ),
663
    "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
664
    "layout_4_1": ModelConfig(
665
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
666
667
    ),
    "layout_4_2": ModelConfig(
668
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
669
670
    ),
    "layout_5_0": ModelConfig(
671
        2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
672
673
    ),
    "layout_5_1": ModelConfig(
674
675
676
677
678
679
680
        2,
        2048,
        24,
        128,
        num_gqa_groups=1,
        attn_mask_type="padding_causal_bottom_right",
        window_size=(4, 0),
681
682
683
    ),
    "layout_5_2": ModelConfig(
        2,
684
        2048,
685
686
        24,
        128,
687
688
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal_bottom_right",
689
690
        window_size=(4, 0),
    ),
691
692
693
}


694
695
696
697
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
    get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
698
699
700
701
702
703
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with different QKV layouts"""
704
705
706
    config = model_configs[model]
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")
707
    logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
708
    pad_between_seqs = True
709
710
711
    test_dot_product_attention(
        dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
    )
712
    if get_cudnn_version() >= (9, 3, 0):
713
        logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
714
715
716
717
718
        # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
        pad_between_seqs = False
        test_dot_product_attention(
            dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
        )
719

720

721
def _run_dot_product_attention(
722
723
724
725
726
727
728
729
730
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_layout: str,
    workspace_opt: bool,
    pad_between_seqs: bool,
    is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
731
732
733
734
    """Run DotProductAttention module with one forward pass and one backward pass"""

    # Set RNG and environment varables
    reset_rng_states()
735
736
737
738
739
740
741
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
742
    _attention_backends["backend_selection_requires_update"] = True
743

744
    # Create seqlens
745
746
747
748
749
750
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
751
            seqlens_kv = seqlens_q
752
        if config.attn_type == "cross":
753
754
755
756
757
758
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
759
760
761
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
762
    else:
763
764
765
766
767
768
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
769
770
771
772
773
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

774
775
776
777
778
779
780
    seqlens_q_after_pad = seqlens_q.clone()
    seqlens_kv_after_pad = seqlens_kv.clone()
    cu_seqlens_q_after_pad = cu_seqlens_q.clone()
    cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
    pad_len = [0] * config.batch_size
    if pad_between_seqs:
        max_pad_len = 3
781
        pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda")  # 3
782
783
784
785
786
        seqlens_q_after_pad = seqlens_q + pad_len
        seqlens_kv_after_pad = seqlens_kv + pad_len
        cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
        cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)

787
788
789
    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
790
        if config.attn_type == "self":
791
792
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
793
794
795
796
797
798
799
800
801
802
803
804
805
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
806
            attention_mask = attention_mask_q.to(device="cuda")
807
        if config.attn_type == "cross":
808
809
810
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
                attention_mask_kv = torch.cat(
                    [
                        attention_mask_kv,
                        torch.Tensor(
                            [False] * seqlens_kv[i]
                            + [True] * (config.max_seqlen_kv - seqlens_kv[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
838
            attention_mask = (
839
840
841
                attention_mask_q.to(device="cuda"),
                attention_mask_kv.to(device="cuda"),
            )
842

843
    alibi_slopes = None
844
845
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
        if config.bias_shape == "1hss":
846
847
848
            alibi_slopes = (
                torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
            )
849
        if config.bias_shape == "bhss":
850
851
852
853
854
            alibi_slopes = (
                torch.randn(config.batch_size, config.num_heads)
                .abs()
                .to(dtype=torch.float32, device="cuda")
            )
855

856
857
    # Create input tensors
    dim_to_num = {
858
859
860
861
862
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
863
864
        "dqk": config.head_dim_qk,
        "dv": config.head_dim_v,
865
866
867
868
869
870
        "t": cu_seqlens_q_after_pad[-1],
        "tg": cu_seqlens_kv_after_pad[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
871
    inp = []
872
    inp_orig = []
873
874
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
875
        if i == 0:
876
            layout = layout.replace("s", "sq")
877
        else:
878
879
880
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
881
882
883
884
        if i == 2:
            layout = layout.replace("d", "dv")
        else:
            layout = layout.replace("d", "dqk")
885
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
886
        tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
887
        tensor_orig = tensor
888
889
        if qkv_format == "thd" and pad_between_seqs:
            tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
890
            if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
891
892
893
894
895
896
897
898
899
900
901
902
903
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_q_after_pad[i - 1],
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_q_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
904
            if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
905
906
907
908
909
910
911
912
913
914
915
916
917
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_kv_after_pad[i - 1],
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_kv_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
918
919
        tensor_count = 1
        split_dim = 0
920
        for dim, l in enumerate(layout.split("_")):
921
922
923
924
925
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
926
927
928
        tensors_orig = (
            torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
        )
929
930
931
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
932
                inp_orig.append(tensors_orig[j].squeeze(split_dim))
933
934
            else:
                inp.append(tensors[j])
935
                inp_orig.append(tensors_orig[j])
936
    for i in range(3):
937
        inp[i].requires_grad = True
938
939
        inp_orig[i].requires_grad = True

940
    # Create output gradient
941
942
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
943
    qkv_format_kv = qkv_format_kv.replace("d", "dv")
944
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
945
946
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
    out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
947
    out_grad_orig = out_grad
948
949
    if qkv_format == "thd" and pad_between_seqs:
        out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
950
        if qkv_format_kv == "t_h_dv":
951
952
953
954
955
956
957
958
959
960
            for i in range(1, config.batch_size + 1):
                valid_range = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
                out_grad[pad_range[0] : pad_range[1]] = 0.0
                out_grad_orig = torch.cat(
                    [out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
                )
961

962
    # Create bias
963
    if config.attn_bias_type in ["no_bias", "alibi"]:
964
        bias = None
965
966
967
968
    if config.attn_bias_type == "post_scale_bias":
        shape = "_".join(config.bias_shape)
        shape = shape.replace("_s_s", "_sq_skv")
        tensor_shape = [dim_to_num[j] for j in shape.split("_")]
969
        bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
970
        if config.bias_shape != "1hss":
971
            bias.requires_grad = False
972
973
974
975

    # Create RNG
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
976

977
978
979
980
981
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    # Set up model
982
983
    block = DotProductAttention(
        config.num_heads,
984
        (config.head_dim_qk, config.head_dim_v),
985
986
987
988
989
990
991
992
993
994
995
        num_gqa_groups=config.num_gqa_groups,
        attention_dropout=config.dropout_p,
        qkv_format=qkv_format,
        attn_mask_type=config.attn_mask_type,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type=config.attn_type,
    ).to(dtype=dtype, device="cuda")
996
997
    if not is_training:
        block = block.eval()
998

999
    # Run a forward and backward pass
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        q = inp_orig[0]
        k = inp_orig[1]
        v = inp_orig[2]
        d_out = out_grad_orig
    if backend == "FusedAttention":
        q = inp[0]
        k = inp[1]
        v = inp[2]
        d_out = out_grad
1010
1011
1012
1013
    out = block(
        q,
        k,
        v,
1014
        window_size=config.window_size,
1015
1016
1017
1018
1019
1020
        attention_mask=attention_mask,
        qkv_format=qkv_format,
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1021
1022
        cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
        cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
1023
1024
1025
1026
1027
1028
1029
        attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=config.attn_bias_type,
        core_attention_bias=bias,
        alibi_slopes=alibi_slopes,
        fast_zero_fill=True,
    )
1030
1031
    if is_training:
        out.backward(d_out)
1032

1033
1034
1035
1036
1037
1038
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        if is_training:
            return out, (q.grad, k.grad, v.grad)
        else:
            return out, (None, None, None)
    if backend == "FusedAttention":
1039
1040
        if qkv_format == "thd" and pad_between_seqs:
            out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1041
1042
1043
1044
            if is_training:
                q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
            for i in range(1, config.batch_size + 1):
                valid_range_q = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                valid_range_kv = (
                    cu_seqlens_kv_after_pad[i - 1],
                    cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                )
                out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
                if is_training:
                    q_grad_orig = torch.cat(
                        [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
                    )
                    k_grad_orig = torch.cat(
                        [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
                    v_grad_orig = torch.cat(
                        [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
1065
1066
1067
1068
1069
1070
1071
1072
1073
            if is_training:
                return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
            else:
                return out_orig, (None, None, None)
        else:
            if is_training:
                return out, (q.grad, k.grad, v.grad)
            else:
                return out, (None, None, None)
1074

1075

1076
1077
model_configs_te_layer = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,      mask,             bias
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
    "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "te_1_1": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "te_1_2": ModelConfig(
        2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
    ),
    "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
    "te_2_1": ModelConfig(2, 2048, 16, 64),
    "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
    "te_2_3": ModelConfig(
        1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
    "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
1094
}
1095

1096

1097
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1098
@pytest.mark.parametrize("dtype", param_types)
1099
1100
1101
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
1102
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
1103
1104
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
1105
1106
1107
def test_transformer_layer(
    dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
1108
    """Test TransformerLayer module"""
1109

Tim Moon's avatar
Tim Moon committed
1110
    # Get configs
1111
    config = model_configs[model]
1112
    tols = dict(atol=5e-2, rtol=5e-2)
1113
    workspace_opt = True
1114

1115
    # Test backend availability
1116
    is_training = True
1117
    available_backends, _, fused_attn_backends = get_available_attention_backends(
Tim Moon's avatar
Tim Moon committed
1118
        config,
1119
        qkv_dtype=dtype,
1120
1121
1122
        qkv_layout=(
            qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
        ),
1123
        is_training=is_training,
Tim Moon's avatar
Tim Moon committed
1124
    )
1125
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1126
1127
    if not fused_attn_supported:
        is_training = False
1128
        available_backends, _, fused_attn_backends = get_available_attention_backends(
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
            config,
            qkv_dtype=dtype,
            qkv_layout=(
                qkv_format.replace("hd", "h3d")
                if fused_qkv_params
                else qkv_format.replace("hd", "3hd")
            ),
            is_training=is_training,
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1139
1140
1141

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
1142
        pytest.skip("Less than two backends to compare.")
1143
1144
1145
    # Skip if qkv_format = thd and "padding" not in attn_mask_type
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        pytest.skip("THD requires padding mask.")
Tim Moon's avatar
Tim Moon committed
1146
1147

    # UnfusedDotProductAttention backend
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_format,
            workspace_opt,
            fused_qkv_params,
            RoPE,
1158
            is_training,
1159
        )
Tim Moon's avatar
Tim Moon committed
1160
1161
1162
1163
1164
1165
1166

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FusedAttention",
1167
1168
1169
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1170
1171
            fused_qkv_params,
            RoPE,
1172
            is_training,
Tim Moon's avatar
Tim Moon committed
1173
        )
1174

Tim Moon's avatar
Tim Moon committed
1175
1176
1177
1178
1179
1180
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FlashAttention",
1181
1182
1183
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1184
1185
            fused_qkv_params,
            RoPE,
1186
            is_training,
Tim Moon's avatar
Tim Moon committed
1187
        )
1188

1189
    logging.info(f"[test_transformer_layer]: is_training = {is_training}")
1190
    if unfused_attn_supported and fused_attn_supported:
1191
        logging.info("[test_transformer_layer]: unfused attn vs fused attn")
1192
1193
1194
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
    if unfused_attn_supported and flash_attn_supported:
1195
        logging.info("[test_transformer_layer]: unfused attn vs flash attn")
Tim Moon's avatar
Tim Moon committed
1196
1197
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
1198
    if fused_attn_supported and flash_attn_supported:
1199
        logging.info("[test_transformer_layer]: fused attn vs flash attn")
1200
1201
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
1202

1203

1204
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1205
@pytest.mark.parametrize("dtype", param_types_lean)
1206
1207
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
1208
1209
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format):
hugo-syn's avatar
hugo-syn committed
1210
    """Test TransformerLayer module with miscellaneous settings"""
1211
1212
1213
    ckpt_attn = True
    fused_qkv_params = True
    RoPE = True
1214
1215
1216
    test_transformer_layer(
        dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
    )
1217

1218

1219
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1220
1221
1222
1223
1224
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
    """Test TransformerLayer module with MQA/GQA"""
1225

1226
    def find_factors(x):
1227
1228
1229
1230
1231
        f = []
        for i in range(2, x + 1):
            if x % i == 0:
                f.append(i)
        return f
1232

1233
1234
1235
1236
1237
1238
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    config = model_configs[model]
    num_querys_per_gqa_group = find_factors(config.num_heads)
1239
1240

    for num_q_per_gqa_group in num_querys_per_gqa_group:
1241
1242
1243
1244
        config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
        test_transformer_layer(
            dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
        )
1245

1246

1247
def _run_transformer_layer(
1248
1249
1250
1251
1252
1253
1254
1255
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_format: str,
    workspace_opt: bool,
    fused_qkv_params: bool,
    RoPE: bool,
1256
    is_training: bool,
1257
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
1258
1259
1260
    """Run TransformerLayer module with one forward pass and one backward pass"""

    # Set RNG and environment variables
1261
    reset_rng_states()
1262
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
1263
    os.environ["NVTE_FUSED_ATTN"] = "0"
1264
1265
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
1266
1267
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
1268
    _attention_backends["backend_selection_requires_update"] = True
1269

1270
    # Create input tensor
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
    if qkv_format == "sbhd":
        inp = torch.randn(
            config.max_seqlen_q,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.max_seqlen_kv,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1288
    if qkv_format == "bshd":
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
        inp = torch.randn(
            config.batch_size,
            config.max_seqlen_q,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.batch_size,
            config.max_seqlen_kv,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1305
1306

    # Create seqlens
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
1323
    else:
1324
1325
1326
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
    if qkv_format == "thd":
        inp = torch.randn(
            cu_seqlens_q[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            cu_seqlens_kv[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1349
1350
1351
1352
1353
1354
1355

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
1356
    drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
1357

1358
    # Create bias
1359
    bias = None
1360
1361
1362
1363
1364
1365
1366
1367
1368
    if config.attn_bias_type == "post_scale_bias":
        bias = torch.randn(
            1,
            config.num_heads,
            config.max_seqlen_q,
            config.max_seqlen_kv,
            dtype=dtype,
            device="cuda",
        )
1369
1370
1371
1372

    # Create RoPE
    rotary_pos_emb = None
    if RoPE:
1373
        PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
1374
        rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1375
1376

    # Set up model
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_heads,
        num_gqa_groups=config.num_gqa_groups,
        layernorm_epsilon=1e-5,
        hidden_dropout=0.0,
        attention_dropout=config.dropout_p,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        layer_number=layer_number,
1388
        kv_channels=config.head_dim_qk,
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
        self_attn_mask_type=config.attn_mask_type,
        tp_group=None,
        tp_size=1,
        params_dtype=dtype,
        get_rng_state_tracker=None,
        fuse_wgrad_accumulation=False,
        seq_length=config.max_seqlen_q,
        micro_batch_size=config.batch_size,
        sequence_parallel=False,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
1400
        layer_type="encoder" if config.attn_type == "self" else "decoder",
1401
1402
1403
1404
1405
1406
1407
1408
1409
        drop_path_rate=drop_path_rates[layer_number - 1],
        set_parallel_mode=True,
        fuse_qkv_params=fused_qkv_params,
        zero_centered_gamma=False,
        qkv_weight_interleaved=False,
        ub_tp_comm_overlap=False,
        bias=True,
        attn_input_format=qkv_format,
    ).to(dtype=dtype, device="cuda")
1410
1411
    if not is_training:
        block = block.eval()
1412

1413
1414
1415
    # Create ALiBi slopes
    alibi_slopes = None
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
1416
        alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
1417

1418
    # Run a forward and backward pass
1419
1420
    out = block(
        inp,
1421
        self_attn_mask_type=config.attn_mask_type,
1422
1423
        encoder_output=inp_enc if config.attn_type == "cross" else None,
        enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
1424
1425
1426
        checkpoint_core_attention=False,
        rotary_pos_emb=rotary_pos_emb,
        core_attention_bias_type=config.attn_bias_type,
1427
        core_attention_bias=bias,
1428
        alibi_slopes=alibi_slopes,
1429
1430
1431
1432
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1433
    )
1434
1435
1436
    if is_training:
        loss = out.sum()
        loss.backward()
1437
1438

    return out, inp.grad
1439
1440


1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
model_configs_fp8_extra_state = {
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
}


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
    config = model_configs_fp8_extra_state[model]
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="sb3hd",
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not fused_attn_supported and not flash_attn_supported:
        pytest.skip("No attention backend available.")

    outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
    outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
    outputs_checkpoint_v1_6 = _run_attention_extra_state(
        dtype, config, mimic_v1_6=True, checkpoint=True
    )

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )


def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
    steps = 10
    path = "checkpoint.pt"
    fp8_enabled = True
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_enabled,
        fp8_mha=False,
    )

    reset_rng_states()
    hidden_states = torch.randn(
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

    def get_model(dtype, config):
        sigma = 0.023
        init_method = init_method_normal(sigma)
        output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

        with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
            block = TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_heads,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                hidden_dropout=0.0,
                attention_dropout=0.0,
                fuse_qkv_params=True,
                params_dtype=dtype,
                device="cuda",
            )
        return block

    block = get_model(dtype, config)
    for i in range(steps // 2):
        with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    if checkpoint:
        sd = block.state_dict()
        if mimic_v1_6:
            sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
                "self_attention.core_attention._extra_state"
            ]
            del sd["self_attention.core_attention._extra_state"]
        torch.save(sd, path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

        _cpu_rng_state_new = torch.get_rng_state()
        _cuda_rng_state_new = torch.cuda.get_rng_state()

        del block
        block = get_model(dtype, config)
        block.load_state_dict(torch.load(path, weights_only=False))
        torch.set_rng_state(_cpu_rng_state_new)
        torch.cuda.set_rng_state(_cuda_rng_state_new)

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for i in range((steps + 1) // 2):
        with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [output, hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)

    return outputs


1585
model_configs_fp8_vs_f16 = {
1586
    #  test:             b,  h, hg,   d,   sq,  skv,   p,      mask,      bias
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
    "fp8_9": ModelConfig(2, 2048, 16, 128),
    "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
    "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
    "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
    "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
    "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
    "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
    "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
    "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
    "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
1599
}
1600

1601
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
1602
1603
1604
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]

1605
1606

def _rmse(a, b):
1607
    return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
1608

1609

1610
1611
1612
1613
def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
    logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
    logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
    try:
1614
1615
        if a.dtype != b.dtype:
            a = a.to(b.dtype)
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
        torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
    except Exception as e:
        logging.debug(e)

    rmse = _rmse(a, b)
    logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
    rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
    assert rmse < rmse_tol * rmse_range, (
        name_a
        + " vs "
        + name_b
        + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
            rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
        )
    )


1633
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1634
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1635
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
1636
1637
1638
1639
1640
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1641
@pytest.mark.parametrize("RoPE", [True, False])
1642
@pytest.mark.parametrize("is_training", [True, False])
1643
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
1644
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1645
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1646
1647
    config = model_configs_fp8_vs_f16[model]

1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
    # Test backend availability
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_format.replace("hd", "h3d"),
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
        pytest.skip("Less than two backends to compare.")
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_format.replace("hd", "h3d"),
            is_training=is_training,
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")

    if flash_attn_supported:
1671
1672
1673
1674
1675
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1676
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
1677
        )
1678

1679
1680
1681
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
    _attention_backends["backend_selection_requires_update"] = True
1682
    logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
1683
    fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1684
        dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
1685
    )
1686
1687

    logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
1688
    fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
1689
        dtype, config, False, qkv_format, input_layernorm, RoPE, is_training
1690
1691
    )

1692
1693
1694
    atol = 5e-1
    rtol = 5e-1
    rmse_tol = 0.15
1695
    logging.debug("========== {:^25s} ==========".format("forward output"))
1696
    if flash_attn_supported:
1697
1698
1699
1700
1701
1702
1703
1704
        _error(
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
1705
        )
1706
1707
1708
1709
1710
1711
1712
1713
    _error(
        fused_attn_fwd_fp8,
        fused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "fused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
1714
    )
1715

1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
    if is_training:
        for i in range(len(param_names[:1])):
            logging.debug("========== {:^25s} ==========".format(param_names[i]))
            _error(
                fused_attn_bwd_fp8[i],
                fused_attn_bwd_f16[i],
                f"fused_attn_bwd_fp8[{i}]",
                f"fused_attn_bwd_f16[{i}]",
                atol,
                rtol,
                rmse_tol,
1727
1728
            )

1729

1730
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
1731
1732
1733
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1734

1735
1736
1737
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
Tim Moon's avatar
Tim Moon committed
1738

1739
1740
1741
1742
1743
1744
1745
1746
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_mha,
        fp8_mha=fp8_mha,
    )
Tim Moon's avatar
Tim Moon committed
1747

1748
    with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe):
1749
1750
1751
1752
        rotary_pos_emb = None
        if RoPE:
            PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
            rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1753
        mha = MultiheadAttention(
1754
1755
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_heads,
1756
            kv_channels=config.head_dim_qk,
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            layer_number=1,
            bias=True,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            params_dtype=dtype,
            input_layernorm=input_layernorm,
            fuse_qkv_params=True,
            attention_type="self",
            qkv_weight_interleaved=True,
            qkv_format=qkv_format,
1768
        ).to(dtype=dtype, device="cuda")
1769
1770
        if not is_training:
            mha = mha.eval()
1771

1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
1792
1793
1794
1795
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
1796

1797
    dim_to_num = {
1798
1799
1800
1801
1802
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
1803
        "d": config.head_dim_qk,
1804
1805
1806
1807
1808
1809
1810
1811
1812
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
    layout = "_".join(qkv_format)
    layout = layout.replace("s", "sq")
    tensor_shape = [dim_to_num[j] for j in layout.split("_")]
1813
1814
    tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
    hidden_states = tensor.view(*tensor.shape[:-2], -1)
1815
1816
    if is_training:
        hidden_states.requires_grad = True
1817
1818
1819
1820
    tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
    out_grad = tensor.view(*tensor.shape[:-2], -1)

    with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe):
1821
1822
        out = mha(
            hidden_states,
1823
1824
1825
1826
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
            is_first_microbatch=None,
1827
            rotary_pos_emb=rotary_pos_emb,
1828
1829
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
1830
        )
1831
1832
    if is_training:
        out.backward(out_grad)
Tim Moon's avatar
Tim Moon committed
1833

1834
    param_names = []
1835
    param_names.append("hidden_states.grad")
1836
1837
1838
1839
    params = []
    params.append(hidden_states)
    for name, param in mha.named_parameters():
        if param.requires_grad:
1840
            param_names.append(name + ".grad")
1841
            params.append(param)
1842

1843
1844
1845
    if is_training:
        return out, param_names, tuple(x.grad for x in params)
    return out, param_names, tuple(None for x in params)
1846

1847

1848
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1849
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
1850
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
1851
1852
1853
1854
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1855
1856
@pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
1857
1858
    config = model_configs_fp8_vs_f16[model]

1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
    # TODO(cyang): think of another way to verify dropout results
    # test cuDNN FP8 dropout
    # 1. we modify the config here to not affect mha_fp8_vs_f16 tests
    # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an
    #    indirect verification method, we create Q/K/V as all 1s and check if O is all 1s
    # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell
    # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type:
    #    if get_device_compute_capability() >= (10, 0):
    #        config.dropout_p = 0.1

1869
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1870
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1871

1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
    # Test backend availability
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_layout,
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    # Skip if only unfused backend is supported
    if flash_attn_supported + fused_attn_supported < 1:
        pytest.skip("No FP8 attention backend available.")
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            is_training=is_training,
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")

    if flash_attn_supported:
1897
1898
1899
1900
1901
1902
1903
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
        flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
            dtype, config, True, qkv_layout, is_training
        )
1904

1905
1906
1907
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
    _attention_backends["backend_selection_requires_update"] = True
1908
    logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
1909
1910
1911
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
        dtype, config, True, qkv_layout, is_training
    )
1912

1913
1914
1915
1916
1917
1918
    if config.dropout_p == 0.0:
        # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
        fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
            dtype, config, False, qkv_layout, is_training
        )
1919

1920
1921
    atol = 5e-1
    rtol = 5e-2
1922
    rmse_tol = 0.11
1923
1924
    bwd_names = ["dq", "dk", "dv"]
    logging.debug("========== {:^25s} ==========".format("forward output"))
1925
    if flash_attn_supported:
1926
1927
1928
1929
1930
1931
1932
1933
        _error(
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
1934
        )
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
    if config.dropout_p != 0.0:
        # test cuDNN FP8 dropout
        assert torch.all(
            fused_attn_fwd_fp8 == 1
        ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
    else:
        _error(
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
                _error(
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                )
1962
1963


1964
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
1965

1966
1967
1968
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1969

1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_dpa,
    )

1982
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
1983
    with fp8_model_init(enabled=fp8_dpa):
1984
1985
        dpa = DotProductAttention(
            config.num_heads,
1986
            config.head_dim_qk,
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            sequence_parallel=False,
            tp_size=1,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            tp_group=None,
            layer_number=1,
            attention_type="self",
            qkv_format=qkv_format,
        ).to(dtype=dtype, device="cuda")
1997
1998
        if not is_training:
            dpa = dpa.eval()
1999

2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
2020
2021
2022
2023
2024
2025
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

    dim_to_num = {
2026
2027
2028
2029
2030
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2031
        "d": config.head_dim_qk,
2032
2033
2034
2035
2036
2037
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
2038
    inp = []
2039
2040
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
2041
        if i == 0:
2042
            layout = layout.replace("s", "sq")
2043
        else:
2044
2045
2046
2047
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2048
2049
2050
2051
2052
        if config.dropout_p == 0.0:
            tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
        else:
            # test cuDNN FP8 dropout
            tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda")
2053
2054
        tensor_count = 1
        split_dim = 0
2055
        for dim, l in enumerate(layout.split("_")):
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
        inp[i].requires_grad = True

2069
2070
2071
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
2072
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
2073
    out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
2074
2075

    with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
2076
2077
2078
2079
        out = dpa(
            inp[0],
            inp[1],
            inp[2],
2080
2081
2082
2083
2084
2085
2086
2087
            qkv_format=qkv_format,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
2088
        )
2089
2090
    if is_training:
        out.backward(out_grad)
2091

2092
2093
2094
    if is_training:
        return out, (inp[0].grad, inp[1].grad, inp[2].grad)
    return out, (None, None, None)
2095
2096
2097
2098


model_configs_fp8 = {
    #  test:             b,  h, hg,   d,   sq,  skv,   p,      mask,      bias
2099
2100
2101
2102
2103
2104
2105
2106
    "fp8_1": ModelConfig(1, 512, 1, 64),
    "fp8_2": ModelConfig(4, 512, 16, 64),
    "fp8_3": ModelConfig(1, 2048, 1, 128),
    "fp8_4": ModelConfig(2, 2048, 24, 128),
    "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
    "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
    "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
    "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
2107
2108
}
param_types_fp8 = [torch.float16, torch.bfloat16]
2109
2110
2111
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
2112
2113


2114
2115
2116
2117
2118
2119
2120
2121
@pytest.mark.skipif(
    (
        get_cudnn_version() < (8, 9, 3)
        if cudnn_frontend_version == 0
        else get_cudnn_version() < (9, 2, 1)
    ),
    reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
2122
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
2123
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
    """Test FP8 dot product attention implementations based on cuDNN frontend
    v0.9 and v1.0+. Each test compares results from a custom implementation of
    an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
    implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
    Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""

    config = model_configs_fp8[model]

2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not (fused_attn_backends and unfused_attn_supported):
        pytest.skip("Not enough backends to run this test with.")

2147
2148
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
    unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
2149

2150
2151
    atol = 5e-1
    rtol = 5e-1
2152
    rmse_tol = 0.13
2153
2154
2155
2156
2157
2158
2159
2160
    _error(
        fused_attn_fwd_fp8,
        unfused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "unfused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
2161
    )
2162
2163
2164
2165
2166
2167
2168
2169
    _error(
        fused_attn_bwd_fp8,
        unfused_attn_bwd_f16,
        "fused_attn_bwd_fp8",
        "unfused_attn_bwd_f16",
        atol,
        rtol,
        rmse_tol,
2170
    )
2171
2172
2173
2174
2175


def _run_custom_mha_fp8(dtype, config, backend):
    """Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
2176
    reset_rng_states()
2177
2178
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
2179
2180
2181
2182
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2183
    _attention_backends["backend_selection_requires_update"] = True
2184

2185
2186
2187
    inp = 0.0001 * torch.randint(
        -100,
        100,
2188
        (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
2189
2190
2191
2192
2193
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
2194
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2195
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2196

2197
    out_grad = 0.01 * torch.randn(
2198
        config.batch_size * config.max_seqlen_q,
2199
        config.num_heads * config.head_dim_qk,
2200
2201
2202
2203
        dtype=dtype,
        device="cuda",
    )
    torch.save(out_grad, "out_grad.pt")
2204
2205
2206
2207
2208
2209
2210
2211

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

2212
    mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
2213
    with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
2214
        out = mha(inp, cu_seqlens, config.max_seqlen_q)
2215
    out.backward(out_grad)
2216

2217
    out = torch.load("out.pt")
2218
2219
2220
2221
    dqkv = torch.load("dqkv.pt")
    return (
        out.view(config.batch_size, config.max_seqlen_q, -1),
        dqkv.view(
2222
            config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
2223
2224
        ).contiguous(),
    )
2225

2226

2227
2228
2229
def _run_ref_mha_f16(dtype, config, backend):
    """Run reference F16 FusedAttention. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
2230
2231
2232
2233
2234
2235
2236

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2237
    _attention_backends["backend_selection_requires_update"] = True
2238

2239
    inp = torch.load("qkv.pt").to(device="cuda")
2240
2241
2242
    inp.requires_grad = True
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2243
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2244
2245
2246
    out_grad = (
        torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
    )
2247
2248
2249
2250
2251
2252
2253

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
2254

2255
2256
    block = DotProductAttention(
        config.num_heads,
2257
        config.head_dim_qk,
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
        attention_dropout=config.dropout_p,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type="self",
        qkv_format="bshd",
    ).to(dtype=dtype, device="cuda")

    q = inp[:, :, 0, :, :]
    k = inp[:, :, 1, :, :]
    v = inp[:, :, 2, :, :]
2271
2272
2273
2274
    out = block(q, k, v, attn_mask_type=config.attn_mask_type)
    out.backward(out_grad)

    return out, inp.grad
2275
2276
2277
2278
2279
2280
2281


_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

2282
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
2283
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
2284
2285
2286
2287
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
2288
2289


2290
class _custom_mha_fp8(torch.autograd.Function):
2291
2292
2293
2294
2295
2296
2297
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
2298
        num_heads: int,
2299
2300
2301
2302
2303
2304
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        workspace: torch.Tensor,
        is_training: bool,
2305
        mask_type: str,
2306
        quantizers: list[Quantizer],
2307
    ) -> torch.Tensor:
2308
        qkv_dtype = inp.dtype
2309
2310
2311

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
2312
        h = num_heads
2313
2314
2315
        d = in_features // h
        b = cu_seqlens.numel() - 1

2316
2317
2318
2319
2320
2321
2322
2323
        input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
        qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
        qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
        dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
        dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
2324

2325
        inp_fp8 = input_quantizer(inp)
2326

2327
        qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight)
2328

2329
        qkv, *_ = ext.general_gemm(
2330
            qkv_weight_fp8,
2331
            inp_fp8,
2332
2333
            workspace,
            bias=qkv_bias,
2334
2335
            out_dtype=qkv_weight_fp8.dtype,
            quantization_params=qkv_quantizer,
2336
2337
            use_split_accumulator=_2X_ACC_FPROP,
        )
2338
        qkv = qkv.view(-1, 3, h, d)
2339
        qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
2340
        torch.save(qkv_fp16, "qkv.pt")
2341
        if cudnn_frontend_version == 1:
2342
            qkv = qkv.view(b, max_s, 3, h, d)  # bs3hd
2343
2344

        # FMHA
2345
2346
2347
2348
2349
2350
2351
2352
        q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
        k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
        v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
        q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
        k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
        v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)

        out, aux_ctx_tensors = fused_attn_fwd(
2353
2354
2355
2356
2357
            is_training,
            max_s,
            max_s,
            cu_seqlens,
            cu_seqlens,
2358
2359
2360
2361
            q,
            k,
            v,
            qkv_dtype,
2362
2363
2364
2365
2366
2367
2368
2369
            FusedAttnBackend["FP8"],
            attn_scale=None,
            dropout=p_dropout,
            fast_zero_fill=fast_zero_fill,
            qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
            attn_bias_type="no_bias",
            attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
            rng_gen=None,
2370
2371
            o_quantizer=o_quantizer,
            s_quantizer=s_quantizer,
2372
        )
2373

2374
2375
        tensors_to_save, tensor_objects = prepare_for_saving(
            q, k, v, inp_fp8, qkv_weight_fp8, workspace, out
2376
        )
2377
2378
2379

        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
2380
        ctx.aux_ctx_tensors = aux_ctx_tensors
2381
        ctx.qkv_dtype = qkv_dtype
2382
2383
2384
2385
2386
2387
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.hidden_size = in_features
2388
        ctx.num_heads = num_heads
2389
2390
        ctx.mask_type = mask_type
        ctx.dtype = inp.dtype
2391

2392
2393
2394
2395
2396
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = s_quantizer

2397
        out = out.view(-1, in_features)  # (bs)(hd)
2398
        out_fp16 = out.dequantize()
2399
        torch.save(out_fp16, "out.pt")  # (bs)(hd)
2400
        return out_fp16
2401
2402

    @staticmethod
2403
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
2404
        with torch.cuda.nvtx.range("_DPA"):
2405
2406
2407
2408
            saved_tensors = ctx.saved_tensors
            (q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved(
                ctx.tensor_objects, saved_tensors
            )
2409

2410
2411
            proj_dgrad = ctx.dO_quantizer(grad_output)
            fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2412

2413
            dq, dk, dv, *rest = fused_attn_bwd(
2414
2415
2416
2417
                ctx.max_s,
                ctx.max_s,
                ctx.cu_seqlens,
                ctx.cu_seqlens,
2418
2419
2420
                q,
                k,
                v,
2421
2422
                out,
                proj_dgrad.view_as(out),
2423
                ctx.qkv_dtype,
2424
2425
2426
2427
2428
                fp8_dtype_backward,
                ctx.aux_ctx_tensors,
                FusedAttnBackend["FP8"],
                None,
                None,
2429
2430
2431
                ctx.S_quantizer,
                ctx.dP_quantizer,
                ctx.dQKV_quantizer,
2432
2433
2434
2435
2436
2437
2438
                attn_scale=None,
                dropout=ctx.p_dropout,
                fast_zero_fill=ctx.fast_zero_fill,
                qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
                attn_bias_type="no_bias",
                attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
            )
2439
            dim = 2 if cudnn_frontend_version == 1 else 1
2440
2441
            dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
            dqkv_shape = list(dq._data.shape)
2442
            dqkv_shape.insert(dim, 3)
2443
            dqkv_stride = list(dq._data.stride())
2444
            dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
2445
2446
2447
            dqkv.set_(
                dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride
            )  # bs3hd
2448

2449
            dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
2450
2451
            dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape)
            dqkv_c_fp16 = dqkv_c.dequantize()
2452
            torch.save(dqkv_c_fp16, "dqkv.pt")
2453

2454
2455
2456
            qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer)
            dqkv_c._transpose = None
            dqkv_c._create_transpose()
2457
2458

            # QKV DGRAD
2459
2460
            qkv_dgrad, *_ = ext.general_gemm(
                qkv_weight_fp8,
2461
                dqkv_c,
2462
                workspace,
2463
                ctx.dtype,
2464
                use_split_accumulator=_2X_ACC_DGRAD,
2465
                layout="NN",
2466
            )
2467

2468
            # QKV WGRAD
2469
2470
2471
            qkv_wgrad, *_ = ext.general_gemm(
                inp_fp8,
                dqkv,
2472
                workspace,
2473
                ctx.dtype,
2474
                use_split_accumulator=_2X_ACC_WGRAD,
2475
                layout="NT",
2476
2477
            )

2478
2479
        return (
            qkv_dgrad,
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2491
2492
            None,
        )
2493

2494

2495
class Custom_MHA_FP8(TransformerEngineBaseModule):
2496
    def __init__(self, config, params_dtype: torch.dtype = torch.float32):
2497
2498
        super().__init__()
        self.p_dropout = config.dropout_p
2499
        self.h = config.num_heads
2500
        self.hidden_size = config.hidden_size
2501
        self.head_dim = config.head_dim_qk
2502
        self.fast_zero_fill = True
2503
        self.mask_type = config.attn_mask_type
2504

Tim Moon's avatar
Tim Moon committed
2505
        self.qkv_weight = torch.nn.Parameter(
2506
2507
2508
2509
2510
2511
2512
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
Tim Moon's avatar
Tim Moon committed
2513
        self.qkv_bias = torch.nn.Parameter(
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)
        self.workspace = torch.empty(
            _CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
        )

    def forward(
2528
2529
2530
2531
        self,
        inp: torch.Tensor,
        cu_seqlens,
        max_s,
2532
    ) -> torch.Tensor:
2533
        with self.prepare_forward(inp, num_gemms=3) as inp:
2534
            out = _custom_mha_fp8.apply(
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
                self.workspace,
2545
                self.training,
2546
                self.mask_type,
2547
                self.quantizers,
2548
            )
2549
        return out