benchmark_attention.py 9.58 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
12
from tests.pytorch.utils import (
13
    ModelConfig,
14
    get_available_attention_backends,
15
)
16
from tests.pytorch.attention.test_attention import _run_dot_product_attention
17
18
19
20
21
22
23
24
25
26
27
28

pd.set_option("display.precision", 4)

# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
29
qkv_layout = "bshd_bshd_bshd"
30
31
32
33
34
35
36
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True

model_configs = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,     mask,              bias
37
38
39
40
    "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),  # short seq
    "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"),  # longer seq, mask
    "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),  # bias
    "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),  # GQA
41
42
}

43

44
45
46
def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
    config = model_configs[model]
    if dtype == torch.bfloat16:
47
        tols = dict(atol=2.5e-2, rtol=2.5e-2)
48
49
50
51
52
53
54
55
56
    else:
        tols = dict(atol=5e-3, rtol=5e-3)

    cudnn_times = []
    flash_times = []
    warmup_iters = 3
    for i in range(warmup_iters):
        if fused_attn_supported:
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
57
58
59
60
61
62
63
64
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
65
66
67
            )
        if flash_attn_supported:
            flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
68
69
70
71
72
73
74
75
                dtype,
                config,
                "FlashAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
76
77
78
            )
        if fused_attn_supported and flash_attn_supported:
            torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
79
            for i, _ in enumerate(flash_attn_bwd):
80
81
82
83
84
85
86
87
                torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)

    torch.cuda.cudart().cudaProfilerStart()
    torch.cuda.synchronize()
    fused_attn_start = time.time()
    if fused_attn_supported:
        for i in range(num_iters):
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
88
89
90
91
92
93
94
95
                dtype,
                config,
                "FusedAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
96
97
98
99
100
101
102
103
104
            )
    torch.cuda.synchronize()
    fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0

    torch.cuda.synchronize()
    flash_attn_start = time.time()
    if flash_attn_supported:
        for i in range(num_iters):
            flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
105
106
107
108
109
110
111
112
                dtype,
                config,
                "FlashAttention",
                ckpt_attn,
                qkv_layout,
                workspace_opt,
                pad_between_seqs,
                is_training,
113
114
115
116
            )
    torch.cuda.synchronize()
    flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    df = pd.read_csv("times.csv")
    df = pd.concat(
        [
            df,
            pd.DataFrame(
                [
                    [
                        fused_attn_time * 1e3 / num_iters,
                        0,
                        0,
                        0,
                        flash_attn_time * 1e3 / num_iters,
                        0,
                        0,
                        0,
                        0,
                    ]
                ],
                columns=df.columns,
            ),
        ],
        ignore_index=True,
    )
    df.to_csv("times.csv", index=False)
141
142
    torch.cuda.cudart().cudaProfilerStop()

143

144
def parse_results(per_cudnn, per_flash, model):
145
146
147
148
149
    filename = f"prof_{model}_cuda_gpu_trace.csv"
    df = pd.read_csv(os.path.join("./", filename))
    df_times = pd.read_csv("times.csv")
    row = len(df_times.index) - 1

150
    if per_cudnn > 0:
151
        t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy()
152
153
        t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
        t_cudnn_avg = np.average(t_cudnn_all, axis=0)
154
155
156
        df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6
        df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6
        df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
157
158

    if per_flash > 0:
159
        t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy()
160
161
        t_flash_all = t_flash_all.reshape(-1, per_flash)
        t_flash_avg = np.average(t_flash_all, axis=0)
162
163
164
        df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
        df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6
        df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6
165
166

    if per_cudnn > 0 and per_flash > 0:
167
168
169
170
171
172
        df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = (
            df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"]
            / df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"]
        )
    df_times.to_csv("times.csv", index=False)

173
174
175

def main():
    times = pd.DataFrame(
176
177
178
179
180
181
182
183
184
185
186
187
188
        columns=[
            "FusedAttention Module",
            "FusedAttention Kernels (fwd)",
            "FusedAttention Kernels (bwd)",
            "FusedAttention Kernels (fwd+bwd)",
            "FlashAttention Module",
            "FlashAttention Kernels (fwd)",
            "FlashAttention Kernels (bwd)",
            "FlashAttention Kernels (fwd+bwd)",
            "Fused vs Flash Kernels Speedup (fwd+bwd)",
        ]
    )
    times.to_csv("times.csv", index=False)
189
190
191

    device_id = torch.cuda.current_device()
    device_properties = torch.cuda.get_device_properties(device_id)
192
193
    print(
        f"Device {device_id}: "
194
195
        f"{device_properties.name} GPU, "
        f"sm{device_properties.major}{device_properties.minor} compute capability, "
196
197
        f"{device_properties.total_memory/1024**3:.1f}GB memory"
    )
198
199
    for model in model_configs.keys():
        config = model_configs[model]
200
        available_backends, _, fused_attn_backends = get_available_attention_backends(
201
            config,
202
            qkv_dtype=dtype,
203
            qkv_layout=qkv_layout,
204
205
            window_size=config.window_size,
            pad_between_seqs=pad_between_seqs,
206
        )
207
208
        flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

209
210
211
212
        print(
            f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
            f'{" and flash-attention" if flash_attn_supported else ""}...'
        )
213
214
215
216
217
218
219
220
221
222
223
224
225

        prof_cmd = [
            "nsys",
            "profile",
            "--capture-range=cudaProfilerApi",
            "--capture-range-end=stop-shutdown",
            "--force-overwrite=true",
            f"--output=prof_{model}",
            "python",
            "-c",
            f""" "import benchmark_attention;""",
            f"""benchmark_attention.benchmark_dot_product_attention("""
            f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
226
227
        ]
        prof_cmd = " ".join(prof_cmd)
228
229
230
231
232
233
234
235
236
237
238
239
240
        subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
        stats_cmd = [
            "nsys",
            "stats",
            "-q",
            "-r",
            "cuda_gpu_trace",
            "--format",
            "csv,column",
            "--force-overwrite=true",
            "--force-export=true",
            f"--output=prof_{model}",
            f"prof_{model}.nsys-rep",
241
        ]
242
243
        if fused_attn_supported:
            num_kernels_cudnn = 4
244
245
            if config.attn_bias_type == "post_scale_bias":
                num_kernels_cudnn = num_kernels_cudnn + 1
246
            if config.num_heads != config.num_gqa_groups:
247
                num_kernels_cudnn = num_kernels_cudnn + 2
248
249
250
        else:
            num_kernels_cudnn = 0
        num_kernels_flash = 4 if flash_attn_supported else 0
251
        stats_cmd = " ".join(stats_cmd)
252
253
254
255
256
257
258
        subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
        parse_cmd = [
            "python",
            "-c",
            f""" "import benchmark_attention;""",
            f"""benchmark_attention.parse_results("""
            f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
259
260
        ]
        parse_cmd = " ".join(parse_cmd)
261
262
        subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)

263
    df_times = pd.read_csv("times.csv")
264
    df_times.index = list(model_configs.keys())
265
266
267
268
269
270
271
272
    a = df_times[
        [
            "FusedAttention Kernels (fwd+bwd)",
            "FlashAttention Kernels (fwd+bwd)",
            "Fused vs Flash Kernels Speedup (fwd+bwd)",
        ]
    ]
    a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
273
274
275
    print()
    print(a)

276

277
278
if __name__ == "__main__":
    main()