# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py import openpyxl from block_sparse_attn.utils.benchmark import benchmark_forward import math import torch import os def benchmark_fwd( fn, *inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs, ): """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" return benchmark_forward( fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ) def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): assert mode in ["fwd"] f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) def efficiency(flop, time): return (flop / time / 10**12) if not math.isnan(time) else 0.0 def time_fwd(func, *args, **kwargs): time_f = benchmark_fwd(func, *args, **kwargs) return time_f[1].mean def write_to_excel(label, data, dir_path, file_name): workbook = openpyxl.Workbook() sheet = workbook.active sheet.append(label) os.makedirs(dir_path, exist_ok=True) for row in data: sheet.append(row) workbook.save(dir_path + file_name + ".xlsx")