benchmark_fa_varlen.py 4.85 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# from openpyxl import Workbook
from einops import rearrange, repeat

from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined

from flash_attn import flash_attn_qkvpacked_func,flash_attn_func
from flash_attn import flash_attn_varlen_func

wb = Workbook()
ws = wb.active



def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
    assert mode in ["fwd", "bwd", "fwd_bwd"]
    f = 4 * batch * seqlen**2 * nheads * headdim 
    if causal:
        f=f/2
    return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)

def efficiency(flop, time):
    return (flop / time / 10**12) if not math.isnan(time) else 0.0

def time_forward(func, *args, **kwargs):
    time_f, time_b = benchmark_forward(func, *args, **kwargs)
    return time_b.mean

def padding_bmhk(t):   # BMHK 
    # print(f"padding..")
    batch, seqlen, nheads, dim = t.shape
    t_tmp = torch.nn.functional.pad(t.reshape(batch, seqlen, nheads*dim), (0, 32), 'constant', 0)[:,:,:-32].reshape(batch, seqlen, nheads, dim)
    # print(f"{t_tmp.shape=}, {t_tmp.stride()=}")
    return t_tmp

repeats = 30
device = 'cuda'
dtype = torch.float16
bs_seqlen_vals = [(1,128), (1, 1024), (1, 2048), (1, 4096), (1, 6144), (1, 8192), (1, 10*1024), (1, 12*1024), (1, 16*1024), (1, 32*1024), (1, 64*1024)]
# bs_seqlen_vals = [(1, 1024), (1, 2048), (1, 4096), (1, 8192), (1, 16*1024), (1, 32*1024)]
# bs_seqlen_vals += [(8, 1024), (8, 2048), (8, 4096), (8, 8192), (8, 16*1024), (8, 32*1024)]
# bs_seqlen_vals += [(16, 2049), (32, 1024), (64, 512), (128, 256), (256, 128)]
causal_vals = [True]
headdim_vals = [128]
nheads_vals = [(32, 2), (16, 1), (8, 1), (32, 8),
    (32, 32), (16, 16), (8, 8), (4, 4), (40, 40),
    (20, 20), (10, 10), (5, 5), (32, 4), (16, 2), (16, 16), 
    (14, 2), (7, 1), (20, 4), (10, 2), (5, 1)]
# nheads_vals=[(28,4)]
dropout_p =0.0
pad=0
methods = (["Flash2"])
time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}

# ws.append(['batch_size', 'total_q', 'total_kv', 'nheads_q', 'num_heads_kv', 'causal', 'dim', 'dtype', 'tflops', 'time(ms)'])

for batch_size, seqlen in bs_seqlen_vals:
    for causal in causal_vals:
        for headdim in headdim_vals:
            for nheads_q, nheads_k in nheads_vals:
                config = (causal, headdim, batch_size, seqlen, nheads_q, nheads_k)
                q = torch.randn(batch_size,seqlen, nheads_q , headdim, device=device, dtype=dtype, requires_grad=True)
                k = torch.randn(batch_size,seqlen, nheads_k, headdim, device=device, dtype=dtype, requires_grad=True)
                v = torch.randn(batch_size,seqlen, nheads_k, headdim, device=device, dtype=dtype, requires_grad=True)
                q = padding_bmhk(q)
                k = padding_bmhk(k)
                v = padding_bmhk(v)
                # # print(q.shape)
                # print(q.stride())
                q = q.reshape(batch_size*seqlen, nheads_q, headdim)
                k = k.reshape(batch_size*seqlen, nheads_k, headdim)
                v = v.reshape(batch_size*seqlen, nheads_k, headdim)

                # print(q.shape)
                # print(q.stride())
                # print(k.shape)
                # print(k.stride())
                # print(v.shape)
                # exit(-1)
                cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                    device=device)

                f = time_forward(
                    flash_attn_varlen_func, q, k, v, cu_seqlens, cu_seqlens, seqlen, seqlen, dropout_p,
                    causal=causal, repeats=repeats, verbose=False
                )

                time_f[config, "Flash2"] = f
                print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, nheads_q={nheads_q}, nheads_k={nheads_k}, seqlen={seqlen} ###")
                for method in methods:
                    # time_f_b[config, method] = time_f[config, method] + time_b[config, method]
                    speed_f[config, method] = efficiency(
                        flops(batch_size, seqlen, headdim, nheads_q, causal, mode="fwd"),
                        time_f[config, method]
                    )

                    print(
                        f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s,  {time_f[config, method]*1000:.2f} ms"
                        # f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
                        # f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
                    )

                    # ws.append([batch_size, seqlen, seqlen, nheads_q, nheads_k, causal, headdim, "float16", round(speed_f[config, method], 2), round(time_f[config, method]*1000, 2)])
                    # exit(0)

# wb.save("varlen_64_32_4_018a7dd_waq.xlsx")