benchmark_attention.py 8.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
    ModelConfig,
    _is_flash_attention_supported,
    _is_fused_attention_supported,
    _is_unfused_attention_supported,
    _run_dot_product_attention
)

pd.set_option("display.precision", 4)

# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = 'bshd_bshd_bshd'
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True

model_configs = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,     mask,              bias
    "test_0": ModelConfig(2, 16, 16,  64,  512,  512, 0.0, "no_mask",         "no_bias"), # short seq
    "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  "causal",         "no_bias"), # longer seq, mask
    "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  "causal", "post_scale_bias"), # bias
    "test_3": ModelConfig(2, 32,  4, 128, 8192, 8192, 0.0,  "causal",         "no_bias"), # GQA
}

def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
    config = model_configs[model]
    if dtype == torch.bfloat16:
        tols = dict(atol=2.5e-2, rtol=2.5e-2) 
    else:
        tols = dict(atol=5e-3, rtol=5e-3)

    cudnn_times = []
    flash_times = []
    warmup_iters = 3
    for i in range(warmup_iters):
        if fused_attn_supported:
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
                dtype, config, "FusedAttention",
                ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
            )
        if flash_attn_supported:
            flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
                dtype, config, "FlashAttention",
                ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
            )
        if fused_attn_supported and flash_attn_supported:
            torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
            for i,_ in enumerate(flash_attn_bwd):
                torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)

    torch.cuda.cudart().cudaProfilerStart()
    torch.cuda.synchronize()
    fused_attn_start = time.time()
    if fused_attn_supported:
        for i in range(num_iters):
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
                dtype, config, "FusedAttention",
                ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
            )
    torch.cuda.synchronize()
    fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0

    torch.cuda.synchronize()
    flash_attn_start = time.time()
    if flash_attn_supported:
        for i in range(num_iters):
            flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
                dtype, config, "FlashAttention",
                ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
            )
    torch.cuda.synchronize()
    flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0

    df = pd.read_csv('times.csv')
    df = pd.concat([
        df,
        pd.DataFrame(
            [[fused_attn_time*1e3/num_iters, 0, 0, 0,
                flash_attn_time*1e3/num_iters, 0, 0, 0, 0]], columns=df.columns)],
            ignore_index=True
        )
    df.to_csv('times.csv',index=False)
    torch.cuda.cudart().cudaProfilerStop()

def parse_results(per_cudnn, per_flash, model):
    filename = f'prof_{model}_cuda_gpu_trace.csv'
    df = pd.read_csv(os.path.join('./',filename))
    df_times = pd.read_csv('times.csv')
    row = len(df_times.index)-1
    
    if per_cudnn > 0:
        t_cudnn_all = df[df['Name'].str.contains('cudnn')]['Duration (ns)'].to_numpy()
        t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
        t_cudnn_avg = np.average(t_cudnn_all, axis=0)
        df_times.loc[row, 'FusedAttention Kernels (fwd)'] = t_cudnn_avg[0]/1e6
        df_times.loc[row, 'FusedAttention Kernels (bwd)'] = t_cudnn_avg[1:4].sum()/1e6
        df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] = t_cudnn_avg.sum()/1e6

    if per_flash > 0:
        t_flash_all = df[df['Name'].str.contains('void flash')]['Duration (ns)'].to_numpy()
        t_flash_all = t_flash_all.reshape(-1, per_flash)
        t_flash_avg = np.average(t_flash_all, axis=0)
        df_times.loc[row, 'FlashAttention Kernels (fwd)'] = t_flash_avg[0]/1e6
        df_times.loc[row, 'FlashAttention Kernels (bwd)'] = t_flash_avg[1:4].sum()/1e6
        df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] = t_flash_avg.sum()/1e6

    if per_cudnn > 0 and per_flash > 0:
        df_times.loc[row, 'Fused vs Flash Kernels Speedup (fwd+bwd)'] = \
                df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] / \
                df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)']
    df_times.to_csv('times.csv',index=False)

def main():
    times = pd.DataFrame(
            columns=[
                'FusedAttention Module',
                'FusedAttention Kernels (fwd)',
                'FusedAttention Kernels (bwd)',
                'FusedAttention Kernels (fwd+bwd)',
                'FlashAttention Module',
                'FlashAttention Kernels (fwd)',
                'FlashAttention Kernels (bwd)',
                'FlashAttention Kernels (fwd+bwd)',
                'Fused vs Flash Kernels Speedup (fwd+bwd)',
                ])
    times.to_csv('times.csv',index=False)

    device_id = torch.cuda.current_device()
    device_properties = torch.cuda.get_device_properties(device_id)
    print(f"Device {device_id}: "
        f"{device_properties.name} GPU, "
        f"sm{device_properties.major}{device_properties.minor} compute capability, "
        f"{device_properties.total_memory/1024**3:.1f}GB memory")
    for model in model_configs.keys():
        config = model_configs[model]
        fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
            config, dtype, qkv_layout=qkv_layout,
        )
        fused_attn_supported = fused_attn_supported and not swa
        flash_attn_supported = _is_flash_attention_supported(config)
        print(f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
            f'{" and flash-attention" if flash_attn_supported else ""}...')

        prof_cmd = [
            "nsys",
            "profile",
            "--capture-range=cudaProfilerApi",
            "--capture-range-end=stop-shutdown",
            "--force-overwrite=true",
            f"--output=prof_{model}",
            "python",
            "-c",
            f""" "import benchmark_attention;""",
            f"""benchmark_attention.benchmark_dot_product_attention("""
            f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
            ]
        prof_cmd = ' '.join(prof_cmd)
        subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
        stats_cmd = [
            "nsys",
            "stats",
            "-q",
            "-r",
            "cuda_gpu_trace",
            "--format",
            "csv,column",
            "--force-overwrite=true",
            "--force-export=true",
            f"--output=prof_{model}",
            f"prof_{model}.nsys-rep",
            ]
        if fused_attn_supported:
            num_kernels_cudnn = 4
            if config.attn_bias_type == 'post_scale_bias':
                num_kernels_cudnn = num_kernels_cudnn+1 
            if config.num_heads != config.num_gqa_groups:
                num_kernels_cudnn = num_kernels_cudnn+2 
        else:
            num_kernels_cudnn = 0
        num_kernels_flash = 4 if flash_attn_supported else 0
        stats_cmd = ' '.join(stats_cmd)
        subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
        parse_cmd = [
            "python",
            "-c",
            f""" "import benchmark_attention;""",
            f"""benchmark_attention.parse_results("""
            f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
            ]
        parse_cmd = ' '.join(parse_cmd)
        subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)

    df_times = pd.read_csv('times.csv')
    df_times.index = list(model_configs.keys())
    a=df_times[['FusedAttention Kernels (fwd+bwd)',
                'FlashAttention Kernels (fwd+bwd)',
                'Fused vs Flash Kernels Speedup (fwd+bwd)']]
    a.columns = ['cuDNN fwd+bwd (ms)', 'flash-attn fwd+bwd (ms)', 'cuDNN vs flash speedup']
    print()
    print(a)

if __name__ == "__main__":
    main()