gemm_perf.py 6.26 KB
Newer Older
turneram's avatar
turneram committed
1
#%matplotlib
2
import subprocess, csv, re, datetime, argparse, os, enum
turneram's avatar
turneram committed
3
4
5
from subprocess import STDOUT


6
7
8
9
10
11
class GEMM_Provider(enum.Enum):
    CK = 1
    ROCBLAS = 2 
    MLIR = 3


turneram's avatar
turneram committed
12
13
def parse_args():
    parser = argparse.ArgumentParser(description="GEMM performance tools")
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    parser.add_argument('--gemm',
                        action='store_true',
                        help='Run performance comparison on a range of GEMM problem sizes (fp16/int8)')
    parser.add_argument('--int8',
                        action='store_true',
                        help='Quantize GEMMs to int8 precision (not available for GEMM-Softmax-GEMM)')
    parser.add_argument('--gemm-softmax-gemm',
                        action='store_true',
                        help='Run performance comparison on a range of GEMM-Softmax-GEMM problem sizes (fp16)')
    parser.add_argument('--batch_sizes',
                        '-b',
                        nargs='+',
                        help='Batch sizes to run',
                        required=True)
    parser.add_argument('--exclude-lens',
                        '-e',
                        nargs='+',
                        help='Exclude lengths from [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096]')
    parser.add_argument('--timeout',
                        '-t',
                        type=int,
                        default=600,
                        help='Time in seconds before compilation timeout')
turneram's avatar
turneram committed
37
38
39
40
41
42
43
44
45
46
    args = parser.parse_args()

    return args


class CSVFile:
    def __init__(self, path="output.csv"):
        self.path = path

    def write_row(self, row=[]):
47
        row = [str(r) for r in row]
turneram's avatar
turneram committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        with open(self.path, "a+") as f:
            cw = csv.writer(f)
            cw.writerow(row)


def get_device_name():
    out = subprocess.run("rocminfo",
                         capture_output=True,
                         check=True,
                         shell=True)
    matches = re.findall("gfx\d*[a-z]*", str(out.stdout))
    return matches[0]


62
def get_gemm_time(config, fp16, provider, timeout):
turneram's avatar
turneram committed
63
    model = "../test/onnx/matmul_half.onnx" 
64
    b, m, n, k = config
turneram's avatar
turneram committed
65
66
    prec_str = "--fp16" if fp16 else "--int8"
    cmd = f"../build/bin/driver perf {model} --input-dim @1 {b} {m} {k} @2 {b} {k} {n} {prec_str}"
67
68
69
    use_CK = "1" if provider == GEMM_Provider.CK else "0"
    use_MLIR = "1" if provider == GEMM_Provider.MLIR else "0"

turneram's avatar
turneram committed
70
71
72
73
    try:
        out = subprocess.run(cmd.split(),
                             capture_output=True,
                             check=True,
74
75
76
77
                             timeout=timeout,
                             env=dict(os.environ, 
                                      MIGRAPHX_ENABLE_CK=use_CK,
                                      MIGRAPHX_ENABLE_MLIR=use_MLIR))
turneram's avatar
turneram committed
78
79
    except Exception as e:
        print(f"An exception occurred: {str(e)}")
80
81
82
        print(f"{provider.name} GEMM {b}, {m}, {n}, {k}:", end=" ")
        print("-100.0")
        return -100.0
turneram's avatar
turneram committed
83
84
85
86
87
    
    summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    total_time = total_time.replace("Total time: ", "")

88
89
90
91
92
93
94
95
96
97
    return float(total_time)


def get_gemm_softmax_gemm_time(config, provider, timeout):
    model = "../test/onnx/gemm_softmax_gemm_half.onnx" 
    b, m, n, k, o = config
    cmd = f"../build/bin/driver perf {model} --input-dim @a {b} {m} {k} @b {b} {k} {n} @b1 {b} {n} {o} --fp16"
    use_CK = "1" if provider == GEMM_Provider.CK else "0"
    use_MLIR = "1" if provider == GEMM_Provider.MLIR else "0"

turneram's avatar
turneram committed
98
99
100
101
    try:
        out = subprocess.run(cmd.split(),
                             capture_output=True,
                             check=True,
102
                             timeout=timeout,
turneram's avatar
turneram committed
103
                             env=dict(os.environ, 
104
105
                                      MIGRAPHX_ENABLE_CK=use_CK,
                                      MIGRAPHX_ENABLE_MLIR=use_MLIR))
turneram's avatar
turneram committed
106
    except Exception as e:
107
108
        return -100.0
    
turneram's avatar
turneram committed
109
110
111
    summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
    total_time = re.findall("Total time: \d+\.\d*", summary)[0]
    total_time = total_time.replace("Total time: ", "")
112
113
114
115
116

    return float(total_time)


def run_gemm_perf(batches, sizes, fp16):
turneram's avatar
turneram committed
117
118
    prec_str = "half" if fp16 else "int"
    for b in batches:
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        out = CSVFile(f"gemm_perf_{prec_str}_{b}.csv")
        out.write_row([get_device_name(), datetime.datetime.now()])
        out.write_row(["batch_size", "m", "n", "k", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
        for shape in [(m, n, k) for m in sizes for n in sizes for k in sizes]:
            config = (b,) + shape
            ck_time = get_gemm_time(config, fp16, GEMM_Provider.CK)
            rb_time = get_gemm_time(config, fp16, GEMM_Provider.ROCBLAS)
            mlir_time = get_gemm_time(config, fp16, GEMM_Provider.MLIR)
            out.write_row(list(config) + [ck_time, rb_time, mlir_time])


def run_gemm_softmax_gemm_perf(batches, sizes):
    for b in batches:
        out = CSVFile(f"gemm_softmax_gemm_perf_fp16_{b}.csv")
        out.write_row([get_device_name(), datetime.datetime.now()])
        out.write_row(["batch_size", "m", "n", "k", "o", "CK Total Time (ms)", "rocBLAS Total Time (ms)", "MLIR Total Time (ms)"])
        for shape in [(m, n, k, o) for m in sizes for n in sizes for k in sizes for o in sizes]:
            config = (b,) + shape
            ck_time = get_gemm_time(config, GEMM_Provider.CK)
            rb_time = get_gemm_time(config, GEMM_Provider.ROCBLAS)
            mlir_time = get_gemm_time(config, GEMM_Provider.MLIR)
            out.write_row(list(config) + [ck_time, rb_time, mlir_time])

turneram's avatar
turneram committed
142
143
144

if __name__ == "__main__":
    args = parse_args()
145
146
147
148
149
150
    exclude = args.exclude_lens
    exclude = [int(x) for x in exclude]
    sizes = [64, 256, 384, 512, 768, 1024, 1920, 2048, 2304, 3072, 4096]
    sizes = [x for x in sizes if x not in exclude]
    fp16 = not args.int8
    timeout = int(args.timeout)
turneram's avatar
turneram committed
151
    if args.gemm:
152
153
154
155
156
        gemm_sizes = [(m, n, k) for m in sizes for n in sizes for k in sizes]
        run_gemm_perf(args.batches, gemm_sizes, fp16, timeout)
    if args.gemm_softmax_gemm:
        gemm_softmax_gemm_sizes = [(m, n, k, o) for m in sizes for n in sizes for k in sizes for o in sizes]
        run_gemm_softmax_gemm_perf(args.batches, gemm_sizes, timeout)