gemm_perf.py 8.71 KB
Newer Older
1
import subprocess, csv, re, datetime, argparse, os, enum
turneram's avatar
turneram committed
2
3
4
5


def parse_args():
    parser = argparse.ArgumentParser(description="GEMM performance tools")
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
    parser.add_argument('--gemm',
                        action='store_true',
                        help='Run performance comparison on a range of GEMM problem sizes (fp16/int8)')
    parser.add_argument('--int8',
                        action='store_true',
                        help='Quantize GEMMs to int8 precision (not available for GEMM-Softmax-GEMM)')
    parser.add_argument('--gemm-softmax-gemm',
                        action='store_true',
                        help='Run performance comparison on a range of GEMM-Softmax-GEMM problem sizes (fp16)')
    parser.add_argument('--batch_sizes',
                        '-b',
                        nargs='+',
                        help='Batch sizes to run',
                        required=True)
    parser.add_argument('--exclude-lens',
                        '-e',
                        nargs='+',
turneram's avatar
turneram committed
23
24
25
26
27
28
29
30
31
                        help='Exclude lengths from [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096] \
                              Lengths not excluded will be permuted as m, n, k, (o) inputs')
    parser.add_argument('--lens-from-file',
                        '-l',
                        help='Run a list of problem lens from file containing rows of \
                              m0 n0 k0 (o0) \
                              m1 n1 k1 (o1) \
                              ... \
                              (oi) only used for gemm-softmax-gemm')
32
33
34
35
36
    parser.add_argument('--timeout',
                        '-t',
                        type=int,
                        default=600,
                        help='Time in seconds before compilation timeout')
turneram's avatar
turneram committed
37
38
39
40
41
42
43
44
45
46
    args = parser.parse_args()

    return args


class CSVFile:
    def __init__(self, path="output.csv"):
        self.path = path

    def write_row(self, row=[]):
47
        row = [str(r) for r in row]
turneram's avatar
turneram committed
48
49
50
51
52
        with open(self.path, "a+") as f:
            cw = csv.writer(f)
            cw.writerow(row)


turneram's avatar
turneram committed
53
54
55
56
57
58
59
60
61
62
class GEMM_Provider(enum.Enum):
    CK = 1
    ROCBLAS = 2 
    MLIR = 3


def get_migraphx_root():
    return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


turneram's avatar
turneram committed
63
64
65
66
67
68
69
70
71
def get_device_name():
    out = subprocess.run("rocminfo",
                         capture_output=True,
                         check=True,
                         shell=True)
    matches = re.findall("gfx\d*[a-z]*", str(out.stdout))
    return matches[0]


turneram's avatar
turneram committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def verify_format(file, single):
    if not file:
        return False
    format = r"^\d\s\d\s\d\s*$"  if single else r"^\d\s\d\s\d\s\d\s*$"
    with open(file, 'r') as f:
        return all([bool(re.match(format, line)) for line in f.readlines()])


def parse_lens(file):
    with open(file, 'r') as f:
        return [tuple(map(int, line.split())) for line in f.readlines()]


def get_total_time(output):
    summary = re.findall("Summary.*", output)[0].replace("\\n", "\n")
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    return float(total_time.replace("Total time: ", ""))


def verify_output(output, provider):
    summary = re.findall("Summary.*", output)[0].replace("\\n", "\n")
    if provider == GEMM_Provider.CK:
Alan Turner's avatar
Alan Turner committed
94
        return bool(re.search(".*ck_gemm.*", summary))
turneram's avatar
turneram committed
95
    if provider == GEMM_Provider.ROCBLAS:
Alan Turner's avatar
Alan Turner committed
96
        return bool(re.search(".*gpu::gemm.*", summary)) or bool(re.search(".*gpu::quant_gemm.*", summary))
turneram's avatar
turneram committed
97
    if provider == GEMM_Provider.MLIR:
Alan Turner's avatar
Alan Turner committed
98
        return bool(re.search(".*mlir_dot.*", summary)) or bool(re.search(".*mlir_quant_dot.*", summary))
turneram's avatar
turneram committed
99
100


101
def get_gemm_time(config, fp16, provider, timeout):
turneram's avatar
turneram committed
102
103
    root = get_migraphx_root()
    model = f"{root}/test/onnx/matmul_half.onnx" 
104
    b, m, n, k = config
turneram's avatar
turneram committed
105
    prec_str = "--fp16" if fp16 else "--int8"
turneram's avatar
turneram committed
106
    cmd = f"{root}/build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n} {prec_str} --exhaustive-tune"
107
108
109
    use_CK = "1" if provider == GEMM_Provider.CK else "0"
    use_MLIR = "1" if provider == GEMM_Provider.MLIR else "0"

turneram's avatar
turneram committed
110
111
112
113
    try:
        out = subprocess.run(cmd.split(),
                             capture_output=True,
                             check=True,
114
115
116
                             timeout=timeout,
                             env=dict(os.environ, 
                                      MIGRAPHX_ENABLE_CK=use_CK,
Alan Turner's avatar
Alan Turner committed
117
                                      MIGRAPHX_ENABLE_MLIR=use_MLIR,
118
                                      MIGRAPHX_USE_CK_ONLY="1",
Alan Turner's avatar
Alan Turner committed
119
                                      MIGRAPHX_MLIR_USE_SPECIFIC_OPS="dot"))
turneram's avatar
turneram committed
120
    except Exception as e:
121
        print(f"{provider.name} encountered an exception: {e}")
122
        return -100.0
turneram's avatar
turneram committed
123
    
Alan Turner's avatar
Alan Turner committed
124
125
126
127
128
129
130
    if verify_output(str(out.stdout), provider):
        total_time = get_total_time(str(out.stdout))
        print(f"{provider.name} total time: {total_time} ms")
        return total_time
    else:
        print(f"{provider.name} was not found in performance summary")
        return -100.0
131
132
133


def get_gemm_softmax_gemm_time(config, provider, timeout):
turneram's avatar
turneram committed
134
135
    root = get_migraphx_root()
    model = f"{root}/test/onnx/gemm_softmax_gemm_half.onnx" 
136
    b, m, n, k, o = config
turneram's avatar
turneram committed
137
    cmd = f"{root}/build/bin/driver perf {model} --input-dim @a {b} {m} {k} @b {b} {k} {n} @b1 {b} {n} {o} --fp16 --exhaustive-tune"
138
139
140
    use_CK = "1" if provider == GEMM_Provider.CK else "0"
    use_MLIR = "1" if provider == GEMM_Provider.MLIR else "0"

turneram's avatar
turneram committed
141
142
143
144
    try:
        out = subprocess.run(cmd.split(),
                             capture_output=True,
                             check=True,
145
                             timeout=timeout,
turneram's avatar
turneram committed
146
                             env=dict(os.environ, 
147
                                      MIGRAPHX_ENABLE_CK=use_CK,
Alan Turner's avatar
Alan Turner committed
148
149
                                      MIGRAPHX_ENABLE_MLIR=use_MLIR,
                                      MIGRAPHX_MLIR_USE_SPECIFIC_OPS="dot"))
turneram's avatar
turneram committed
150
    except Exception as e:
151
        print(f"{provider.name} encountered an exception: {e}")
152
153
        return -100.0
    
Alan Turner's avatar
Alan Turner committed
154
155
156
157
158
159
160
    if verify_output(str(out.stdout), provider):
        total_time = get_total_time(str(out.stdout))
        print(f"{provider.name} total time: {total_time} ms")
        return total_time
    else:
        print(f"{provider.name} was not found in performance summary")
        return -100.0
161
162


turneram's avatar
turneram committed
163
164
def run_gemm_perf(batches, sizes, fp16, timeout):
    prec_str = "fp16" if fp16 else "int8"
turneram's avatar
turneram committed
165
    for b in batches:
166
167
168
        out = CSVFile(f"gemm_perf_{prec_str}_{b}.csv")
        out.write_row([get_device_name(), datetime.datetime.now()])
        out.write_row(["batch_size", "m", "n", "k", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
turneram's avatar
turneram committed
169
        for shape in sizes:
170
            config = (b,) + shape
Alan Turner's avatar
Alan Turner committed
171
            print("Running {prec} gemm with config: {0}, {1}, {2}, {3}".format(prec=prec_str, *config))
turneram's avatar
turneram committed
172
173
174
            ck_time = get_gemm_time(config, fp16, GEMM_Provider.CK, timeout)
            rb_time = get_gemm_time(config, fp16, GEMM_Provider.ROCBLAS, timeout)
            mlir_time = get_gemm_time(config, fp16, GEMM_Provider.MLIR, timeout)
175
176
177
            out.write_row(list(config) + [ck_time, rb_time, mlir_time])


turneram's avatar
turneram committed
178
def run_gemm_softmax_gemm_perf(batches, sizes, timeout):
179
180
181
182
    for b in batches:
        out = CSVFile(f"gemm_softmax_gemm_perf_fp16_{b}.csv")
        out.write_row([get_device_name(), datetime.datetime.now()])
        out.write_row(["batch_size", "m", "n", "k", "o", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
turneram's avatar
turneram committed
183
        for shape in sizes:
184
            config = (b,) + shape
Alan Turner's avatar
Alan Turner committed
185
            print("Running fp16 gemm-softmax-gemm with config: {0}, {1}, {2}, {3}, {4}".format(*config))
turneram's avatar
turneram committed
186
187
188
            ck_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.CK, timeout)
            rb_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.ROCBLAS, timeout)
            mlir_time = get_gemm_softmax_gemm_time(config, GEMM_Provider.MLIR, timeout)
189
190
            out.write_row(list(config) + [ck_time, rb_time, mlir_time])

turneram's avatar
turneram committed
191
192
193

if __name__ == "__main__":
    args = parse_args()
194
195
196
197
198
199
    exclude = args.exclude_lens
    exclude = [int(x) for x in exclude]
    sizes = [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096]
    sizes = [x for x in sizes if x not in exclude]
    fp16 = not args.int8
    timeout = int(args.timeout)
turneram's avatar
turneram committed
200
    if args.gemm:
201
        gemm_sizes = [(m, n, k) for m in sizes for n in sizes for k in sizes]
turneram's avatar
turneram committed
202
203
204
        if verify_format(args.lens_from_file, True):
            gemm_sizes = parse_lens(args.lens_from_file)
        run_gemm_perf(args.batch_sizes, gemm_sizes, fp16, timeout)
205
206
    if args.gemm_softmax_gemm:
        gemm_softmax_gemm_sizes = [(m, n, k, o) for m in sizes for n in sizes for k in sizes for o in sizes]
turneram's avatar
turneram committed
207
208
209
        if verify_format(args.lens_from_file, False):
            gemm_softmax_gemm_sizes = parse_lens(args.lens_from_file)
        run_gemm_softmax_gemm_perf(args.batch_sizes, gemm_softmax_gemm_sizes, timeout)