gemm_perf.py 8.18 KB
Newer Older
turneram's avatar
turneram committed
1
#%matplotlib
2
import subprocess, csv, re, datetime, argparse, os, enum
turneram's avatar
turneram committed
3
4
5
6
7
from subprocess import STDOUT


def parse_args():
    parser = argparse.ArgumentParser(description="GEMM performance tools")
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
    parser.add_argument('--gemm',
                        action='store_true',
                        help='Run performance comparison on a range of GEMM problem sizes (fp16/int8)')
    parser.add_argument('--int8',
                        action='store_true',
                        help='Quantize GEMMs to int8 precision (not available for GEMM-Softmax-GEMM)')
    parser.add_argument('--gemm-softmax-gemm',
                        action='store_true',
                        help='Run performance comparison on a range of GEMM-Softmax-GEMM problem sizes (fp16)')
    parser.add_argument('--batch_sizes',
                        '-b',
                        nargs='+',
                        help='Batch sizes to run',
                        required=True)
    parser.add_argument('--exclude-lens',
                        '-e',
                        nargs='+',
turneram's avatar
turneram committed
25
26
27
28
29
30
31
32
33
                        help='Exclude lengths from [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096] \
                              Lengths not excluded will be permuted as m, n, k, (o) inputs')
    parser.add_argument('--lens-from-file',
                        '-l',
                        help='Run a list of problem lens from file containing rows of \
                              m0 n0 k0 (o0) \
                              m1 n1 k1 (o1) \
                              ... \
                              (oi) only used for gemm-softmax-gemm')
34
35
36
37
38
    parser.add_argument('--timeout',
                        '-t',
                        type=int,
                        default=600,
                        help='Time in seconds before compilation timeout')
turneram's avatar
turneram committed
39
40
41
42
43
44
45
46
47
48
    args = parser.parse_args()

    return args


class CSVFile:
    def __init__(self, path="output.csv"):
        self.path = path

    def write_row(self, row=[]):
49
        row = [str(r) for r in row]
turneram's avatar
turneram committed
50
51
52
53
54
        with open(self.path, "a+") as f:
            cw = csv.writer(f)
            cw.writerow(row)


turneram's avatar
turneram committed
55
56
57
58
59
60
61
62
63
64
class GEMM_Provider(enum.Enum):
    CK = 1
    ROCBLAS = 2 
    MLIR = 3


def get_migraphx_root():
    return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


turneram's avatar
turneram committed
65
66
67
68
69
70
71
72
73
def get_device_name():
    out = subprocess.run("rocminfo",
                         capture_output=True,
                         check=True,
                         shell=True)
    matches = re.findall("gfx\d*[a-z]*", str(out.stdout))
    return matches[0]


turneram's avatar
turneram committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def verify_format(file, single):
    if not file:
        return False
    format = r"^\d\s\d\s\d\s*$"  if single else r"^\d\s\d\s\d\s\d\s*$"
    with open(file, 'r') as f:
        return all([bool(re.match(format, line)) for line in f.readlines()])


def parse_lens(file):
    with open(file, 'r') as f:
        return [tuple(map(int, line.split())) for line in f.readlines()]


def get_total_time(output):
    summary = re.findall("Summary.*", output)[0].replace("\\n", "\n")
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    return float(total_time.replace("Total time: ", ""))


def verify_output(output, provider):
    summary = re.findall("Summary.*", output)[0].replace("\\n", "\n")
    if provider == GEMM_Provider.CK:
        return bool(re.match(".*ck_gemm.*", summary))
    if provider == GEMM_Provider.ROCBLAS:
        return bool(re.match(".*gpu::gemm.*", summary))
    if provider == GEMM_Provider.MLIR:
        return bool(re.match(".*MLIR.*", summary)) #######


103
def get_gemm_time(config, fp16, provider, timeout):
turneram's avatar
turneram committed
104
105
    root = get_migraphx_root()
    model = f"{root}/test/onnx/matmul_half.onnx" 
106
    b, m, n, k = config
turneram's avatar
turneram committed
107
    prec_str = "--fp16" if fp16 else "--int8"
turneram's avatar
turneram committed
108
    cmd = f"{root}/build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n} {prec_str} --exhaustive-tune"
109
110
111
    use_CK = "1" if provider == GEMM_Provider.CK else "0"
    use_MLIR = "1" if provider == GEMM_Provider.MLIR else "0"

turneram's avatar
turneram committed
112
113
114
115
    try:
        out = subprocess.run(cmd.split(),
                             capture_output=True,
                             check=True,
116
117
118
119
                             timeout=timeout,
                             env=dict(os.environ, 
                                      MIGRAPHX_ENABLE_CK=use_CK,
                                      MIGRAPHX_ENABLE_MLIR=use_MLIR))
turneram's avatar
turneram committed
120
    except Exception as e:
121
        return -100.0
turneram's avatar
turneram committed
122
    
turneram's avatar
turneram committed
123
124
125
    verify_output(str(out.stdout), provider)
    total_time = get_total_time(str(out.stdout))
    print(f"{provider.name} finished in {total_time}")
turneram's avatar
turneram committed
126

turneram's avatar
turneram committed
127
    return total_time
128
129
130


def get_gemm_softmax_gemm_time(config, provider, timeout):
turneram's avatar
turneram committed
131
132
    root = get_migraphx_root()
    model = f"{root}/test/onnx/gemm_softmax_gemm_half.onnx" 
133
    b, m, n, k, o = config
turneram's avatar
turneram committed
134
    cmd = f"{root}/build/bin/driver perf {model} --input-dim @a {b} {m} {k} @b {b} {k} {n} @b1 {b} {n} {o} --fp16 --exhaustive-tune"
135
136
137
    use_CK = "1" if provider == GEMM_Provider.CK else "0"
    use_MLIR = "1" if provider == GEMM_Provider.MLIR else "0"

turneram's avatar
turneram committed
138
139
140
141
    try:
        out = subprocess.run(cmd.split(),
                             capture_output=True,
                             check=True,
142
                             timeout=timeout,
turneram's avatar
turneram committed
143
                             env=dict(os.environ, 
144
145
                                      MIGRAPHX_ENABLE_CK=use_CK,
                                      MIGRAPHX_ENABLE_MLIR=use_MLIR))
turneram's avatar
turneram committed
146
    except Exception as e:
147
148
        return -100.0
    
turneram's avatar
turneram committed
149
150
151
    summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    total_time = total_time.replace("Total time: ", "")
turneram's avatar
turneram committed
152
153
    total_time = float(total_time)
    print(f"{provider.name} finished in {total_time}")
154

turneram's avatar
turneram committed
155
    return total_time
156
157


turneram's avatar
turneram committed
158
159
def run_gemm_perf(batches, sizes, fp16, timeout):
    prec_str = "fp16" if fp16 else "int8"
turneram's avatar
turneram committed
160
    for b in batches:
161
162
163
        out = CSVFile(f"gemm_perf_{prec_str}_{b}.csv")
        out.write_row([get_device_name(), datetime.datetime.now()])
        out.write_row(["batch_size", "m", "n", "k", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
turneram's avatar
turneram committed
164
        for shape in sizes:
165
            config = (b,) + shape
turneram's avatar
turneram committed
166
167
168
169
            print("Running gemm with config: {0}, {1}, {2}".format(*config))
            ck_time = get_gemm_time(config, fp16, GEMM_Provider.CK, timeout)
            rb_time = get_gemm_time(config, fp16, GEMM_Provider.ROCBLAS, timeout)
            mlir_time = get_gemm_time(config, fp16, GEMM_Provider.MLIR, timeout)
170
171
172
            out.write_row(list(config) + [ck_time, rb_time, mlir_time])


turneram's avatar
turneram committed
173
def run_gemm_softmax_gemm_perf(batches, sizes, timeout):
174
175
176
177
    for b in batches:
        out = CSVFile(f"gemm_softmax_gemm_perf_fp16_{b}.csv")
        out.write_row([get_device_name(), datetime.datetime.now()])
        out.write_row(["batch_size", "m", "n", "k", "o", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
turneram's avatar
turneram committed
178
        for shape in sizes:
179
            config = (b,) + shape
turneram's avatar
turneram committed
180
181
182
183
            print("Running gemm-softmax-gemm with config: {0}, {1}, {2}, {3}".format(*config))
            ck_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.CK, timeout)
            rb_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.ROCBLAS, timeout)
            mlir_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.MLIR, timeout)
184
185
            out.write_row(list(config) + [ck_time, rb_time, mlir_time])

turneram's avatar
turneram committed
186
187
188

if __name__ == "__main__":
    args = parse_args()
189
190
191
192
193
194
    exclude = args.exclude_lens
    exclude = [int(x) for x in exclude]
    sizes = [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096]
    sizes = [x for x in sizes if x not in exclude]
    fp16 = not args.int8
    timeout = int(args.timeout)
turneram's avatar
turneram committed
195
    if args.gemm:
196
        gemm_sizes = [(m, n, k) for m in sizes for n in sizes for k in sizes]
turneram's avatar
turneram committed
197
198
199
        if verify_format(args.lens_from_file, True):
            gemm_sizes = parse_lens(args.lens_from_file)
        run_gemm_perf(args.batch_sizes, gemm_sizes, fp16, timeout)
200
201
    if args.gemm_softmax_gemm:
        gemm_softmax_gemm_sizes = [(m, n, k, o) for m in sizes for n in sizes for k in sizes for o in sizes]
turneram's avatar
turneram committed
202
203
204
        if verify_format(args.lens_from_file, False):
            gemm_softmax_gemm_sizes = parse_lens(args.lens_from_file)
        run_gemm_softmax_gemm_perf(args.batch_sizes, gemm_softmax_gemm_sizes, timeout)