# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import os, sys, time import subprocess import pandas as pd import numpy as np import torch import nvtx import transformer_engine from tests.pytorch.fused_attn.test_fused_attn import ( ModelConfig, _is_flash_attention_supported, _is_fused_attention_supported, _is_unfused_attention_supported, _run_dot_product_attention, ) pd.set_option("display.precision", 4) # data type dtype = torch.bfloat16 # number of iterations after 3 warmup iterations num_iters = 3 # checkpointing ckpt_attn = False # workspace optimization path for cuDNN attention workspace_opt = True # QKV memory layout qkv_layout = "bshd_bshd_bshd" # sliding window attention swa = False # padding between sequences for qkv_format=thd pad_between_seqs = False # training mode is_training = True model_configs = { # test: b, h, hg, d, sq, skv, p, mask, bias "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA } def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported): config = model_configs[model] if dtype == torch.bfloat16: tols = dict(atol=2.5e-2, rtol=2.5e-2) else: tols = dict(atol=5e-3, rtol=5e-3) cudnn_times = [] flash_times = [] warmup_iters = 3 for i in range(warmup_iters): if fused_attn_supported: fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, ) if flash_attn_supported: flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, ) if fused_attn_supported and flash_attn_supported: torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.cuda.cudart().cudaProfilerStart() torch.cuda.synchronize() fused_attn_start = time.time() if fused_attn_supported: for i in range(num_iters): fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, ) torch.cuda.synchronize() fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0 torch.cuda.synchronize() flash_attn_start = time.time() if flash_attn_supported: for i in range(num_iters): flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training, ) torch.cuda.synchronize() flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0 df = pd.read_csv("times.csv") df = pd.concat( [ df, pd.DataFrame( [ [ fused_attn_time * 1e3 / num_iters, 0, 0, 0, flash_attn_time * 1e3 / num_iters, 0, 0, 0, 0, ] ], columns=df.columns, ), ], ignore_index=True, ) df.to_csv("times.csv", index=False) torch.cuda.cudart().cudaProfilerStop() def parse_results(per_cudnn, per_flash, model): filename = f"prof_{model}_cuda_gpu_trace.csv" df = pd.read_csv(os.path.join("./", filename)) df_times = pd.read_csv("times.csv") row = len(df_times.index) - 1 if per_cudnn > 0: t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy() t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn) t_cudnn_avg = np.average(t_cudnn_all, axis=0) df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6 df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6 df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6 if per_flash > 0: t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy() t_flash_all = t_flash_all.reshape(-1, per_flash) t_flash_avg = np.average(t_flash_all, axis=0) df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6 df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6 df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6 if per_cudnn > 0 and per_flash > 0: df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = ( df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] / df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] ) df_times.to_csv("times.csv", index=False) def main(): times = pd.DataFrame( columns=[ "FusedAttention Module", "FusedAttention Kernels (fwd)", "FusedAttention Kernels (bwd)", "FusedAttention Kernels (fwd+bwd)", "FlashAttention Module", "FlashAttention Kernels (fwd)", "FlashAttention Kernels (bwd)", "FlashAttention Kernels (fwd+bwd)", "Fused vs Flash Kernels Speedup (fwd+bwd)", ] ) times.to_csv("times.csv", index=False) device_id = torch.cuda.current_device() device_properties = torch.cuda.get_device_properties(device_id) print( f"Device {device_id}: " f"{device_properties.name} GPU, " f"sm{device_properties.major}{device_properties.minor} compute capability, " f"{device_properties.total_memory/1024**3:.1f}GB memory" ) for model in model_configs.keys(): config = model_configs[model] fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( config, dtype, qkv_layout=qkv_layout, ) fused_attn_supported = fused_attn_supported and not swa flash_attn_supported = _is_flash_attention_supported(config) print( f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}' f'{" and flash-attention" if flash_attn_supported else ""}...' ) prof_cmd = [ "nsys", "profile", "--capture-range=cudaProfilerApi", "--capture-range-end=stop-shutdown", "--force-overwrite=true", f"--output=prof_{model}", "python", "-c", f""" "import benchmark_attention;""", f"""benchmark_attention.benchmark_dot_product_attention(""" f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """, ] prof_cmd = " ".join(prof_cmd) subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) stats_cmd = [ "nsys", "stats", "-q", "-r", "cuda_gpu_trace", "--format", "csv,column", "--force-overwrite=true", "--force-export=true", f"--output=prof_{model}", f"prof_{model}.nsys-rep", ] if fused_attn_supported: num_kernels_cudnn = 4 if config.attn_bias_type == "post_scale_bias": num_kernels_cudnn = num_kernels_cudnn + 1 if config.num_heads != config.num_gqa_groups: num_kernels_cudnn = num_kernels_cudnn + 2 else: num_kernels_cudnn = 0 num_kernels_flash = 4 if flash_attn_supported else 0 stats_cmd = " ".join(stats_cmd) subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) parse_cmd = [ "python", "-c", f""" "import benchmark_attention;""", f"""benchmark_attention.parse_results(""" f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """, ] parse_cmd = " ".join(parse_cmd) subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) df_times = pd.read_csv("times.csv") df_times.index = list(model_configs.keys()) a = df_times[ [ "FusedAttention Kernels (fwd+bwd)", "FlashAttention Kernels (fwd+bwd)", "Fused vs Flash Kernels Speedup (fwd+bwd)", ] ] a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"] print() print(a) if __name__ == "__main__": main()