gemm_perf.py 7.55 KB
Newer Older
Alan Turner's avatar
Alan Turner committed
1
2
3
4
5
6
7
8
#%matplotlib
import subprocess, csv, re, datetime, argparse, os
from subprocess import STDOUT, check_output
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from pylab import *
import random

Alan Turner's avatar
Alan Turner committed
9

Alan Turner's avatar
Alan Turner committed
10
11
12
13
14
def parse_args():
    parser = argparse.ArgumentParser(description="GEMM performance tools")
    parser.add_argument('--bert',
                        action='store_true',
                        help='Run GEMM performance comparisons on BERT model')
Alan Turner's avatar
Alan Turner committed
15
16
17
18
    parser.add_argument(
        '--gemm',
        action='store_true',
        help='Run performance comparison on a range of GEMM problem sizes')
Alan Turner's avatar
Alan Turner committed
19
20
21
    args = parser.parse_args()

    return args
Alan Turner's avatar
Alan Turner committed
22

Alan Turner's avatar
Alan Turner committed
23

Alan Turner's avatar
Alan Turner committed
24
class CSVFile:
Alan Turner's avatar
Alan Turner committed
25

Alan Turner's avatar
Alan Turner committed
26
27
28
29
30
31
32
33
    def __init__(self, path="output.csv"):
        self.path = path

    def write_row(self, row=[]):
        with open(self.path, "a+") as f:
            cw = csv.writer(f)
            cw.writerow(row)

turneram's avatar
turneram committed
34
35
36
37
38
39
40
41
42

def get_device_name():
    out = subprocess.run("rocminfo",
                         capture_output=True,
                         check=True,
                         shell=True)
    matches = re.findall("gfx\d*[a-z]*", str(out.stdout))
    return matches[0]

Alan Turner's avatar
Alan Turner committed
43
44
45
46
47
48
49

def run_perf(model,
             batch_size,
             int8=False,
             use_ck=False,
             use_large_k=False,
             disable_fusion=False):
turneram's avatar
turneram committed
50
51
    env_vars = ""
    if use_ck:
Alan Turner's avatar
Alan Turner committed
52
        env_vars += "MIGRAPHX_ENABLE_CK=1 "
turneram's avatar
turneram committed
53
        if use_large_k:
Alan Turner's avatar
Alan Turner committed
54
            env_vars += "MIGRAPHX_USE_LARGE_K=1 "
turneram's avatar
turneram committed
55
56
57
        if disable_fusion:
            env_vars += "MIGRAPHX_DISABLE_CK_FUSION=1 "
    int8_str = "--int8" if int8 else ""
Alan Turner's avatar
Alan Turner committed
58
    cmd = f"{env_vars} ../build/bin/driver perf {model} --fill1 input_ids --input-dim @input_ids {batch_size} 384 --batch {batch_size} --fp16 {int8_str}  --exhaustive-tune"
Alan Turner's avatar
Alan Turner committed
59
    out = subprocess.run(cmd, capture_output=True, check=True, shell=True)
Alan Turner's avatar
Alan Turner committed
60

turneram's avatar
turneram committed
61
62
63
64
    summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    total_time = total_time.replace("Total time: ", "")

Alan Turner's avatar
Alan Turner committed
65
66
67
68
69
    ck_gemm_time = re.findall("ck_gemm_kernel: \d+\.\d*", summary)
    if ck_gemm_time:
        ck_gemm_time = re.findall("\d+\.\d*", ck_gemm_time[0])[0]
    else:
        ck_gemm_time = "0.0"
Alan Turner's avatar
Alan Turner committed
70
71
72

    rb_gemm_time = re.findall("gpu::quant_gemm: \d+\.\d*|gpu::gemm: \d+\.\d*",
                              summary)
Alan Turner's avatar
Alan Turner committed
73
74
75
76
    if rb_gemm_time:
        rb_gemm_time = re.findall("\d+\.\d*", rb_gemm_time[0])[0]
    else:
        rb_gemm_time = "0.0"
Alan Turner's avatar
Alan Turner committed
77

Alan Turner's avatar
Alan Turner committed
78
79
80
81
82
83
84
85
86
87
88
89
90
    gemm_pack_time = re.findall("gpu::int8_gemm_pack_a: \d+\.\d*", summary)
    if gemm_pack_time:
        gemm_pack_time = re.findall("\d+\.\d*", gemm_pack_time[0])[0]
    else:
        gemm_pack_time = "0.0"

    gemm_times = [ck_gemm_time, rb_gemm_time, gemm_pack_time]
    total_gemm_time = [str(sum(map(float, gemm_times)))]
    gemm_times.extend(total_gemm_time)

    print(cmd)
    print(total_time + "ms")
    with open("perf_summaries.txt", "a+") as f:
turneram's avatar
turneram committed
91
92
        f.write(cmd + "\n")
        f.write(summary + "\n\n")
Alan Turner's avatar
Alan Turner committed
93

Alan Turner's avatar
Alan Turner committed
94
    return [total_time] + gemm_times
turneram's avatar
turneram committed
95

Alan Turner's avatar
Alan Turner committed
96

Alan Turner's avatar
Alan Turner committed
97
def run_ck_perf(model, batch_size, int8=False, use_large_k=False):
Alan Turner's avatar
Alan Turner committed
98
    # CK with fusions
Alan Turner's avatar
Alan Turner committed
99
    total_time = run_perf(model, batch_size, int8, True, use_large_k, False)[0]
Alan Turner's avatar
Alan Turner committed
100
    # CK without fusions
Alan Turner's avatar
Alan Turner committed
101
    gemm_times = run_perf(model, batch_size, int8, True, use_large_k, True)
turneram's avatar
turneram committed
102

Alan Turner's avatar
Alan Turner committed
103
    return [total_time] + gemm_times[1:]
turneram's avatar
turneram committed
104
105


Alan Turner's avatar
Alan Turner committed
106
def run_bert_perf():
turneram's avatar
turneram committed
107
108
    device_id = get_device_name()
    model = "/code/bert_base_cased_1_fp16_gpu.onnx"
Alan Turner's avatar
Alan Turner committed
109
110
111
112
    cf = CSVFile()
    cf.write_row([str(datetime.datetime.now())])
    cf.write_row([device_id])
    cf.write_row([model])
Alan Turner's avatar
Alan Turner committed
113
114
115
116
    headers = [
        "", "Total Time (ms)", "CK GEMM Time (ms)", "RB GEMM Time (ms)",
        "GEMM Pack Time (ms)", "Total GEMM Time (ms)"
    ]
Alan Turner's avatar
Alan Turner committed
117
118
119
120
121
122
123
124
125
126

    batch_size = "1"
    # int8:
    quantize = True
    label = f"Int8 / BatchSize: {batch_size}" if quantize else f"FP16 / BatchSize: {batch_size}"
    cf.write_row([label])
    cf.write_row(headers)
    # CK Only
    cf.write_row(["CK"] + run_ck_perf(model, batch_size, quantize, True))
    # CK + rocBLAS (k>2048)
Alan Turner's avatar
Alan Turner committed
127
128
    cf.write_row(["CK + rocBLAS(k>2048)"] +
                 run_ck_perf(model, batch_size, quantize, False))
Alan Turner's avatar
Alan Turner committed
129
130
131
132
133
134
135
136
137
138
139
140
    # rocBLAS Only
    cf.write_row(["rocBLAS"] + run_perf(model, batch_size, quantize))
    cf.write_row()

    # fp16:
    quantize = False
    label = f"Int8 / BatchSize: {batch_size}" if quantize else f"FP16 / BatchSize: {batch_size}"
    cf.write_row([label])
    cf.write_row(headers)
    # CK Only
    cf.write_row(["CK"] + run_ck_perf(model, batch_size, quantize, True))
    # CK + rocBLAS (k>2048)
Alan Turner's avatar
Alan Turner committed
141
142
    cf.write_row(["CK + rocBLAS(k>2048)"] +
                 run_ck_perf(model, batch_size, quantize, False))
Alan Turner's avatar
Alan Turner committed
143
144
145
146
147
148
149
150
151
152
153
154
155
    # rocBLAS Only
    cf.write_row(["rocBLAS"] + run_perf(model, batch_size, quantize))
    cf.write_row()

    batch_size = "64"
    # int8:
    quantize = True
    label = f"Int8 / BatchSize: {batch_size}" if quantize else f"FP16 / BatchSize: {batch_size}"
    cf.write_row([label])
    cf.write_row(headers)
    # CK Only
    cf.write_row(["CK"] + run_ck_perf(model, batch_size, quantize, True))
    # CK + rocBLAS (k>2048)
Alan Turner's avatar
Alan Turner committed
156
157
    cf.write_row(["CK + rocBLAS(k>2048)"] +
                 run_ck_perf(model, batch_size, quantize, False))
Alan Turner's avatar
Alan Turner committed
158
159
160
161
162
163
164
165
166
167
168
169
    # rocBLAS Only
    cf.write_row(["rocBLAS"] + run_perf(model, batch_size, quantize))
    cf.write_row()

    # fp16:
    quantize = False
    label = f"Int8 / BatchSize: {batch_size}" if quantize else f"FP16 / BatchSize: {batch_size}"
    cf.write_row([label])
    cf.write_row(headers)
    # CK Only
    cf.write_row(["CK"] + run_ck_perf(model, batch_size, quantize, True))
    # CK + rocBLAS (k>2048)
Alan Turner's avatar
Alan Turner committed
170
171
    cf.write_row(["CK + rocBLAS(k>2048)"] +
                 run_ck_perf(model, batch_size, quantize, False))
Alan Turner's avatar
Alan Turner committed
172
173
    # rocBLAS Only
    cf.write_row(["rocBLAS"] + run_perf(model, batch_size, quantize))
Alan Turner's avatar
Alan Turner committed
174
    cf.write_row()
Alan Turner's avatar
Alan Turner committed
175

Alan Turner's avatar
Alan Turner committed
176

Alan Turner's avatar
Alan Turner committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def gemm_perf(b, m, n, k, fp16):
    print(f"{b}, {m}, {n}, {k}:", end=" ")
    model = "../test/onnx/matmul_half.onnx" if fp16 else "../test/onnx/matmul_int8.onnx"
    #rocBLAS run
    cmd = f"MIGRAPHX_ENABLE_CK=0 ../build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n}"
    out = subprocess.run(cmd, capture_output=True, check=True, shell=True)
    summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
    # print(summary)
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    total_time = total_time.replace("Total time: ", "")
    rb_time = total_time

    cmd = f"../build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n} --exhaustive-tune"
    try:
Alan Turner's avatar
Alan Turner committed
191
192
193
194
195
        out = subprocess.run(cmd.split(),
                             capture_output=True,
                             check=True,
                             timeout=300,
                             env=dict(os.environ, MIGRAPHX_ENABLE_CK="1"))
Alan Turner's avatar
Alan Turner committed
196
197
198
199
200
201
202
203
204
205
    except:
        print("-69.0")
        return -69.0

    summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
    # print(summary)
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    total_time = total_time.replace("Total time: ", "")
    ck_time = total_time

Alan Turner's avatar
Alan Turner committed
206
    diff = float(ck_time) - float(rb_time)
Alan Turner's avatar
Alan Turner committed
207
208
209
    print(f"{diff}")
    return diff

Alan Turner's avatar
Alan Turner committed
210

Alan Turner's avatar
Alan Turner committed
211
212
213
def run_gemm_perf():
    batches = [1]
    sizes = [64, 256, 384, 768, 1024, 2048, 2304, 3072]
Alan Turner's avatar
Alan Turner committed
214
215
    results = [(b, m, n, k, gemm_perf(b, m, n, k, False)) for b in batches
               for m in sizes for n in sizes for k in sizes]
Alan Turner's avatar
Alan Turner committed
216
217
218
219
220
    print(results)
    with open("gemm_results.txt", "w+") as f:
        for r in results:
            f.write(f"{r[0]}, {r[1]}, {r[2]}, {r[3]}, {r[4]}\n")

Alan Turner's avatar
Alan Turner committed
221

Alan Turner's avatar
Alan Turner committed
222
223
224
225
226
if __name__ == "__main__":
    args = parse_args()
    if args.bert:
        run_bert_perf()
    if args.gemm:
Alan Turner's avatar
Alan Turner committed
227
        run_gemm_perf()