fa_bwd_benchmark.py 7.39 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import torch.utils.benchmark as benchmark
from collections import namedtuple
import math
import importlib.util
import csv

# 加载动态库
path_to_so = '../build/flash-attention.so'
print("load from {}".format(path_to_so))
spec = importlib.util.spec_from_file_location("flash_attn_2_cuda", path_to_so)
flash_attn_2_cuda = importlib.util.module_from_spec(spec)
spec.loader.exec_module(flash_attn_2_cuda)
import flash_attn_2_cuda as _C_flashattention

def benchmark_backward(fn, *inputs, repeats=1, desc="", verbose=False, amp=False, amp_dtype=torch.float16, **kwinputs):
    if verbose:
        print(desc, "- Backward pass")
    def amp_wrapper(*inputs, **kwinputs):
        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
            fn(*inputs, **kwinputs)
    t = benchmark.Timer(
        stmt="fn_amp(*inputs, **kwinputs)",
        globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
        num_threads=torch.get_num_threads(),
    )
    m = t.timeit(repeats)
    if verbose: print(m)
    return m.times[0] 

def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
    assert mode in ["fwd", "bwd", "fwd_bwd"]
    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
    return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)

def efficiency(flop, time):
    return (flop / time / 10**12)

params_list = [
    {'causal': True, 'batch_size': 1, 'nheads': 16, 'nheads_k': 16, 'seq_len': 8192, 'head_size': 128, 'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 32, 'nheads_k': 32, 'seq_len': 8192, 'head_size': 128, 'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 32, 'nheads_k': 4, 'seq_len': 8192, 'head_size': 128, 'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 52, 'nheads_k': 4, 'seq_len': 8192, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 16, 'nheads_k': 2, 'seq_len': 8192, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 26, 'nheads_k': 2, 'seq_len': 8192, 'head_size': 128, 'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 8, 'nheads_k': 1, 'seq_len': 8192, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 13, 'nheads_k': 1, 'seq_len': 8192, 'head_size': 128,'window_size': [-1, -1]},

    {'causal': True, 'batch_size': 1, 'nheads': 32, 'nheads_k': 32, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 16, 'nheads_k': 16, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 8, 'nheads_k': 8, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 4, 'nheads_k': 4, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 40, 'nheads_k': 40, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},

    {'causal': True, 'batch_size': 1, 'nheads': 20, 'nheads_k': 20, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 10, 'nheads_k': 10, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 5, 'nheads_k': 5, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 32, 'nheads_k': 8, 'seq_len': 8192, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 16, 'nheads_k': 4, 'seq_len': 8192, 'head_size': 128,'window_size': [-1, -1]},

    {'causal': True, 'batch_size': 1, 'nheads': 8, 'nheads_k': 2, 'seq_len': 8192, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 4, 'nheads_k': 1, 'seq_len': 8192, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 28, 'nheads_k': 4, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 14, 'nheads_k': 2, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},
    {'causal': True, 'batch_size': 1, 'nheads': 7, 'nheads_k': 1, 'seq_len': 4096, 'head_size': 128,'window_size': [-1, -1]},

]

csv_file_name = "bwd_results.csv"
fieldnames = ["batch_size", "seq_len", "head_size", "nheads", "nheads_k", "causal", "bwd_speed"]
results = []

for params in params_list:
    batch_size = params['batch_size']
    nheads = params['nheads']
    nheads_k = params['nheads_k']
    head_size = params['head_size']
    seq_len = params['seq_len']
    nheads_k = params['nheads_k']
    causal = params['causal']
    window_size_left = params['window_size'][0]
    window_size_right = params['window_size'][1]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    softmax_scale = 1.0 / math.sqrt(head_size)
    dropout_p = 0
    q = torch.randn(batch_size, nheads, seq_len, head_size, device=device, dtype=torch.float16, requires_grad=True)
    k = torch.randn(batch_size, nheads_k, seq_len, head_size, device=device, dtype=torch.float16, requires_grad=True)
    v = torch.randn(batch_size, nheads_k, seq_len, head_size, device=device, dtype=torch.float16, requires_grad=True)
    o = torch.randn(batch_size, nheads, seq_len, head_size, device=device, dtype=torch.float16, requires_grad=True)
    do = torch.randn(batch_size, nheads, seq_len, head_size, device=device, dtype=torch.float16, requires_grad=True)
    dq = torch.empty(batch_size, nheads, seq_len, head_size, device=device, dtype=torch.float16) 
    dk = torch.empty(batch_size, nheads_k, seq_len, head_size, device=device, dtype=torch.float16) 
    dv = torch.empty(batch_size, nheads_k, seq_len, head_size, device=device, dtype=torch.float16) 
    lse = torch.randn(batch_size, nheads_k, seq_len, device=device, dtype=torch.float16) 
    input_params = (
        do,
        q,
        k,
        v,
        o,
        lse,
        dq,
        dk,
        dv,
        None,
        dropout_p,
        softmax_scale,
        causal,
        window_size_left,
        window_size_right,
        0.0,
        False,
        None,
        None)
    fa_average_cost = 0
    # benchmark 多次取平均值
    iterations = 12
    warmup = 2
    cost_time_list = []
    for i in range(iterations):
        cost_time = benchmark_backward(_C_flashattention.bwd, *input_params, repeats=1)
        if i >= warmup:
            cost_time_list.append(cost_time)
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        # print(float(cost_time))
    max_cost_time = max(cost_time_list)
    cost_time_list.remove(max_cost_time)
    fa_average_cost = sum(cost_time_list) / (iterations - warmup - 1)
    calculation_amount_bwd = flops(batch_size, seq_len, head_size, nheads, causal,"bwd")
    speed_bwd = efficiency(calculation_amount_bwd, fa_average_cost)
    results.append({
    "batch_size": batch_size,
    "seq_len": seq_len,
    "head_size": head_size,
    "nheads": nheads,
    "nheads_k": nheads_k,
    "causal": causal,
    "bwd_speed": speed_bwd
})
    print("bs= {}, seq_len={}, head_size={}, nheads={}, nheads_k={}, causal={}, bwd speed={} tflops".format(batch_size, seq_len, head_size, nheads, nheads_k, causal, speed_bwd))
with open(csv_file_name, 'w', newline='') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()  # 写入表头
    for result in results:
        writer.writerow(result)