gemm_perf.py 7.44 KB
Newer Older
Alan Turner's avatar
Alan Turner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#%matplotlib
import subprocess, csv, re, datetime, argparse, os
from subprocess import STDOUT, check_output
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from pylab import *
import random

def parse_args():
    parser = argparse.ArgumentParser(description="GEMM performance tools")
    parser.add_argument('--bert',
                        action='store_true',
                        help='Run GEMM performance comparisons on BERT model')
    parser.add_argument('--gemm',
                        action='store_true',
                        help='Run performance comparison on a range of GEMM problem sizes')
    args = parser.parse_args()

    return args
Alan Turner's avatar
Alan Turner committed
20

Alan Turner's avatar
Alan Turner committed
21

Alan Turner's avatar
Alan Turner committed
22
class CSVFile:
Alan Turner's avatar
Alan Turner committed
23

Alan Turner's avatar
Alan Turner committed
24
25
26
27
28
29
30
31
    def __init__(self, path="output.csv"):
        self.path = path

    def write_row(self, row=[]):
        with open(self.path, "a+") as f:
            cw = csv.writer(f)
            cw.writerow(row)

turneram's avatar
turneram committed
32
33
34
35
36
37
38
39
40

def get_device_name():
    out = subprocess.run("rocminfo",
                         capture_output=True,
                         check=True,
                         shell=True)
    matches = re.findall("gfx\d*[a-z]*", str(out.stdout))
    return matches[0]

Alan Turner's avatar
Alan Turner committed
41
42
43
44
45
46
47

def run_perf(model,
             batch_size,
             int8=False,
             use_ck=False,
             use_large_k=False,
             disable_fusion=False):
turneram's avatar
turneram committed
48
49
    env_vars = ""
    if use_ck:
Alan Turner's avatar
Alan Turner committed
50
        env_vars += "MIGRAPHX_ENABLE_CK=1 "
turneram's avatar
turneram committed
51
        if use_large_k:
Alan Turner's avatar
Alan Turner committed
52
            env_vars += "MIGRAPHX_USE_LARGE_K=1 "
turneram's avatar
turneram committed
53
54
55
        if disable_fusion:
            env_vars += "MIGRAPHX_DISABLE_CK_FUSION=1 "
    int8_str = "--int8" if int8 else ""
Alan Turner's avatar
Alan Turner committed
56
    cmd = f"{env_vars} ../build/bin/driver perf {model} --fill1 input_ids --input-dim @input_ids {batch_size} 384 --batch {batch_size} --fp16 {int8_str}  --exhaustive-tune"
Alan Turner's avatar
Alan Turner committed
57
    out = subprocess.run(cmd, capture_output=True, check=True, shell=True)
Alan Turner's avatar
Alan Turner committed
58

turneram's avatar
turneram committed
59
60
61
62
    summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    total_time = total_time.replace("Total time: ", "")

Alan Turner's avatar
Alan Turner committed
63
64
65
66
67
    ck_gemm_time = re.findall("ck_gemm_kernel: \d+\.\d*", summary)
    if ck_gemm_time:
        ck_gemm_time = re.findall("\d+\.\d*", ck_gemm_time[0])[0]
    else:
        ck_gemm_time = "0.0"
Alan Turner's avatar
Alan Turner committed
68
69
70

    rb_gemm_time = re.findall("gpu::quant_gemm: \d+\.\d*|gpu::gemm: \d+\.\d*",
                              summary)
Alan Turner's avatar
Alan Turner committed
71
72
73
74
    if rb_gemm_time:
        rb_gemm_time = re.findall("\d+\.\d*", rb_gemm_time[0])[0]
    else:
        rb_gemm_time = "0.0"
Alan Turner's avatar
Alan Turner committed
75

Alan Turner's avatar
Alan Turner committed
76
77
78
79
80
81
82
83
84
85
86
87
88
    gemm_pack_time = re.findall("gpu::int8_gemm_pack_a: \d+\.\d*", summary)
    if gemm_pack_time:
        gemm_pack_time = re.findall("\d+\.\d*", gemm_pack_time[0])[0]
    else:
        gemm_pack_time = "0.0"

    gemm_times = [ck_gemm_time, rb_gemm_time, gemm_pack_time]
    total_gemm_time = [str(sum(map(float, gemm_times)))]
    gemm_times.extend(total_gemm_time)

    print(cmd)
    print(total_time + "ms")
    with open("perf_summaries.txt", "a+") as f:
turneram's avatar
turneram committed
89
90
        f.write(cmd + "\n")
        f.write(summary + "\n\n")
Alan Turner's avatar
Alan Turner committed
91

Alan Turner's avatar
Alan Turner committed
92
    return [total_time] + gemm_times
turneram's avatar
turneram committed
93

Alan Turner's avatar
Alan Turner committed
94

Alan Turner's avatar
Alan Turner committed
95
def run_ck_perf(model, batch_size, int8=False, use_large_k=False):
Alan Turner's avatar
Alan Turner committed
96
    # CK with fusions
Alan Turner's avatar
Alan Turner committed
97
    total_time = run_perf(model, batch_size, int8, True, use_large_k, False)[0]
Alan Turner's avatar
Alan Turner committed
98
    # CK without fusions
Alan Turner's avatar
Alan Turner committed
99
    gemm_times = run_perf(model, batch_size, int8, True, use_large_k, True)
turneram's avatar
turneram committed
100

Alan Turner's avatar
Alan Turner committed
101
    return [total_time] + gemm_times[1:]
turneram's avatar
turneram committed
102
103


Alan Turner's avatar
Alan Turner committed
104
def run_bert_perf():
turneram's avatar
turneram committed
105
106
    device_id = get_device_name()
    model = "/code/bert_base_cased_1_fp16_gpu.onnx"
Alan Turner's avatar
Alan Turner committed
107
108
109
110
    cf = CSVFile()
    cf.write_row([str(datetime.datetime.now())])
    cf.write_row([device_id])
    cf.write_row([model])
Alan Turner's avatar
Alan Turner committed
111
112
113
114
    headers = [
        "", "Total Time (ms)", "CK GEMM Time (ms)", "RB GEMM Time (ms)",
        "GEMM Pack Time (ms)", "Total GEMM Time (ms)"
    ]
Alan Turner's avatar
Alan Turner committed
115
116
117
118
119
120
121
122
123
124

    batch_size = "1"
    # int8:
    quantize = True
    label = f"Int8 / BatchSize: {batch_size}" if quantize else f"FP16 / BatchSize: {batch_size}"
    cf.write_row([label])
    cf.write_row(headers)
    # CK Only
    cf.write_row(["CK"] + run_ck_perf(model, batch_size, quantize, True))
    # CK + rocBLAS (k>2048)
Alan Turner's avatar
Alan Turner committed
125
126
    cf.write_row(["CK + rocBLAS(k>2048)"] +
                 run_ck_perf(model, batch_size, quantize, False))
Alan Turner's avatar
Alan Turner committed
127
128
129
130
131
132
133
134
135
136
137
138
    # rocBLAS Only
    cf.write_row(["rocBLAS"] + run_perf(model, batch_size, quantize))
    cf.write_row()

    # fp16:
    quantize = False
    label = f"Int8 / BatchSize: {batch_size}" if quantize else f"FP16 / BatchSize: {batch_size}"
    cf.write_row([label])
    cf.write_row(headers)
    # CK Only
    cf.write_row(["CK"] + run_ck_perf(model, batch_size, quantize, True))
    # CK + rocBLAS (k>2048)
Alan Turner's avatar
Alan Turner committed
139
140
    cf.write_row(["CK + rocBLAS(k>2048)"] +
                 run_ck_perf(model, batch_size, quantize, False))
Alan Turner's avatar
Alan Turner committed
141
142
143
144
145
146
147
148
149
150
151
152
153
    # rocBLAS Only
    cf.write_row(["rocBLAS"] + run_perf(model, batch_size, quantize))
    cf.write_row()

    batch_size = "64"
    # int8:
    quantize = True
    label = f"Int8 / BatchSize: {batch_size}" if quantize else f"FP16 / BatchSize: {batch_size}"
    cf.write_row([label])
    cf.write_row(headers)
    # CK Only
    cf.write_row(["CK"] + run_ck_perf(model, batch_size, quantize, True))
    # CK + rocBLAS (k>2048)
Alan Turner's avatar
Alan Turner committed
154
155
    cf.write_row(["CK + rocBLAS(k>2048)"] +
                 run_ck_perf(model, batch_size, quantize, False))
Alan Turner's avatar
Alan Turner committed
156
157
158
159
160
161
162
163
164
165
166
167
    # rocBLAS Only
    cf.write_row(["rocBLAS"] + run_perf(model, batch_size, quantize))
    cf.write_row()

    # fp16:
    quantize = False
    label = f"Int8 / BatchSize: {batch_size}" if quantize else f"FP16 / BatchSize: {batch_size}"
    cf.write_row([label])
    cf.write_row(headers)
    # CK Only
    cf.write_row(["CK"] + run_ck_perf(model, batch_size, quantize, True))
    # CK + rocBLAS (k>2048)
Alan Turner's avatar
Alan Turner committed
168
169
    cf.write_row(["CK + rocBLAS(k>2048)"] +
                 run_ck_perf(model, batch_size, quantize, False))
Alan Turner's avatar
Alan Turner committed
170
171
    # rocBLAS Only
    cf.write_row(["rocBLAS"] + run_perf(model, batch_size, quantize))
Alan Turner's avatar
Alan Turner committed
172
    cf.write_row()
Alan Turner's avatar
Alan Turner committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

def gemm_perf(b, m, n, k, fp16):
    print(f"{b}, {m}, {n}, {k}:", end=" ")
    model = "../test/onnx/matmul_half.onnx" if fp16 else "../test/onnx/matmul_int8.onnx"
    #rocBLAS run
    cmd = f"MIGRAPHX_ENABLE_CK=0 ../build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n}"
    out = subprocess.run(cmd, capture_output=True, check=True, shell=True)
    summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
    # print(summary)
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    total_time = total_time.replace("Total time: ", "")
    rb_time = total_time

    cmd = f"../build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n} --exhaustive-tune"
    try:
        out = subprocess.run(cmd.split(), capture_output=True, check=True, timeout=300, env=dict(os.environ, MIGRAPHX_ENABLE_CK="1"))
    except:
        print("-69.0")
        return -69.0

    summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
    # print(summary)
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    total_time = total_time.replace("Total time: ", "")
    ck_time = total_time

    diff = float(ck_time)-float(rb_time)
    print(f"{diff}")
    return diff

def run_gemm_perf():
    batches = [1]
    sizes = [64, 256, 384, 768, 1024, 2048, 2304, 3072]
    results = [(b, m, n, k, gemm_perf(b, m, n, k, False)) for b in batches for m in sizes for n in sizes for k in sizes]
    print(results)
    with open("gemm_results.txt", "w+") as f:
        for r in results:
            f.write(f"{r[0]}, {r[1]}, {r[2]}, {r[3]}, {r[4]}\n")

if __name__ == "__main__":
    args = parse_args()
    if args.bert:
        run_bert_perf()
    if args.gemm:
        run_gemm_perf()