# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import os, sys, time import subprocess import pandas as pd import numpy as np import torch import nvtx import transformer_engine from tests.pytorch.utils import ( ModelConfig, get_available_attention_backends, ) from tests.pytorch.attention.test_attention import _run_dot_product_attention pd.set_option("display.precision", 4) # data type dtype = torch.bfloat16 # number of iterations after 3 warmup iterations num_iters = 3 # checkpointing ckpt_attn = False # workspace optimization path for cuDNN attention workspace_opt = True # QKV memory layout qkv_layout = "bshd_bshd_bshd" # padding between sequences for qkv_format=thd pad_between_seqs = False # training mode is_training = True model_configs = { # test: b, h, hg, d, sq, skv, p, mask, bias "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA } def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported): config = model_configs[model] if dtype == torch.bfloat16: tols = dict(atol=2.5e-2, rtol=2.5e-2) else: tols = dict(atol=5e-3, rtol=5e-3) cudnn_times = [] flash_times = [] warmup_iters = 3 for i in range(warmup_iters): if fused_attn_supported: fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, pad_between_seqs, is_training, ) if flash_attn_supported: flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, pad_between_seqs, is_training, ) if fused_attn_supported and flash_attn_supported: torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.cuda.cudart().cudaProfilerStart() torch.cuda.synchronize() fused_attn_start = time.time() if fused_attn_supported: for i in range(num_iters): fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, pad_between_seqs, is_training, ) torch.cuda.synchronize() fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0 torch.cuda.synchronize() flash_attn_start = time.time() if flash_attn_supported: for i in range(num_iters): flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, pad_between_seqs, is_training, ) torch.cuda.synchronize() flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0 df = pd.read_csv("times.csv") df = pd.concat( [ df, pd.DataFrame( [ [ fused_attn_time * 1e3 / num_iters, 0, 0, 0, flash_attn_time * 1e3 / num_iters, 0, 0, 0, 0, ] ], columns=df.columns, ), ], ignore_index=True, ) df.to_csv("times.csv", index=False) torch.cuda.cudart().cudaProfilerStop() def parse_results(per_cudnn, per_flash, model): filename = f"prof_{model}_cuda_gpu_trace.csv" df = pd.read_csv(os.path.join("./", filename)) df_times = pd.read_csv("times.csv") row = len(df_times.index) - 1 if per_cudnn > 0: t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy() t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn) t_cudnn_avg = np.average(t_cudnn_all, axis=0) df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6 df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6 df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6 if per_flash > 0: t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy() t_flash_all = t_flash_all.reshape(-1, per_flash) t_flash_avg = np.average(t_flash_all, axis=0) df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6 df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6 df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6 if per_cudnn > 0 and per_flash > 0: df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = ( df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] / df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] ) df_times.to_csv("times.csv", index=False) def main(): times = pd.DataFrame( columns=[ "FusedAttention Module", "FusedAttention Kernels (fwd)", "FusedAttention Kernels (bwd)", "FusedAttention Kernels (fwd+bwd)", "FlashAttention Module", "FlashAttention Kernels (fwd)", "FlashAttention Kernels (bwd)", "FlashAttention Kernels (fwd+bwd)", "Fused vs Flash Kernels Speedup (fwd+bwd)", ] ) times.to_csv("times.csv", index=False) device_id = torch.cuda.current_device() device_properties = torch.cuda.get_device_properties(device_id) print( f"Device {device_id}: " f"{device_properties.name} GPU, " f"sm{device_properties.major}{device_properties.minor} compute capability, " f"{device_properties.total_memory/1024**3:.1f}GB memory" ) for model in model_configs.keys(): config = model_configs[model] available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, window_size=config.window_size, pad_between_seqs=pad_between_seqs, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends print( f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}' f'{" and flash-attention" if flash_attn_supported else ""}...' ) prof_cmd = [ "nsys", "profile", "--capture-range=cudaProfilerApi", "--capture-range-end=stop-shutdown", "--force-overwrite=true", f"--output=prof_{model}", "python", "-c", f""" "import benchmark_attention;""", f"""benchmark_attention.benchmark_dot_product_attention(""" f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """, ] prof_cmd = " ".join(prof_cmd) subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) stats_cmd = [ "nsys", "stats", "-q", "-r", "cuda_gpu_trace", "--format", "csv,column", "--force-overwrite=true", "--force-export=true", f"--output=prof_{model}", f"prof_{model}.nsys-rep", ] if fused_attn_supported: num_kernels_cudnn = 4 if config.attn_bias_type == "post_scale_bias": num_kernels_cudnn = num_kernels_cudnn + 1 if config.num_heads != config.num_gqa_groups: num_kernels_cudnn = num_kernels_cudnn + 2 else: num_kernels_cudnn = 0 num_kernels_flash = 4 if flash_attn_supported else 0 stats_cmd = " ".join(stats_cmd) subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) parse_cmd = [ "python", "-c", f""" "import benchmark_attention;""", f"""benchmark_attention.parse_results(""" f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """, ] parse_cmd = " ".join(parse_cmd) subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) df_times = pd.read_csv("times.csv") df_times.index = list(model_configs.keys()) a = df_times[ [ "FusedAttention Kernels (fwd+bwd)", "FlashAttention Kernels (fwd+bwd)", "Fused vs Flash Kernels Speedup (fwd+bwd)", ] ] a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"] print() print(a) if __name__ == "__main__": main()