utils.py 1.41 KB
Newer Older
Junxian's avatar
Junxian committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py

import openpyxl
from block_sparse_attn.utils.benchmark import benchmark_forward
import math
import torch
import os


def benchmark_fwd(
    fn,
    *inputs,
    grad=None,
    repeats=10,
    desc="",
    verbose=True,
    amp=False,
    amp_dtype=torch.float16,
    **kwinputs,
):
    """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
    return benchmark_forward(
            fn,
            *inputs,
            repeats=repeats,
            desc=desc,
            verbose=verbose,
            amp=amp,
            amp_dtype=amp_dtype,
            **kwinputs,
            )

def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
    assert mode in ["fwd"]
    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
    return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)

def efficiency(flop, time):
    return (flop / time / 10**12) if not math.isnan(time) else 0.0

def time_fwd(func, *args, **kwargs):
    time_f = benchmark_fwd(func, *args, **kwargs)
    return time_f[1].mean

def write_to_excel(label, data, dir_path, file_name):
    workbook = openpyxl.Workbook()
    sheet = workbook.active
    sheet.append(label)
    os.makedirs(dir_path, exist_ok=True)
    for row in data:
        sheet.append(row)
    workbook.save(dir_path + file_name + ".xlsx")