example_attention.py 3.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
    ModelConfig,
    _is_flash_attention_supported,
    _is_fused_attention_supported,
    _is_unfused_attention_supported,
17
    _run_dot_product_attention,
18
19
20
21
22
23
24
25
26
27
28
)

# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
29
qkv_layout = "bshd_bshd_bshd"
30
31
32
33
34
35
36
37
38
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True

model_configs = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,     mask,              bias
39
40
41
42
    "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),  # short seq
    "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"),  # longer seq, mask
    "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),  # bias
    "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),  # GQA
43
44
}

45

46
47
48
49
50
51
52
53
54
def example_attention(model, fused_attn_supported, flash_attn_supported):
    config = model_configs[model]
    if dtype == torch.bfloat16:
        tols = dict(atol=2.5e-2, rtol=2.5e-2)
    else:
        tols = dict(atol=5e-3, rtol=5e-3)

    if fused_attn_supported:
        print()
55
        print("Run cuDNN attention...")
56
        fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
57
58
59
60
61
62
63
64
65
            dtype,
            config,
            "FusedAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            swa,
            pad_between_seqs,
            is_training,
66
67
68
69
        )

    if flash_attn_supported:
        print()
70
        print("Run flash-attention...")
71
        flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
72
73
74
75
76
77
78
79
80
            dtype,
            config,
            "FlashAttention",
            ckpt_attn,
            qkv_layout,
            workspace_opt,
            swa,
            pad_between_seqs,
            is_training,
81
82
83
84
        )

    if fused_attn_supported and flash_attn_supported:
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
85
        for i, _ in enumerate(flash_attn_bwd):
86
87
88
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)

    print()
89
90
    print("Test passed.")

91
92
93

def main():

94
    models = ["test_0"]
95
96
97
    for model in models:
        config = model_configs[model]
        fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
98
99
100
            config,
            dtype,
            qkv_layout=qkv_layout,
101
102
103
104
105
106
        )
        fused_attn_supported = fused_attn_supported and not swa
        flash_attn_supported = _is_flash_attention_supported(config)

        example_attention(model, fused_attn_supported, flash_attn_supported)

107

108
109
if __name__ == "__main__":
    main()