bench.py 5.52 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
from st_attn import sliding_tile_attention


def flops(batch, seqlen, nheads, headdim, causal, mode="fwd"):
    assert mode in ["fwd", "bwd", "fwd_bwd"]
    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
    return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)


def efficiency(flop, time):
    flop = flop / 1e12
    time = time / 1e6
    return flop / time


def benchmark_attention(configurations):
    results = {'fwd': defaultdict(list), 'bwd': defaultdict(list)}

    for B, H, N, D, causal in configurations:
        print("=" * 60)
        print(f"Timing forward and backward pass for B={B}, H={H}, N={N}, D={D}, causal={causal}")

        q = torch.randn(B, H, N, D, dtype=torch.bfloat16, device='cuda', requires_grad=False).contiguous()
        k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device='cuda', requires_grad=False).contiguous()
        v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device='cuda', requires_grad=False).contiguous()

        grad_output = torch.randn_like(q, requires_grad=False).contiguous()

        qg = torch.zeros_like(q, requires_grad=False, dtype=torch.float).contiguous()
        kg = torch.zeros_like(k, requires_grad=False, dtype=torch.float).contiguous()
        vg = torch.zeros_like(v, requires_grad=False, dtype=torch.float).contiguous()

        # Prepare for timing forward pass
        start_events_fwd = [torch.cuda.Event(enable_timing=True) for _ in range(10)]
        end_events_fwd = [torch.cuda.Event(enable_timing=True) for _ in range(10)]

        torch.cuda.empty_cache()
        torch.cuda.synchronize()

        # Warmup for forward pass
        for _ in range(10):
            o = sliding_tile_attention(q, k, v, [[6, 6, 6]] * 24, 0, False)

        # Time the forward pass
        for i in range(10):
            start_events_fwd[i].record()
            o = sliding_tile_attention(q, k, v, [[6, 6, 6]] * 24, 0, False)
            end_events_fwd[i].record()

        torch.cuda.synchronize()
        times_fwd = [s.elapsed_time(e) for s, e in zip(start_events_fwd, end_events_fwd)]
        time_us_fwd = np.mean(times_fwd) * 1000

        tflops_fwd = efficiency(flops(B, N, H, D, causal, 'fwd'), time_us_fwd)
        results['fwd'][(D, causal)].append((N, tflops_fwd))

        print(f"Average time for forward pass in us: {time_us_fwd:.2f}")
        print(f"Average efficiency for forward pass in TFLOPS: {tflops_fwd}")
        print("-" * 60)

        # torch.cuda.empty_cache()
        # torch.cuda.synchronize()

        # # Prepare for timing backward pass
        # start_events_bwd = [torch.cuda.Event(enable_timing=True) for _ in range(10)]
        # end_events_bwd = [torch.cuda.Event(enable_timing=True) for _ in range(10)]

        # # Warmup for backward pass
        # for _ in range(10):
        #     qg, kg, vg = tk.mha_backward(q, k, v, o, l_vec, grad_output, causal)

        # # Time the backward pass
        # for i in range(10):
        #     start_events_bwd[i].record()
        #     qg, kg, vg = tk.mha_backward(q, k, v, o, l_vec, grad_output, causal)
        #     end_events_bwd[i].record()

        # torch.cuda.synchronize()
        # times_bwd = [s.elapsed_time(e) for s, e in zip(start_events_bwd, end_events_bwd)]
        # time_us_bwd = np.mean(times_bwd) * 1000

        # tflops_bwd = efficiency(flops(B, N, H, D, causal, 'bwd'), time_us_bwd)
        # results['bwd'][(D, causal)].append((N, tflops_bwd))

        # print(f"Average time for backward pass in us: {time_us_bwd:.2f}")
        # print(f"Average efficiency for backward pass in TFLOPS: {tflops_bwd}")
        print("=" * 60)

        torch.cuda.empty_cache()
        torch.cuda.synchronize()

    return results


def plot_results(results):
    os.makedirs('benchmark_results', exist_ok=True)
    for mode in ['fwd', 'bwd']:
        for (D, causal), values in results[mode].items():
            seq_lens = [x[0] for x in values]
            tflops = [x[1] for x in values]

            plt.figure(figsize=(10, 6))
            bars = plt.bar(range(len(seq_lens)), tflops, tick_label=seq_lens)
            plt.xlabel('Sequence Length')
            plt.ylabel('TFLOPS')
            plt.title(f'{mode.upper()} Pass - Head Dim: {D}, Causal: {causal}')
            plt.grid(True)

            # Adding the numerical y value on top of each bar
            for bar in bars:
                yval = bar.get_height()
                plt.text(bar.get_x() + bar.get_width() / 2, yval, round(yval, 2), ha='center', va='bottom')

            filename = f'benchmark_results/{mode}_D{D}_causal{causal}.png'
            plt.savefig(filename)
            plt.close()


# Example list of configurations to test
configurations = [
    (2, 24, 82944, 128, False),
    # (16, 16, 768*16,  128, False),
    # (16, 16, 768*2,  128, False),
    # (16, 16, 768*4,  128, False),
    # (16, 16, 768*8,  128, False),
    # (16, 16, 768*16, 128, False),
    # (16, 16, 768,    128, True),
    # (16, 16, 768*2,  128, True),
    # (16, 16, 768*4,  128, True),
    # (16, 16, 768*8,  128, True),
    # (16, 16, 768*16, 128, True),
    # (16, 32, 768,    64,  False),
    # (16, 32, 768*2,  64,  False),
    # (16, 32, 768*4,  64,  False),
    # (16, 32, 768*8,  64,  False),
    # (16, 32, 768*16, 64,  False),
    # (16, 32, 768,    64,  True),
    # (16, 32, 768*2,  64,  True),
    # (16, 32, 768*4,  64,  True),
    # (16, 32, 768*8,  64,  True),
    # (16, 32, 768*16, 64,  True),
]

results = benchmark_attention(configurations)
# plot_results(results)