benchmark_ours.py 9.3 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import math
import numpy
import torch
import torch.utils.benchmark as benchmark
from collections import namedtuple
import argparse


def flops(batch, seq_len, headdim, qheads, causal, mode="fwd"):
    assert mode in ["fwd", "bwd", "fwd_bwd"]
    if (isinstance(seq_len, int)):
        f = 4 * batch * seq_len**2 * qheads * headdim // (2 if causal else 1)
    else:
        tmp = sum([(seq_len[k + 1] - seq_len[k])**2 for k in range(len(seq_len) - 1)])
        f = 4 * tmp * qheads * headdim // (2 if causal else 1)
    return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)

def benchmark_forward(
    fn, *inputs, repeats=100, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
):
    """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
    if verbose:
        print(desc, "- Forward pass")

    def amp_wrapper(*inputs, **kwinputs):
        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
            fn(*inputs, **kwinputs)

    t = benchmark.Timer(
        stmt="fn_amp(*inputs, **kwinputs)",
        globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
        num_threads=torch.get_num_threads(),
    )
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return t, m

def efficiency(flop, time):
    return (flop / time / 10**12)

def warp_tensor(tensor, gpu_is_ours, is_varlen=False, num_head=None):
    if (not is_varlen):
        return tensor if (gpu_is_ours) else tensor.transpose(1, 2).contiguous()
    else:
        return tensor if (gpu_is_ours) else tensor.view(-1, num_head, tensor.shape[-1])


parser = argparse.ArgumentParser(description='test')
parser.add_argument('--repeats',    default=1, type=int, help='run times during once benchmark')
parser.add_argument('--iterations', default=6, type=int, help='times of benchmark')
parser.add_argument('--compare', default=None, type=str, help='competitor card name')
parser.add_argument('--ratio', default=False, action='store_true', help='whether compute ratio of ours/nvidia')
args = parser.parse_args()

# prepare testing cases
params = namedtuple('param', ['causal', 'batch_size','qheads','kvheads','seq_len','head_size','window_size'])
params_list = [
    params(batch_size=4, qheads=32, kvheads=32, seq_len=(0, 1000, 2000, 3000, 4000), head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=2, qheads=32, kvheads=32, seq_len=(0, 2000, 4000), head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=4, qheads=16, kvheads=2, seq_len=(0, 1000, 2000, 3000, 4000), head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=2, qheads=16, kvheads=2, seq_len=(0, 2000, 4000), head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=1, qheads=16, kvheads=2, seq_len=(0, 20000), head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=1, qheads=16, kvheads=2, seq_len=(0, 20305), head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=1, qheads=16, kvheads=16, seq_len=8192, head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=1, qheads=32, kvheads=32, seq_len=8192, head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=1, qheads=32, kvheads=4, seq_len=8192, head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=1, qheads=52, kvheads=4, seq_len=8192, head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=1, qheads=16, kvheads=2, seq_len=8192, head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=1, qheads=26, kvheads=2, seq_len=8192, head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=1, qheads=8, kvheads=1, seq_len=8192, head_size=128, causal=True, window_size=[-1,-1]),
    params(batch_size=1, qheads=13, kvheads=1, seq_len=8192, head_size=128, causal=True, window_size=[-1,-1]),
]

import flash_attn
import flash_attn_2_cuda as _C_flashattention
print("load flash_attn from package")
gpu_card_info = torch.cuda.get_device_properties(0)
gpu_is_ours   = bool("NVIDIA" not in gpu_card_info.name)


speed_on_this_gpu = []
for idx, params in enumerate(params_list):
    torch.cuda.empty_cache()
    cost_time   = []
    device      = "cuda"
    causal      = params.causal
    batch_size  = params.batch_size
    qheads      = params.qheads
    kvheads     = params.kvheads
    seq_len     = params.seq_len
    head_size   = params.head_size
    window_size = params.window_size
    flops_count = flops(batch_size, seq_len, head_size, qheads, causal)
    repeats     = args.repeats
    iterations  = args.iterations
    is_varlen   = isinstance(seq_len, tuple)
    for i in range(iterations):
        torch.cuda.empty_cache()
        if (is_varlen):
            max_seqlen_q   = max([seq_len[k + 1] - seq_len[k] for k in range(len(seq_len) - 1)])
            seq_len        = torch.tensor(list(seq_len), dtype=torch.int32).cuda()
            total_seqlen_q = seq_len[-1].item()
            q = warp_tensor(torch.randn(qheads * total_seqlen_q, head_size, device=device,dtype=torch.float16), gpu_is_ours, is_varlen, qheads)
            k = warp_tensor(torch.randn(kvheads * total_seqlen_q, head_size, device=device,dtype=torch.float16), gpu_is_ours, is_varlen, kvheads)
            v = warp_tensor(torch.randn(kvheads * total_seqlen_q, head_size, device=device,dtype=torch.float16), gpu_is_ours, is_varlen, kvheads)
            if ("2.6" in str(flash_attn.__version__)):
                fa_varlen_args = (q, k, v, None, seq_len, seq_len, None, None, max_seqlen_q, max_seqlen_q, 0.0, 1.0 / math.sqrt(head_size), False, causal, window_size[0], window_size[1], 0.0, False, None)
            else:
                fa_varlen_args = (q, k, v, None, seq_len, seq_len, None, None, max_seqlen_q, max_seqlen_q, 0.0, 1.0 / math.sqrt(head_size), False, causal, window_size[0], window_size[1], False, None)
            t = benchmark_forward(_C_flashattention.varlen_fwd, *fa_varlen_args, repeats=repeats, verbose=False)[1].times[0]
        else:
            q = warp_tensor(torch.randn(batch_size, qheads,  seq_len, head_size, device=device,dtype=torch.float16, requires_grad=True), gpu_is_ours)
            k = warp_tensor(torch.randn(batch_size, kvheads, seq_len, head_size, device=device,dtype=torch.float16, requires_grad=True), gpu_is_ours)
            v = warp_tensor(torch.randn(batch_size, kvheads, seq_len, head_size, device=device,dtype=torch.float16, requires_grad=True), gpu_is_ours)
            t = benchmark_forward(flash_attn.flash_attn_interface.flash_attn_func, q, k, v, 0.0, causal=causal, window_size=window_size, repeats=repeats, verbose=False)[1].times[0]
        if(i > 0):
            cost_time.append(t)
            # print("{:.9f}  {:.9f}".format(t, efficiency(flops_count, t)))
        # delete the data each time to avoid detecting the cache
        del q, k, v
    cost_time = numpy.array(cost_time)
    cost_time_mean = cost_time.mean()
    # remove bursts of dirty data
    cost_time = numpy.delete(cost_time, numpy.where(cost_time < (0.8 * cost_time_mean)))
    cost_time_mean = cost_time.mean()
    speed = efficiency(flops_count, cost_time_mean)
    speed_on_this_gpu.append(speed)

if (gpu_is_ours):
    if (args.ratio):
        for it in speed_on_this_gpu:
            print(it)
        exit()
    # prepare performance sheet for comparison
    nvidia_performance = {
        # for L20, the numerical value of "repeat" has very little effect, and thus only one piece of data. "repeats" of 100 is adopted
        "L20": [81.95, 89.90, 74.01, 81.75, 108.61, 108.59, 101.95, 106.80, 106.89, 108.62, 102.55, 105.85, 94.71, 100.60],
        # for A800, the numerical value of "repeat" has very significant effect, and thus several pieces of data.
        "A800": [103.01, 130.44, 78.70, 99.94, 203.21, 203.51, 191.49, 204.63, 207.69, 213.23, 192.70, 204.25, 163.50, 185.51],
    }
    # acquire corresponding card
    if (args.compare is not None):
        nvidia_competitor = args.compare
        if (nvidia_competitor not in nvidia_performance.keys()):
            print("\033[1;31mPerformance of competitor is not recorded yet!\033[0m".format(nvidia_competitor))
        nvidia_speed = nvidia_performance[nvidia_competitor]
    else:
        nvidia_competitor = "A800"
        nvidia_speed = nvidia_performance[nvidia_competitor]
    # check data alignment
    if (len(nvidia_speed) != len(speed_on_this_gpu)):
        print("\x1b[31mPerformance data of ours and {} is not correct\x1b[0m\n\n".format(nvidia_competitor))
        exit()
    # output info
    speed_ratio = []
    print("ours             {}             Ratio".format(nvidia_competitor))
    for i, (ours, nvidia) in enumerate(zip(speed_on_this_gpu, nvidia_speed)):
        print("{:.9f}\t{:.9f}\t{:.2f}%".format(ours, nvidia, ours / nvidia * 100))
        speed_ratio.append(ours / nvidia)
    speed_on_this_gpu    = numpy.array(speed_on_this_gpu)
    nvidia_speed = numpy.array(nvidia_speed)
    speed_ratio  = numpy.array(speed_ratio)
    print("============================================")
    print("{:.9f}\t{:.9f}\t{:.2f}%".format(speed_on_this_gpu.mean(), nvidia_speed.mean(), speed_ratio.mean() * 100))
    print("Mean  of ours      : {:.9f}".format(speed_on_this_gpu.mean()))
    print("Mean  of NVIDIA {}: {:.9f}".format(nvidia_competitor, nvidia_speed.mean()))
    print("Ratio to NVIDIA {}: \x1b[32m{:.2f}%\x1b[0m\n\n".format(nvidia_competitor, 100 * speed_ratio.mean()))

else:
    for it in speed_on_this_gpu:
        print(it)