test_attention.py 103 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
import logging
Tim Moon's avatar
Tim Moon committed
5
import os
6
7
import sys
import pathlib
8
from typing import Any, Dict, Tuple, Union
Tim Moon's avatar
Tim Moon committed
9

10
import pytest
Tim Moon's avatar
Tim Moon committed
11
import torch
12

13
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype
Tim Moon's avatar
Tim Moon committed
14
from transformer_engine.common import recipe
15
16
17
18
from transformer_engine.pytorch import (
    TransformerLayer,
    autocast,
    quantized_model_init,
Tim Moon's avatar
Tim Moon committed
19
    DotProductAttention,
20
21
22
23
24
25
26
    MultiheadAttention,
    get_device_compute_capability,
    Quantizer,
    is_fp8_available,
    is_bf16_available,
)
from transformer_engine.pytorch.attention.dot_product_attention import (
27
28
    _attention_backends,
)
29
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
30
    FlashAttentionUtils,
31
    check_set_window_size,
Tim Moon's avatar
Tim Moon committed
32
)
33
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
Tim Moon's avatar
Tim Moon committed
34
35
36
37
38
39
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    FusedAttnBackend,
    fused_attn_bwd,
    fused_attn_fwd,
)
40
41
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
42
43
44
45
from transformer_engine.pytorch.utils import (
    init_method_normal,
    scaled_init_method_normal,
)
46
from transformer_engine.pytorch.utils import get_cudnn_version
47
import transformer_engine_torch as tex
48
49
from transformer_engine.pytorch.quantized_tensor import (
    Quantizer,
50
51
52
    prepare_for_saving,
    restore_from_saved,
)
53

54
55
56
57
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
    reset_rng_states,
58
    compare_and_assert,
59
60
61
62
63
    ModelConfig,
    dtype_tols,
    get_available_attention_backends,
)

64
# Check if hardware supports FP8 attention.
65
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
66
67
68
69
70
71
72
73
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability()
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
    fp8_attn_available = False
    reason_for_no_fp8_attn = (
        "FP8 attention is not supported for compute capability ="
        f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
    )
74

75
# Reset RNG seed and states
76
seed = 1234
77
reset_rng_states()
78
79


80
# Reset FP8 global state manager
81
82
83
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
84
    FP8GlobalStateManager.reset()
85
86


87
88
# Define F16 data types to test
param_types = [torch.float16]
89
if is_bf16_available():
90
91
    param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
92

93
model_configs_base = {
94
    # test: ModelConfig(b, sq, hq, dqk)
95
96
97
98
99
100
101
102
103
104
105
106
    "base_1_0": ModelConfig(8, 128, 16, 64),
    "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
    "base_2_0": ModelConfig(2, 2048, 24, 128),
    "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
    "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
    "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
    "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
    "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
    "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
    "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
    "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
107
108
}

109

110
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
111
@pytest.mark.parametrize("dtype", param_types)
112
113
114
115
116
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
117
@pytest.mark.parametrize("swa", [False])
118
@pytest.mark.parametrize("pad_between_seqs", [False])
119
120
121
def test_dot_product_attention(
    dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
):
122
    """Test DotProductAttention module"""
123

Tim Moon's avatar
Tim Moon committed
124
    # Get configs
125
    tols = dict(atol=1e-3, rtol=1e-3)
Tim Moon's avatar
Tim Moon committed
126
    if dtype == torch.bfloat16:
127
        tols = dict(atol=1.5e-2, rtol=1.5e-2)
128
    config = model_configs[model]
129
    is_mla = config.head_dim_qk != config.head_dim_v
130
    is_mqa_gqa = config.num_heads != config.num_gqa_groups
131
132
    if qkv_layout is None:
        if config.attn_type == "self":
133
            qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
134
        else:
135
            qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
136
    if "3" in qkv_layout and config.attn_type == "cross":
137
        pytest.skip("No need to test this layout for cross attention")
Tim Moon's avatar
Tim Moon committed
138

139
140
141
    if config.window_size == (-1, -1) and swa:
        config.window_size = [2, 2]
    config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
142
143
144
145
146
    qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        config.attn_mask_type = (
            "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
        )
147

148
    # Get backends
149
    is_training = True
150
    available_backends, _, fused_attn_backends = get_available_attention_backends(
151
        config,
152
        qkv_dtype=dtype,
153
        qkv_layout=qkv_layout,
154
        pad_between_seqs=pad_between_seqs,
155
        is_training=is_training,
Tim Moon's avatar
Tim Moon committed
156
    )
157
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
158
159
    if not fused_attn_supported:
        is_training = False
160
        available_backends, _, fused_attn_backends = get_available_attention_backends(
161
162
163
164
165
166
167
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            pad_between_seqs=pad_between_seqs,
            is_training=is_training,
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
168

169
170
    # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
    # mannually pads and unpads the input and output of FlashAttention for testing purposes
171
172
    if (
        pad_between_seqs
173
        and FlashAttentionUtils.is_installed
174
175
176
177
        and not (
            config.max_seqlen_q != config.max_seqlen_kv
            and config.attn_mask_type in ["causal", "padding_causal"]
        )
178
        and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
179
    ):
180
        flash_attn_supported = True
181
182
183

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
184
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
185
186

    # UnfusedDotProductAttention backend
187
    if unfused_attn_supported:
188
        unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
189
190
191
192
193
194
195
196
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
197
        )
Tim Moon's avatar
Tim Moon committed
198
199
200

    # FusedAttention backend
    if fused_attn_supported:
201
        if len(fused_attn_backends) == 1:
202
            fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
203
204
205
206
207
208
209
210
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
211
            )
212
        if len(fused_attn_backends) == 2:
213
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
214
            fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
215
216
217
218
219
220
221
222
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
223
224
            )
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
225
            fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
226
227
228
229
230
231
232
233
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
234
            )
235

Tim Moon's avatar
Tim Moon committed
236
237
    # FlashAttention backend
    if flash_attn_supported:
238
        flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
239
240
241
242
243
244
245
246
            dtype,
            config,
            "FlashAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
Tim Moon's avatar
Tim Moon committed
247
        )
248

249
    # Compare results
250
    logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
251
    if unfused_attn_supported and flash_attn_supported:
252
        logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
253
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
254
        for i, _ in enumerate(flash_attn_bwd):
255
            torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
256
257
258
    if unfused_attn_supported and fused_attn_supported:
        logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
259
260
        if config.return_max_logit:
            torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
261
262
        for i, _ in enumerate(unfused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
263
    if fused_attn_supported and flash_attn_supported:
264
        logging.info("[test_dot_product_attention]: fused attn vs flash attn")
265
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
266
        for i, _ in enumerate(flash_attn_bwd):
267
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
268
    if fused_attn_supported and len(fused_attn_backends) == 2:
269
        logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
270
        torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
271
        for i, _ in enumerate(fused_attn_bwd):
272
273
            torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)

274

275
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
276
277
278
279
280
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
    """Test DotProductAttention module with checkpointing"""
281
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
282

283

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
model_configs_max_logit = {
    # test: ModelConfig(b, sq, hq, dqk)
    "max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
    "max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
    "max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "max_logit_4": ModelConfig(
        8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
    ),
    "max_logit_5": ModelConfig(
        8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
    ),
    "max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with checkpointing"""
    config = model_configs[model]
    config.return_max_logit = True
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
model_configs_softmax = {
    # test: ModelConfig(b, sq, hq, dqk)
    "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
    "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"),
    "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"),
    "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
    "softmax_2_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
    ),
    "softmax_2_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
    ),
    "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"),
    "softmax_3_1": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one"
    ),
    "softmax_3_2": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable"
    ),
    "softmax_4_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal"
    ),
    "softmax_4_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="off-by-one",
    ),
    "softmax_4_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="causal",
        softmax_type="learnable",
    ),
    "softmax_5_0": ModelConfig(
        2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal"
    ),
    "softmax_5_1": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="off-by-one",
    ),
    "softmax_5_2": ModelConfig(
        2,
        2048,
        64,
        64,
        num_gqa_groups=8,
        window_size=(128, 0),
        attn_mask_type="padding_causal",
        softmax_type="learnable",
    ),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax(dtype, model_configs, model):
    """Test DotProductAttention module with different softmax types"""
    test_dot_product_attention(
        dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False
    )


390
model_configs_mla = {
391
    #TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
392
    #    test:             b,  h, hg, dqk, sq, skv,   p,      mask,      bias   # attn , backend
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    # "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),  # self , 0
    # "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    # "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),  # cross, 0
    # "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),  # self , 1
    # "mla_2_1": ModelConfig(
    #     1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
    # ),  # cross, 1
    # "mla_2_2": ModelConfig(
    #     1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
    # ),  # cross, 1
    # "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),  # inference
    # "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),  # inference
    # "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),  # inference
    # "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),  # inference
407
    "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),  # inference
408
409
410
411
412
413
414
415
416
417
418
419
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
    """Test DotProductAttention module with Multi-Latent Attention (MLA)"""
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


420
model_configs_mask = {
421
    # test: ModelConfig(b, sq, hq, dqk)
422
423
424
425
426
427
    "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
    "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "mask_2_1": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
428
    ),
429
430
431
432
433
434
435
436
437
438
    "mask_2_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
    ),
    "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
439
    "mask_5_1": ModelConfig(
440
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
441
442
    ),
    "mask_5_2": ModelConfig(
443
444
445
446
447
448
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
    "mask_7_0": ModelConfig(
        2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
449
    ),
450
451
452
453
454
455
456
    "mask_7_1": ModelConfig(
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
    ),
    "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
    "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
    "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
457
    "mask_10_0": ModelConfig(
458
        2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
459
    ),
460
    "mask_10_1": ModelConfig(
461
        2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
462
    ),
463
}
464

465

466
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
467
468
469
470
471
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
    """Test DotProductAttention module with different mask types"""
472
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
473

474

475
model_configs_bias = {
476
    # test: ModelConfig(b, sq, hq, dqk)
477
478
479
480
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
481
482
    "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),
    "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"),
483
484
    "bias_2_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
485
    ),
486
487
488
489
490
491
492
493
    "bias_2_1": ModelConfig(
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
494
    ),
495
    "bias_2_2": ModelConfig(
496
        4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
497
    ),
498
    "bias_2_3": ModelConfig(
499
500
501
502
503
504
505
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding",
        attn_bias_type="post_scale_bias",
506
507
    ),
    "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"),
508
509
    "bias_2_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
510
    ),
511
512
513
514
515
516
517
518
519
    "bias_3_0": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_1": ModelConfig(
        2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "bias_3_2": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
520
    "bias_3_3": ModelConfig(
521
522
523
524
525
526
527
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="causal",
        attn_bias_type="post_scale_bias",
528
    ),
529
530
531
    "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
    "bias_3_5": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
532
    ),
533
    "bias_4_0": ModelConfig(
534
        4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
535
    ),
536
    "bias_4_1": ModelConfig(
537
538
539
540
541
542
543
        2,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
544
    ),
545
    "bias_4_2": ModelConfig(
546
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
547
    ),
548
    "bias_4_3": ModelConfig(
549
550
551
552
553
554
555
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
556
    ),
557
558
    "bias_4_4": ModelConfig(
        4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
559
    ),
560
561
562
563
564
565
566
567
    "bias_4_5": ModelConfig(
        2,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="alibi",
568
    ),
569
}
570

571

572
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
573
574
575
576
577
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types"""
578
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
579

580

581
model_configs_bias_shapes = {
582
    # test: ModelConfig(b, sq, hq, dqk)
583
584
585
586
587
    "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
    "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
    "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
    "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
    "bias_1_4": ModelConfig(
588
        4,
589
590
        2048,
        24,
591
        128,
592
593
594
595
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="1hss",
        alibi_type="custom",
596
597
    ),
    "bias_1_5": ModelConfig(
598
599
600
601
602
603
604
605
        2,
        2048,
        24,
        128,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        bias_shape="bhss",
        alibi_type="custom",
606
    ),
607
608
}

609

610
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
611
612
613
614
615
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types and shapes"""
616
617
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)

618

619
model_configs_swa = {
620
    # test: ModelConfig(b, sq, hq, dqk)
621
622
623
624
625
626
627
628
629
630
631
632
    "swa_1_1": ModelConfig(2, 2048, 16, 64),
    "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
    "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
    "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
    "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
    "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
    "swa_3_2": ModelConfig(
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
    ),
    "swa_3_3": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
633
    ),
634
635
636
637
638
639
640
    "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
    "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
    "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
641
    "swa_6_2": ModelConfig(
642
        2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
643
644
    ),
    "swa_6_3": ModelConfig(
645
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
646
    ),
647
}
648
649


650
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
651
652
653
654
655
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
    """Test DotProductAttention module with sliding window attention"""
656
657
    test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)

658

659
model_configs_alibi_slopes = {
660
    # test: ModelConfig(b, sq, hq, dqk)
661
662
663
664
665
666
667
668
669
670
671
672
673
    "alibi_1_0": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
    ),
    "alibi_1_1": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="vanilla",
    ),
674
    "alibi_2_0": ModelConfig(
675
        2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
676
677
    ),
    "alibi_2_1": ModelConfig(
678
679
680
681
682
683
684
685
        1,
        1024,
        24,
        128,
        max_seqlen_kv=2048,
        attn_mask_type="causal",
        attn_bias_type="alibi",
        alibi_type="custom",
686
    ),
687
}
688
689


690
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
691
692
693
694
695
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
    """Test DotProductAttention module with ALiBi slopes"""
696
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
697

698

699
qkv_layouts = [
700
701
702
703
704
705
706
707
708
709
710
    "sb3hd",
    "sbh3d",
    "sbhd_sb2hd",
    "sbhd_sbh2d",
    "sbhd_sbhd_sbhd",
    "bs3hd",
    "bsh3d",
    "bshd_bs2hd",
    "bshd_bsh2d",
    "bshd_bshd_bshd",
]
711

712

713
model_configs_layout = {
714
    # test: ModelConfig(b, sq, hq, dqk)
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
    "layout_0_0": ModelConfig(2, 128, 16, 64),
    "layout_0_1": ModelConfig(
        2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "layout_0_3": ModelConfig(
        1,
        128,
        16,
        64,
        max_seqlen_kv=256,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_1_0": ModelConfig(2, 2048, 24, 128),
    "layout_1_1": ModelConfig(
        2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_3": ModelConfig(
        1,
        2048,
        24,
        128,
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal",
        attn_bias_type="post_scale_bias",
    ),
    "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
    "layout_2_1": ModelConfig(
        2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
747
748
}

749

750
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
751
@pytest.mark.parametrize("dtype", param_types_lean)
752
753
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
754
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
755
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
756
    """Test DotProductAttention module with different QKV layouts"""
757
758
759
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


760
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
761
model_configs_layout_thd = {
762
    # test: ModelConfig(b, sq, hq, dqk)
763
764
765
766
767
768
769
    "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
    "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
    "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
    "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
    "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
    "layout_1_2": ModelConfig(
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
770
    ),
771
    "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
772
    "layout_2_1": ModelConfig(
773
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
774
775
    ),
    "layout_2_2": ModelConfig(
776
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
777
    ),
778
    "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
779
    "layout_3_1": ModelConfig(
780
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
781
782
    ),
    "layout_3_2": ModelConfig(
783
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
784
    ),
785
    "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
786
    "layout_4_1": ModelConfig(
787
        2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
788
789
    ),
    "layout_4_2": ModelConfig(
790
        2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
791
792
    ),
    "layout_5_0": ModelConfig(
793
        2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
794
795
    ),
    "layout_5_1": ModelConfig(
796
797
798
799
800
801
802
        2,
        2048,
        24,
        128,
        num_gqa_groups=1,
        attn_mask_type="padding_causal_bottom_right",
        window_size=(4, 0),
803
804
805
    ),
    "layout_5_2": ModelConfig(
        2,
806
        2048,
807
808
        24,
        128,
809
810
        max_seqlen_kv=4096,
        attn_mask_type="padding_causal_bottom_right",
811
812
        window_size=(4, 0),
    ),
813
814
815
}


816
817
818
819
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
    get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
820
821
822
823
824
825
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
    """Test DotProductAttention module with different QKV layouts"""
826
827
828
    config = model_configs[model]
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")
829
    logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
830
    pad_between_seqs = True
831
832
833
    test_dot_product_attention(
        dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
    )
834
    if get_cudnn_version() >= (9, 3, 0):
835
        logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
836
837
838
839
840
        # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
        pad_between_seqs = False
        test_dot_product_attention(
            dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
        )
841

842

843
def _run_dot_product_attention(
844
845
846
847
848
849
850
851
852
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_layout: str,
    workspace_opt: bool,
    pad_between_seqs: bool,
    is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
853
854
855
    """Run DotProductAttention module with one forward pass and one backward pass"""
    # Set RNG and environment varables
    reset_rng_states()
856
857
858
859
860
861
862
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
863
    _attention_backends["backend_selection_requires_update"] = True
864

865
    # Create seqlens
866
867
868
869
870
871
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
872
            seqlens_kv = seqlens_q
873
        if config.attn_type == "cross":
874
875
876
877
878
879
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
880
881
882
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
883
    else:
884
885
886
887
888
889
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
890
891
892
893
894
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

895
896
897
898
899
900
901
    seqlens_q_after_pad = seqlens_q.clone()
    seqlens_kv_after_pad = seqlens_kv.clone()
    cu_seqlens_q_after_pad = cu_seqlens_q.clone()
    cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
    pad_len = [0] * config.batch_size
    if pad_between_seqs:
        max_pad_len = 3
902
        pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda")  # 3
903
904
905
906
907
        seqlens_q_after_pad = seqlens_q + pad_len
        seqlens_kv_after_pad = seqlens_kv + pad_len
        cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
        cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)

908
909
910
    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
911
        if config.attn_type == "self":
912
913
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
914
915
916
917
918
919
920
921
922
923
924
925
926
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
927
            attention_mask = attention_mask_q.to(device="cuda")
928
        if config.attn_type == "cross":
929
930
931
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
                attention_mask_q = torch.cat(
                    [
                        attention_mask_q,
                        torch.Tensor(
                            [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
                attention_mask_kv = torch.cat(
                    [
                        attention_mask_kv,
                        torch.Tensor(
                            [False] * seqlens_kv[i]
                            + [True] * (config.max_seqlen_kv - seqlens_kv[i])
                        )
                        .to(dtype=torch.bool)
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .unsqueeze(0),
                    ],
                    dim=0,
                )
959
            attention_mask = (
960
961
962
                attention_mask_q.to(device="cuda"),
                attention_mask_kv.to(device="cuda"),
            )
963

964
    alibi_slopes = None
965
966
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
        if config.bias_shape == "1hss":
967
968
969
            alibi_slopes = (
                torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
            )
970
        if config.bias_shape == "bhss":
971
972
973
974
975
            alibi_slopes = (
                torch.randn(config.batch_size, config.num_heads)
                .abs()
                .to(dtype=torch.float32, device="cuda")
            )
976

977
978
    # Create input tensors
    dim_to_num = {
979
980
981
982
983
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
984
985
        "dqk": config.head_dim_qk,
        "dv": config.head_dim_v,
986
987
988
989
990
991
        "t": cu_seqlens_q_after_pad[-1],
        "tg": cu_seqlens_kv_after_pad[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
992
    inp = []
993
    inp_orig = []
994
995
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
996
        if i == 0:
997
            layout = layout.replace("s", "sq")
998
        else:
999
1000
1001
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
1002
1003
1004
1005
        if i == 2:
            layout = layout.replace("d", "dv")
        else:
            layout = layout.replace("d", "dqk")
1006
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
1007
        tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
1008
1009
        # tensor: with padding tokens
        # tensor_orig: without padding tokens
1010
        tensor_orig = tensor
1011
1012
        if qkv_format == "thd" and pad_between_seqs:
            tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1013
            if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_q_after_pad[i - 1],
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_q_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
1027
            if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
                for i in range(1, config.batch_size + 1):
                    valid_range = (
                        cu_seqlens_kv_after_pad[i - 1],
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                    )
                    pad_range = (
                        cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                        cu_seqlens_kv_after_pad[i],
                    )
                    tensor[pad_range[0] : pad_range[1]] = 0.0
                    tensor_orig = torch.cat(
                        [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
                    )
1041
1042
        tensor_count = 1
        split_dim = 0
1043
        for dim, l in enumerate(layout.split("_")):
1044
1045
1046
1047
1048
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
1049
1050
1051
        tensors_orig = (
            torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
        )
1052
1053
1054
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
1055
                inp_orig.append(tensors_orig[j].squeeze(split_dim))
1056
1057
            else:
                inp.append(tensors[j])
1058
                inp_orig.append(tensors_orig[j])
1059
    for i in range(3):
1060
        inp[i].requires_grad = True
1061
1062
        inp_orig[i].requires_grad = True

1063
    # Create output gradient
1064
1065
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
1066
    qkv_format_kv = qkv_format_kv.replace("d", "dv")
1067
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
1068
1069
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
    out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
1070
    out_grad_orig = out_grad
1071
1072
    if qkv_format == "thd" and pad_between_seqs:
        out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1073
        if qkv_format_kv == "t_h_dv":
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
            for i in range(1, config.batch_size + 1):
                valid_range = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
                out_grad[pad_range[0] : pad_range[1]] = 0.0
                out_grad_orig = torch.cat(
                    [out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
                )
1084

1085
    # Create bias
1086
    if config.attn_bias_type in ["no_bias", "alibi"]:
1087
        bias = None
1088
1089
1090
1091
    if config.attn_bias_type == "post_scale_bias":
        shape = "_".join(config.bias_shape)
        shape = shape.replace("_s_s", "_sq_skv")
        tensor_shape = [dim_to_num[j] for j in shape.split("_")]
1092
        bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
1093
        if config.bias_shape != "1hss":
1094
            bias.requires_grad = False
1095
1096
1097
1098

    # Create RNG
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1099

1100
1101
1102
1103
1104
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    # Set up model
1105
1106
    block = DotProductAttention(
        config.num_heads,
1107
        (config.head_dim_qk, config.head_dim_v),
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
        num_gqa_groups=config.num_gqa_groups,
        attention_dropout=config.dropout_p,
        qkv_format=qkv_format,
        attn_mask_type=config.attn_mask_type,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type=config.attn_type,
1118
        softmax_type=config.softmax_type,
1119
        return_max_logit=config.return_max_logit,
1120
    ).to(dtype=dtype, device="cuda")
1121
1122
    if not is_training:
        block = block.eval()
1123
1124
    if is_training and config.softmax_type != "vanilla":
        block.softmax_offset.requires_grad = True
1125

1126
    # Run a forward and backward pass
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        q = inp_orig[0]
        k = inp_orig[1]
        v = inp_orig[2]
        d_out = out_grad_orig
    if backend == "FusedAttention":
        q = inp[0]
        k = inp[1]
        v = inp[2]
        d_out = out_grad
1137
1138
1139
1140
    out = block(
        q,
        k,
        v,
1141
        window_size=config.window_size,
1142
1143
1144
1145
1146
1147
        attention_mask=attention_mask,
        qkv_format=qkv_format,
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1148
1149
        cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
        cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
1150
1151
1152
1153
1154
1155
1156
        attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=config.attn_bias_type,
        core_attention_bias=bias,
        alibi_slopes=alibi_slopes,
        fast_zero_fill=True,
    )
1157
1158
1159
    max_logit = None
    if config.return_max_logit:
        out, max_logit = out
1160
1161
    if is_training:
        out.backward(d_out)
1162

1163
1164
1165
    d_softmax_offset = None
    if is_training and config.softmax_type != "vanilla":
        d_softmax_offset = block.softmax_offset.grad
1166

1167
1168
    if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
        if is_training:
1169
            return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
1170
        else:
1171
            return out, max_logit, (None, None, None, d_softmax_offset)
1172
    if backend == "FusedAttention":
1173
1174
        if qkv_format == "thd" and pad_between_seqs:
            out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1175
1176
1177
1178
            if is_training:
                q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
                v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
            for i in range(1, config.batch_size + 1):
                valid_range_q = (
                    cu_seqlens_q_after_pad[i - 1],
                    cu_seqlens_q_after_pad[i] - pad_len[i - 1],
                )
                valid_range_kv = (
                    cu_seqlens_kv_after_pad[i - 1],
                    cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
                )
                out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
                if is_training:
                    q_grad_orig = torch.cat(
                        [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
                    )
                    k_grad_orig = torch.cat(
                        [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
                    v_grad_orig = torch.cat(
                        [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
                    )
1199
            if is_training:
1200
1201
1202
1203
1204
                return (
                    out_orig,
                    max_logit,
                    (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
                )
1205
            else:
1206
                return out_orig, max_logit, (None, None, None, d_softmax_offset)
1207
1208
        else:
            if is_training:
1209
                return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
1210
            else:
1211
                return out, max_logit, (None, None, None, d_softmax_offset)
1212

1213

1214
model_configs_te_layer = {
1215
    # test: ModelConfig(b, sq, hq, dqk)
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
    "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
    "te_1_1": ModelConfig(
        4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
    ),
    "te_1_2": ModelConfig(
        2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
    ),
    "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
    "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
    "te_2_1": ModelConfig(2, 2048, 16, 64),
    "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
    "te_2_3": ModelConfig(
        1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
    ),
    "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
    "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
1232
}
1233

1234

1235
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1236
@pytest.mark.parametrize("dtype", param_types)
1237
1238
1239
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
1240
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
1241
1242
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
1243
1244
1245
def test_transformer_layer(
    dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
1246
    """Test TransformerLayer module"""
1247

Tim Moon's avatar
Tim Moon committed
1248
    # Get configs
1249
    config = model_configs[model]
1250
    tols = dict(atol=5e-2, rtol=5e-2)
1251
    workspace_opt = True
1252

1253
    # Test backend availability
1254
    is_training = True
1255
    available_backends, _, fused_attn_backends = get_available_attention_backends(
Tim Moon's avatar
Tim Moon committed
1256
        config,
1257
        qkv_dtype=dtype,
1258
1259
1260
        qkv_layout=(
            qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
        ),
1261
        is_training=is_training,
Tim Moon's avatar
Tim Moon committed
1262
    )
1263
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1264
1265
    if not fused_attn_supported:
        is_training = False
1266
        available_backends, _, fused_attn_backends = get_available_attention_backends(
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
            config,
            qkv_dtype=dtype,
            qkv_layout=(
                qkv_format.replace("hd", "h3d")
                if fused_qkv_params
                else qkv_format.replace("hd", "3hd")
            ),
            is_training=is_training,
        )
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1277
1278
1279

    # Skip if only unfused backend is supported
    if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
1280
        pytest.skip("Less than two backends to compare.")
1281
1282
1283
    # Skip if qkv_format = thd and "padding" not in attn_mask_type
    if qkv_format == "thd" and "padding" not in config.attn_mask_type:
        pytest.skip("THD requires padding mask.")
Tim Moon's avatar
Tim Moon committed
1284
1285

    # UnfusedDotProductAttention backend
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_format,
            workspace_opt,
            fused_qkv_params,
            RoPE,
1296
            is_training,
1297
        )
Tim Moon's avatar
Tim Moon committed
1298
1299
1300
1301
1302
1303
1304

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FusedAttention",
1305
1306
1307
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1308
1309
            fused_qkv_params,
            RoPE,
1310
            is_training,
Tim Moon's avatar
Tim Moon committed
1311
        )
1312

Tim Moon's avatar
Tim Moon committed
1313
1314
1315
1316
1317
1318
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FlashAttention",
1319
1320
1321
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
1322
1323
            fused_qkv_params,
            RoPE,
1324
            is_training,
Tim Moon's avatar
Tim Moon committed
1325
        )
1326

1327
    logging.info(f"[test_transformer_layer]: is_training = {is_training}")
1328
    if unfused_attn_supported and fused_attn_supported:
1329
        logging.info("[test_transformer_layer]: unfused attn vs fused attn")
1330
1331
1332
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
    if unfused_attn_supported and flash_attn_supported:
1333
        logging.info("[test_transformer_layer]: unfused attn vs flash attn")
Tim Moon's avatar
Tim Moon committed
1334
1335
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
1336
    if fused_attn_supported and flash_attn_supported:
1337
        logging.info("[test_transformer_layer]: fused attn vs flash attn")
1338
1339
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
1340

1341

1342
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1343
@pytest.mark.parametrize("dtype", param_types_lean)
1344
1345
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
1346
1347
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format):
hugo-syn's avatar
hugo-syn committed
1348
    """Test TransformerLayer module with miscellaneous settings"""
1349
1350
1351
    ckpt_attn = True
    fused_qkv_params = True
    RoPE = True
1352
1353
1354
    test_transformer_layer(
        dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
    )
1355

1356

1357
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
1358
1359
1360
1361
1362
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
    """Test TransformerLayer module with MQA/GQA"""
1363

1364
    def find_factors(x):
1365
1366
1367
1368
1369
        f = []
        for i in range(2, x + 1):
            if x % i == 0:
                f.append(i)
        return f
1370

1371
1372
1373
1374
1375
1376
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    config = model_configs[model]
    num_querys_per_gqa_group = find_factors(config.num_heads)
1377
1378

    for num_q_per_gqa_group in num_querys_per_gqa_group:
1379
1380
1381
1382
        config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
        test_transformer_layer(
            dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
        )
1383

1384

1385
def _run_transformer_layer(
1386
1387
1388
1389
1390
1391
1392
1393
    dtype: torch.dtype,
    config: ModelConfig,
    backend: str,
    ckpt_attn: bool,
    qkv_format: str,
    workspace_opt: bool,
    fused_qkv_params: bool,
    RoPE: bool,
1394
    is_training: bool,
1395
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
1396
1397
1398
    """Run TransformerLayer module with one forward pass and one backward pass"""

    # Set RNG and environment variables
1399
    reset_rng_states()
1400
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
1401
    os.environ["NVTE_FUSED_ATTN"] = "0"
1402
1403
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
1404
1405
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
1406
    _attention_backends["backend_selection_requires_update"] = True
1407

1408
    # Create input tensor
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
    if qkv_format == "sbhd":
        inp = torch.randn(
            config.max_seqlen_q,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.max_seqlen_kv,
            config.batch_size,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1426
    if qkv_format == "bshd":
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
        inp = torch.randn(
            config.batch_size,
            config.max_seqlen_q,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            config.batch_size,
            config.max_seqlen_kv,
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1443
1444

    # Create seqlens
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            if config.max_seqlen_q > 1:
                seqlens_q = torch.randint(
                    1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
                )
            else:
                seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
1461
    else:
1462
1463
1464
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
    if qkv_format == "thd":
        inp = torch.randn(
            cu_seqlens_q[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
        inp_enc = torch.randn(
            cu_seqlens_kv[-1],
            config.hidden_size,
            dtype=dtype,
            device="cuda",
            requires_grad=True,
        )
1487
1488
1489
1490
1491
1492
1493

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
1494
    drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
1495

1496
    # Create bias
1497
    bias = None
1498
1499
1500
1501
1502
1503
1504
1505
1506
    if config.attn_bias_type == "post_scale_bias":
        bias = torch.randn(
            1,
            config.num_heads,
            config.max_seqlen_q,
            config.max_seqlen_kv,
            dtype=dtype,
            device="cuda",
        )
1507
1508
1509
1510

    # Create RoPE
    rotary_pos_emb = None
    if RoPE:
1511
        PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
1512
        rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1513
1514

    # Set up model
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
    block = TransformerLayer(
        config.hidden_size,
        4 * config.hidden_size,
        config.num_heads,
        num_gqa_groups=config.num_gqa_groups,
        layernorm_epsilon=1e-5,
        hidden_dropout=0.0,
        attention_dropout=config.dropout_p,
        init_method=init_method,
        output_layer_init_method=output_layer_init_method,
        layer_number=layer_number,
1526
        kv_channels=config.head_dim_qk,
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
        self_attn_mask_type=config.attn_mask_type,
        tp_group=None,
        tp_size=1,
        params_dtype=dtype,
        get_rng_state_tracker=None,
        fuse_wgrad_accumulation=False,
        seq_length=config.max_seqlen_q,
        micro_batch_size=config.batch_size,
        sequence_parallel=False,
        apply_residual_connection_post_layernorm=False,
        output_layernorm=False,
1538
        layer_type="encoder" if config.attn_type == "self" else "decoder",
1539
1540
1541
1542
1543
1544
1545
1546
1547
        drop_path_rate=drop_path_rates[layer_number - 1],
        set_parallel_mode=True,
        fuse_qkv_params=fused_qkv_params,
        zero_centered_gamma=False,
        qkv_weight_interleaved=False,
        ub_tp_comm_overlap=False,
        bias=True,
        attn_input_format=qkv_format,
    ).to(dtype=dtype, device="cuda")
1548
1549
    if not is_training:
        block = block.eval()
1550

1551
1552
1553
    # Create ALiBi slopes
    alibi_slopes = None
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
1554
        alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
1555

1556
    # Run a forward and backward pass
1557
1558
    out = block(
        inp,
1559
        self_attn_mask_type=config.attn_mask_type,
1560
1561
        encoder_output=inp_enc if config.attn_type == "cross" else None,
        enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
1562
1563
1564
        checkpoint_core_attention=False,
        rotary_pos_emb=rotary_pos_emb,
        core_attention_bias_type=config.attn_bias_type,
1565
        core_attention_bias=bias,
1566
        alibi_slopes=alibi_slopes,
1567
1568
1569
1570
        max_seqlen_q=config.max_seqlen_q,
        max_seqlen_kv=config.max_seqlen_kv,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_kv=cu_seqlens_kv,
1571
    )
1572
1573
1574
    if is_training:
        loss = out.sum()
        loss.backward()
1575
1576

    return out, inp.grad
1577
1578


1579
model_configs_fp8_extra_state = {
1580
    # test: ModelConfig(b, sq, hq, dqk)
1581
1582
1583
1584
    "large": ModelConfig(2, 128, 4, 128, num_layers=1),
}


1585
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1586
1587
1588
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
1589
1590
def test_dpa_fp8_extra_state(model, dtype):
    """Test DotProductAttention module in FP8 with checkpointing"""
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
    config = model_configs_fp8_extra_state[model]
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="sb3hd",
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not fused_attn_supported and not flash_attn_supported:
        pytest.skip("No attention backend available.")

1604
1605
1606
    outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False)
    outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True)
    outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state(
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
        dtype, config, mimic_v1_6=True, checkpoint=True
    )

    # Check that results match
    tols = dtype_tols(dtype)
    if dtype in (torch.float16, torch.bfloat16):
        tols.update(dict(rtol=2e-2, atol=2e-3))
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )
    for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
        torch.testing.assert_close(
            test,
            ref,
            **tols,
        )


1628
1629
def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
    """Run DotProductAttention module in FP8 with checkpointing"""
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
    steps = 10
    path = "checkpoint.pt"
    fp8_enabled = True
    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
        fp8_dpa=fp8_enabled,
        fp8_mha=False,
    )

    reset_rng_states()
    hidden_states = torch.randn(
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )

    def get_model(dtype, config):
        sigma = 0.023
        init_method = init_method_normal(sigma)
        output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

1655
        with quantized_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
            block = TransformerLayer(
                config.hidden_size,
                4 * config.hidden_size,
                config.num_heads,
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                hidden_dropout=0.0,
                attention_dropout=0.0,
                fuse_qkv_params=True,
                params_dtype=dtype,
                device="cuda",
            )
        return block

    block = get_model(dtype, config)
    for i in range(steps // 2):
1672
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    if checkpoint:
        sd = block.state_dict()
        if mimic_v1_6:
            sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
                "self_attention.core_attention._extra_state"
            ]
            del sd["self_attention.core_attention._extra_state"]
        torch.save(sd, path)

        param_grads = []
        for p in block.parameters():
            if p.requires_grad:
                param_grads.append(p.grad.clone())

        _cpu_rng_state_new = torch.get_rng_state()
        _cuda_rng_state_new = torch.cuda.get_rng_state()

        del block
        block = get_model(dtype, config)
        block.load_state_dict(torch.load(path, weights_only=False))
        torch.set_rng_state(_cpu_rng_state_new)
        torch.cuda.set_rng_state(_cuda_rng_state_new)

        for p in block.parameters():
            if p.requires_grad:
                p.grad = param_grads.pop(0)

        assert not param_grads, "Oops!"

    for i in range((steps + 1) // 2):
1707
        with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
            output = block(hidden_states, None)
            loss = output.sum()
            loss.backward()

    torch.cuda.synchronize()

    if os.path.exists(path):
        os.remove(path)

    outputs = [output, hidden_states.grad]
    for p in block.parameters():
        if p.requires_grad:
            outputs.append(p.grad)

    return outputs


1725
model_configs_fp8_vs_f16 = {
1726
    # test: ModelConfig(b, sq, hq, dqk)
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
    "fp8_9": ModelConfig(2, 2048, 16, 128),
    "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
    "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
    "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
    "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
    "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
    "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
    "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
    "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
    "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
    "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
    "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
1739
}
1740

1741
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
1742
1743
1744
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]

1745

1746
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1747
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1748
1749
1750
1751
1752
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1753
@pytest.mark.parametrize("RoPE", [True, False])
1754
@pytest.mark.parametrize("is_training", [True, False])
1755
1756
1757
1758
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_mha_fp8_vs_f16(
    dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
):
1759
    """Test MultiHeadAttention module in FP8"""
1760
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
1761
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
1762
1763
    config = model_configs_fp8_vs_f16[model]

1764
    # Test backend availability
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
            fp8_mha=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
            fp8_mha=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
1782
1783
1784
1785
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_format.replace("hd", "h3d"),
1786
1787
        fp8=True,
        fp8_meta=fp8_meta,
1788
1789
1790
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1791
1792
    if flash_attn_supported + fused_attn_supported < 1:
        pytest.skip("No FP8 attention backend available.")
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_format.replace("hd", "h3d"),
            is_training=is_training,
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")

    if flash_attn_supported:
1805
1806
1807
1808
1809
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
        flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1810
            dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1811
        )
1812

1813
1814
1815
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
    _attention_backends["backend_selection_requires_update"] = True
1816
    logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
1817
    fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1818
        dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1819
    )
1820
1821

    logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
1822
    fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
1823
        dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1824
1825
    )

1826
1827
1828
    atol = 5e-1
    rtol = 5e-1
    rmse_tol = 0.15
1829
    if flash_attn_supported:
1830
1831
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
1832
        compare_and_assert(
1833
1834
1835
1836
1837
1838
1839
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
1840
            True,
1841
        )
1842
1843
    logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
    logging.debug("========== {:^25s} ==========".format("forward output"))
1844
    compare_and_assert(
1845
1846
1847
1848
1849
1850
1851
        fused_attn_fwd_fp8,
        fused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "fused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
1852
        True,
1853
    )
1854

1855
1856
1857
    if is_training:
        for i in range(len(param_names[:1])):
            logging.debug("========== {:^25s} ==========".format(param_names[i]))
1858
            compare_and_assert(
1859
1860
1861
1862
1863
1864
1865
                fused_attn_bwd_fp8[i],
                fused_attn_bwd_f16[i],
                f"fused_attn_bwd_fp8[{i}]",
                f"fused_attn_bwd_f16[{i}]",
                atol,
                rtol,
                rmse_tol,
1866
                True,
1867
1868
            )

1869

1870
1871
1872
def _run_mha_fp8_vs_f16(
    dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
):
1873
    """Run MultiHeadAttention module in FP8"""
1874
1875
1876
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
1877

1878
1879
1880
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
Tim Moon's avatar
Tim Moon committed
1881

1882
    with quantized_model_init(enabled=fp8_mha, recipe=fp8_recipe):
1883
1884
1885
1886
        rotary_pos_emb = None
        if RoPE:
            PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
            rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
1887
        mha = MultiheadAttention(
1888
1889
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_heads,
1890
            kv_channels=config.head_dim_qk,
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            layer_number=1,
            bias=True,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            params_dtype=dtype,
            input_layernorm=input_layernorm,
            fuse_qkv_params=True,
            attention_type="self",
            qkv_weight_interleaved=True,
            qkv_format=qkv_format,
1902
        ).to(dtype=dtype, device="cuda")
1903
1904
        if not is_training:
            mha = mha.eval()
1905

1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
1926
1927
1928
1929
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
1930

1931
    dim_to_num = {
1932
1933
1934
1935
1936
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
1937
        "d": config.head_dim_qk,
1938
1939
1940
1941
1942
1943
1944
1945
1946
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
    layout = "_".join(qkv_format)
    layout = layout.replace("s", "sq")
    tensor_shape = [dim_to_num[j] for j in layout.split("_")]
1947
1948
    tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
    hidden_states = tensor.view(*tensor.shape[:-2], -1)
1949
1950
    if is_training:
        hidden_states.requires_grad = True
1951
1952
1953
    tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
    out_grad = tensor.view(*tensor.shape[:-2], -1)

1954
    with autocast(enabled=fp8_mha, recipe=fp8_recipe):
1955
1956
        out = mha(
            hidden_states,
1957
1958
1959
1960
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
            is_first_microbatch=None,
1961
            rotary_pos_emb=rotary_pos_emb,
1962
1963
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
1964
        )
1965
1966
    if is_training:
        out.backward(out_grad)
Tim Moon's avatar
Tim Moon committed
1967

1968
    param_names = []
1969
    param_names.append("hidden_states.grad")
1970
1971
1972
1973
    params = []
    params.append(hidden_states)
    for name, param in mha.named_parameters():
        if param.requires_grad:
1974
            param_names.append(name + ".grad")
1975
            params.append(param)
1976

1977
1978
1979
    if is_training:
        return out, param_names, tuple(x.grad for x in params)
    return out, param_names, tuple(None for x in params)
1980

1981

1982
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
1983
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
1984
1985
1986
1987
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
1988
@pytest.mark.parametrize("is_training", [True, False])
1989
1990
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
1991
    """Test DotProductAttention module in FP8"""
1992
1993
    config = model_configs_fp8_vs_f16[model]

1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
    # TODO(cyang): think of another way to verify dropout results
    # test cuDNN FP8 dropout
    # 1. we modify the config here to not affect mha_fp8_vs_f16 tests
    # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an
    #    indirect verification method, we create Q/K/V as all 1s and check if O is all 1s
    # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell
    # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type:
    #    if get_device_compute_capability() >= (10, 0):
    #        config.dropout_p = 0.1

2004
    os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
2005
    os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
2006
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"
2007

2008
    # Test backend availability
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
    if scaling_mode == "delayed":
        fp8_recipe = recipe.DelayedScaling(
            margin=0,
            fp8_format=recipe.Format.HYBRID,
            amax_history_len=1,
            amax_compute_algo="most_recent",
            fp8_dpa=True,
        )
    elif scaling_mode == "current":
        fp8_recipe = recipe.Float8CurrentScaling(
            fp8_format=recipe.Format.HYBRID,
            fp8_dpa=True,
        )
    fp8_meta = {}
    fp8_meta["recipe"] = fp8_recipe
2024
2025
2026
2027
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout=qkv_layout,
2028
2029
        fp8=True,
        fp8_meta=fp8_meta,
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if flash_attn_supported + fused_attn_supported < 1:
        pytest.skip("No FP8 attention backend available.")
    if not fp8_dpa_bwd:
        available_backends, _, fused_attn_backends = get_available_attention_backends(
            config,
            qkv_dtype=dtype,
            qkv_layout=qkv_layout,
            is_training=is_training,
        )
        _, fused_attn_supported, _ = available_backends
        if not fused_attn_supported:
            pytest.skip("No attention backend available.")
    if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
        pytest.skip("qkv_layout not applicable for MQA/GQA")

    if flash_attn_supported:
2049
2050
2051
        os.environ["NVTE_FLASH_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
2052
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
2053
        flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
            dtype, config, True, qkv_layout, is_training, fp8_recipe
        )

    if unfused_attn_supported:
        os.environ["NVTE_FLASH_ATTN"] = "0"
        os.environ["NVTE_FUSED_ATTN"] = "0"
        _attention_backends["backend_selection_requires_update"] = True
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
        unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
            dtype, config, True, qkv_layout, is_training, fp8_recipe
2064
        )
2065

2066
2067
2068
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
    _attention_backends["backend_selection_requires_update"] = True
2069
    logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
2070
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
2071
        dtype, config, True, qkv_layout, is_training, fp8_recipe
2072
    )
2073

2074
2075
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "1"
2076
2077
    if config.dropout_p == 0.0:
        # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
2078
        logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
2079
        fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
2080
            dtype, config, False, qkv_layout, is_training, fp8_recipe
2081
        )
2082

2083
2084
    atol = 5e-1
    rtol = 5e-2
2085
    rmse_tol = 0.11
2086
    bwd_names = ["dq", "dk", "dv"]
2087
    if flash_attn_supported:
2088
2089
        logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2090
        compare_and_assert(
2091
2092
2093
2094
2095
2096
2097
            flash_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "flash_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2098
            True,
2099
        )
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
    if unfused_attn_supported:
        logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
        compare_and_assert(
            unfused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "unfused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
            True,
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
                compare_and_assert(
                    unfused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"unfused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
                    True,
                )
2126
2127
2128
2129
2130
2131
    if config.dropout_p != 0.0:
        # test cuDNN FP8 dropout
        assert torch.all(
            fused_attn_fwd_fp8 == 1
        ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
    else:
2132
2133
        logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
        logging.debug("========== {:^25s} ==========".format("forward output"))
2134
        compare_and_assert(
2135
2136
2137
2138
2139
2140
2141
            fused_attn_fwd_fp8,
            fused_attn_fwd_f16,
            "fused_attn_fwd_fp8",
            "fused_attn_fwd_f16",
            atol,
            rtol,
            rmse_tol,
2142
            True,
2143
2144
2145
2146
        )
        if is_training:
            for i, _ in enumerate(fused_attn_bwd_f16):
                logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
2147
                compare_and_assert(
2148
2149
2150
2151
2152
2153
2154
                    fused_attn_bwd_fp8[i],
                    fused_attn_bwd_f16[i],
                    f"fused_attn_bwd_fp8[{i}]",
                    f"fused_attn_bwd_f16[{i}]",
                    atol,
                    rtol,
                    rmse_tol,
2155
                    True,
2156
                )
2157
    os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0"
2158
2159


2160
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe):
2161
    """Run DotProductAttention module in FP8"""
2162
2163
2164
    reset_rng_states()
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
2165

2166
2167
2168
2169
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

2170
    qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
2171
    with quantized_model_init(enabled=fp8_dpa):
2172
2173
        dpa = DotProductAttention(
            config.num_heads,
2174
            config.head_dim_qk,
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
            num_gqa_groups=config.num_gqa_groups,
            attention_dropout=config.dropout_p,
            sequence_parallel=False,
            tp_size=1,
            get_rng_state_tracker=get_dummy_cuda_rng_tracker,
            tp_group=None,
            layer_number=1,
            attention_type="self",
            qkv_format=qkv_format,
        ).to(dtype=dtype, device="cuda")
2185
2186
        if not is_training:
            dpa = dpa.eval()
2187

2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
    if "padding" in config.attn_mask_type or qkv_format == "thd":
        if config.attn_type == "self":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = seqlens_q
        if config.attn_type == "cross":
            seqlens_q = torch.randint(
                1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
            )
            seqlens_kv = torch.randint(
                1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
            )
    else:
        seqlens_q = torch.full(
            [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
        )
        seqlens_kv = torch.full(
            [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
        )
2208
2209
2210
2211
2212
2213
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

    dim_to_num = {
2214
2215
2216
2217
2218
        "b": config.batch_size,
        "sq": config.max_seqlen_q,
        "skv": config.max_seqlen_kv,
        "h": config.num_heads,
        "hg": config.num_gqa_groups,
2219
        "d": config.head_dim_qk,
2220
2221
2222
2223
2224
2225
        "t": cu_seqlens_q[-1],
        "tg": cu_seqlens_kv[-1],
        "3": 3,
        "2": 2,
        "1": 1,
    }
2226
    inp = []
2227
2228
    for i, layout in enumerate(qkv_layout.split("_")):
        layout = "_".join(layout)
2229
        if i == 0:
2230
            layout = layout.replace("s", "sq")
2231
        else:
2232
2233
2234
2235
            layout = layout.replace("s", "skv")
            layout = layout.replace("h", "hg")
            layout = layout.replace("t", "tg")
        tensor_shape = [dim_to_num[j] for j in layout.split("_")]
2236
2237
2238
2239
2240
        if config.dropout_p == 0.0:
            tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
        else:
            # test cuDNN FP8 dropout
            tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda")
2241
2242
        tensor_count = 1
        split_dim = 0
2243
        for dim, l in enumerate(layout.split("_")):
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
        inp[i].requires_grad = True

2257
2258
2259
    qkv_format_kv = "_".join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace("s", "sq")
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
2260
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
2261
    out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
2262

2263
    with autocast(enabled=fp8_dpa, recipe=fp8_recipe):
2264
2265
2266
2267
        out = dpa(
            inp[0],
            inp[1],
            inp[2],
2268
2269
2270
2271
2272
2273
2274
2275
            qkv_format=qkv_format,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=False,
            core_attention_bias_type=config.attn_bias_type,
2276
            fp8_output=fp8_dpa,
2277
        )
2278
2279
    if is_training:
        out.backward(out_grad)
2280

2281
2282
2283
    if is_training:
        return out, (inp[0].grad, inp[1].grad, inp[2].grad)
    return out, (None, None, None)
2284
2285
2286


model_configs_fp8 = {
2287
    # test: ModelConfig(b, sq, hq, dqk)
2288
2289
2290
2291
2292
2293
2294
2295
    "fp8_1": ModelConfig(1, 512, 1, 64),
    "fp8_2": ModelConfig(4, 512, 16, 64),
    "fp8_3": ModelConfig(1, 2048, 1, 128),
    "fp8_4": ModelConfig(2, 2048, 24, 128),
    "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
    "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
    "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
    "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
2296
2297
}
param_types_fp8 = [torch.float16, torch.bfloat16]
2298
2299
2300
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
2301
2302


2303
2304
2305
2306
2307
2308
2309
2310
@pytest.mark.skipif(
    (
        get_cudnn_version() < (8, 9, 3)
        if cudnn_frontend_version == 0
        else get_cudnn_version() < (9, 2, 1)
    ),
    reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
2311
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
    """Test FP8 dot product attention implementations based on cuDNN frontend
    v0.9 and v1.0+. Each test compares results from a custom implementation of
    an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
    implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
    Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""

    config = model_configs_fp8[model]

2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
    # Test backend availability
    is_training = True
    available_backends, _, fused_attn_backends = get_available_attention_backends(
        config,
        qkv_dtype=torch.float8_e4m3fn,
        qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
        is_training=is_training,
    )
    flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
    if not (fused_attn_backends and unfused_attn_supported):
        pytest.skip("Not enough backends to run this test with.")

2335
2336
    fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
    unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
2337

2338
2339
    atol = 5e-1
    rtol = 5e-1
2340
    rmse_tol = 0.13
2341
    compare_and_assert(
2342
2343
2344
2345
2346
2347
2348
        fused_attn_fwd_fp8,
        unfused_attn_fwd_f16,
        "fused_attn_fwd_fp8",
        "unfused_attn_fwd_f16",
        atol,
        rtol,
        rmse_tol,
2349
        True,
2350
    )
2351
    compare_and_assert(
2352
2353
2354
2355
2356
2357
2358
        fused_attn_bwd_fp8,
        unfused_attn_bwd_f16,
        "fused_attn_bwd_fp8",
        "unfused_attn_bwd_f16",
        atol,
        rtol,
        rmse_tol,
2359
        True,
2360
    )
2361
2362
2363
2364
2365


def _run_custom_mha_fp8(dtype, config, backend):
    """Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
2366
    reset_rng_states()
2367
2368
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
2369
2370
2371
2372
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2373
    _attention_backends["backend_selection_requires_update"] = True
2374

2375
2376
2377
    inp = 0.0001 * torch.randint(
        -100,
        100,
2378
        (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
2379
2380
2381
2382
2383
        dtype=dtype,
        device="cuda",
        requires_grad=True,
    )
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
2384
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2385
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2386

2387
    out_grad = 0.01 * torch.randn(
2388
        config.batch_size * config.max_seqlen_q,
2389
        config.num_heads * config.head_dim_qk,
2390
2391
2392
2393
        dtype=dtype,
        device="cuda",
    )
    torch.save(out_grad, "out_grad.pt")
2394
2395
2396
2397
2398
2399
2400
2401

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

2402
    mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
2403
    with autocast(enabled=True, recipe=fp8_recipe):
2404
        out = mha(inp, cu_seqlens, config.max_seqlen_q)
2405
    out.backward(out_grad)
2406

2407
    out = torch.load("out.pt")
2408
2409
2410
2411
    dqkv = torch.load("dqkv.pt")
    return (
        out.view(config.batch_size, config.max_seqlen_q, -1),
        dqkv.view(
2412
            config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
2413
2414
        ).contiguous(),
    )
2415

2416

2417
2418
2419
def _run_ref_mha_f16(dtype, config, backend):
    """Run reference F16 FusedAttention. Both input and output
    are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
2420
2421
2422
2423
2424
2425
2426

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
2427
    _attention_backends["backend_selection_requires_update"] = True
2428

2429
    inp = torch.load("qkv.pt").to(device="cuda")
2430
2431
2432
    inp.requires_grad = True
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
2433
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
2434
2435
2436
    out_grad = (
        torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
    )
2437
2438
2439
2440
2441
2442
2443

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
2444

2445
2446
    block = DotProductAttention(
        config.num_heads,
2447
        config.head_dim_qk,
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
        attention_dropout=config.dropout_p,
        sequence_parallel=False,
        tp_size=1,
        get_rng_state_tracker=get_dummy_cuda_rng_tracker,
        tp_group=None,
        layer_number=1,
        attention_type="self",
        qkv_format="bshd",
    ).to(dtype=dtype, device="cuda")

    q = inp[:, :, 0, :, :]
    k = inp[:, :, 1, :, :]
    v = inp[:, :, 2, :, :]
2461
2462
2463
2464
    out = block(q, k, v, attn_mask_type=config.attn_mask_type)
    out.backward(out_grad)

    return out, inp.grad
2465
2466
2467
2468
2469
2470
2471


_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

2472
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
2473
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
2474
2475
2476
2477
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
2478
2479


2480
class _custom_mha_fp8(torch.autograd.Function):
2481
2482
2483
2484
2485
2486
2487
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
2488
        num_heads: int,
2489
2490
2491
2492
2493
2494
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        workspace: torch.Tensor,
        is_training: bool,
2495
        mask_type: str,
2496
        quantizers: list[Quantizer],
2497
    ) -> torch.Tensor:
2498
        qkv_dtype = inp.dtype
2499
2500
2501

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
2502
        h = num_heads
2503
2504
2505
        d = in_features // h
        b = cu_seqlens.numel() - 1

2506
2507
2508
2509
2510
2511
2512
2513
        input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
        qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
        qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
        o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
        dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
        dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
        s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
        dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
2514

2515
        inp_fp8 = input_quantizer(inp)
2516

2517
        qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight)
2518

2519
        qkv, *_ = ext.general_gemm(
2520
            qkv_weight_fp8,
2521
            inp_fp8,
2522
2523
            workspace,
            bias=qkv_bias,
2524
2525
            out_dtype=qkv_weight_fp8.dtype,
            quantization_params=qkv_quantizer,
2526
2527
            use_split_accumulator=_2X_ACC_FPROP,
        )
2528
        qkv = qkv.view(-1, 3, h, d)
2529
        qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
2530
        torch.save(qkv_fp16, "qkv.pt")
2531
        if cudnn_frontend_version == 1:
2532
            qkv = qkv.view(b, max_s, 3, h, d)  # bs3hd
2533
2534

        # FMHA
2535
2536
2537
2538
2539
2540
2541
2542
        q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
        k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
        v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
        q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
        k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
        v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)

        out, aux_ctx_tensors = fused_attn_fwd(
2543
2544
2545
2546
2547
            is_training,
            max_s,
            max_s,
            cu_seqlens,
            cu_seqlens,
2548
2549
2550
2551
            q,
            k,
            v,
            qkv_dtype,
2552
2553
2554
2555
2556
2557
2558
2559
            FusedAttnBackend["FP8"],
            attn_scale=None,
            dropout=p_dropout,
            fast_zero_fill=fast_zero_fill,
            qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
            attn_bias_type="no_bias",
            attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
            rng_gen=None,
2560
2561
            o_quantizer=o_quantizer,
            s_quantizer=s_quantizer,
2562
        )
2563

2564
2565
        tensors_to_save, tensor_objects = prepare_for_saving(
            q, k, v, inp_fp8, qkv_weight_fp8, workspace, out
2566
        )
2567
2568
2569

        ctx.save_for_backward(*tensors_to_save)
        ctx.tensor_objects = tensor_objects
2570
        ctx.aux_ctx_tensors = aux_ctx_tensors
2571
        ctx.qkv_dtype = qkv_dtype
2572
2573
2574
2575
2576
2577
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.hidden_size = in_features
2578
        ctx.num_heads = num_heads
2579
2580
        ctx.mask_type = mask_type
        ctx.dtype = inp.dtype
2581

2582
2583
2584
2585
2586
        ctx.dQKV_quantizer = dQKV_quantizer
        ctx.dO_quantizer = dO_quantizer
        ctx.dP_quantizer = dP_quantizer
        ctx.S_quantizer = s_quantizer

2587
        out = out.view(-1, in_features)  # (bs)(hd)
2588
        out_fp16 = out.dequantize()
2589
        torch.save(out_fp16, "out.pt")  # (bs)(hd)
2590
        return out_fp16
2591
2592

    @staticmethod
2593
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
2594
        with torch.cuda.nvtx.range("_DPA"):
2595
2596
2597
2598
            saved_tensors = ctx.saved_tensors
            (q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved(
                ctx.tensor_objects, saved_tensors
            )
2599

2600
            proj_dgrad = ctx.dO_quantizer(grad_output)
2601
            fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
2602

2603
            dq, dk, dv, *rest = fused_attn_bwd(
2604
2605
2606
2607
                ctx.max_s,
                ctx.max_s,
                ctx.cu_seqlens,
                ctx.cu_seqlens,
2608
2609
2610
                q,
                k,
                v,
2611
2612
                out,
                proj_dgrad.view_as(out),
2613
                ctx.qkv_dtype,
2614
2615
2616
2617
2618
                fp8_dtype_backward,
                ctx.aux_ctx_tensors,
                FusedAttnBackend["FP8"],
                None,
                None,
2619
2620
2621
                ctx.S_quantizer,
                ctx.dP_quantizer,
                ctx.dQKV_quantizer,
2622
2623
2624
2625
2626
2627
2628
                attn_scale=None,
                dropout=ctx.p_dropout,
                fast_zero_fill=ctx.fast_zero_fill,
                qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
                attn_bias_type="no_bias",
                attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
            )
2629
            dim = 2 if cudnn_frontend_version == 1 else 1
2630
2631
            dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
            dqkv_shape = list(dq._data.shape)
2632
            dqkv_shape.insert(dim, 3)
2633
            dqkv_stride = list(dq._data.stride())
2634
            dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
2635
2636
2637
            dqkv.set_(
                dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride
            )  # bs3hd
2638

2639
            dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
2640
2641
            dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape)
            dqkv_c_fp16 = dqkv_c.dequantize()
2642
            torch.save(dqkv_c_fp16, "dqkv.pt")
2643

2644
2645
2646
            qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer)
            dqkv_c._transpose = None
            dqkv_c._create_transpose()
2647
2648

            # QKV DGRAD
2649
2650
            qkv_dgrad, *_ = ext.general_gemm(
                qkv_weight_fp8,
2651
                dqkv_c,
2652
                workspace,
2653
                ctx.dtype,
2654
                use_split_accumulator=_2X_ACC_DGRAD,
2655
                layout="NN",
2656
            )
2657

2658
            # QKV WGRAD
2659
2660
2661
            qkv_wgrad, *_ = ext.general_gemm(
                inp_fp8,
                dqkv,
2662
                workspace,
2663
                ctx.dtype,
2664
                use_split_accumulator=_2X_ACC_WGRAD,
2665
                layout="NT",
2666
2667
            )

2668
2669
        return (
            qkv_dgrad,
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2681
2682
            None,
        )
2683

2684

2685
class Custom_MHA_FP8(TransformerEngineBaseModule):
2686
    def __init__(self, config, params_dtype: torch.dtype = torch.float32):
2687
2688
        super().__init__()
        self.p_dropout = config.dropout_p
2689
        self.h = config.num_heads
2690
        self.hidden_size = config.hidden_size
2691
        self.head_dim = config.head_dim_qk
2692
        self.fast_zero_fill = True
2693
        self.mask_type = config.attn_mask_type
2694

Tim Moon's avatar
Tim Moon committed
2695
        self.qkv_weight = torch.nn.Parameter(
2696
2697
2698
2699
2700
2701
2702
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
Tim Moon's avatar
Tim Moon committed
2703
        self.qkv_bias = torch.nn.Parameter(
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)
        self.workspace = torch.empty(
            _CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
        )

    def forward(
2718
2719
2720
2721
        self,
        inp: torch.Tensor,
        cu_seqlens,
        max_s,
2722
    ) -> torch.Tensor:
2723
        with self.prepare_forward(inp, num_gemms=3) as inp:
2724
            out = _custom_mha_fp8.apply(
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
                self.workspace,
2735
                self.training,
2736
                self.mask_type,
2737
                self.quantizers,
2738
            )
2739
        return out