example_attention.py 3.04 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# See LICENSE for license information.

import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
    ModelConfig,
14
    _get_attention_backends,
15
    _run_dot_product_attention,
16
17
18
19
20
21
22
23
24
25
26
)

# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
27
qkv_layout = "bshd_bshd_bshd"
28
29
30
31
32
33
34
35
36
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True

model_configs = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,     mask,              bias
37
38
39
40
    "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),  # short seq
    "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"),  # longer seq, mask
    "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),  # bias
    "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),  # GQA
41
42
}

43

44
45
46
47
48
49
50
51
52
def example_attention(model, fused_attn_supported, flash_attn_supported):
    config = model_configs[model]
    if dtype == torch.bfloat16:
        tols = dict(atol=2.5e-2, rtol=2.5e-2)
    else:
        tols = dict(atol=5e-3, rtol=5e-3)

    if fused_attn_supported:
        print()
53
        print("Run cuDNN attention...")
54
        fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
55
56
57
58
59
60
61
62
            dtype,
            config,
            "FusedAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
63
64
65
66
        )

    if flash_attn_supported:
        print()
67
        print("Run flash-attention...")
68
        flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
69
70
71
72
73
74
75
76
            dtype,
            config,
            "FlashAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            pad_between_seqs,
            is_training,
77
78
79
80
        )

    if fused_attn_supported and flash_attn_supported:
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
81
        for i, _ in enumerate(flash_attn_bwd):
82
83
84
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)

    print()
85
86
    print("Test passed.")

87
88
89

def main():

90
    models = ["test_0"]
91
92
    for model in models:
        config = model_configs[model]
93
        available_backends, fused_attn_backends = _get_attention_backends(
94
            config,
95
            qkv_dtype=dtype,
96
            qkv_layout=qkv_layout,
97
98
            window_size=config.window_size,
            pad_between_seqs=pad_between_seqs,
99
        )
100
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
101
102
103

        example_attention(model, fused_attn_supported, flash_attn_supported)

104

105
106
if __name__ == "__main__":
    main()