tune_ck.py 4.07 KB
Newer Older
1
import os, json, subprocess, tempfile, sys, argparse, contextlib, multiprocessing, multiprocessing.dummy
Paul's avatar
Paul committed
2

Alan Turner's avatar
Alan Turner committed
3
ck_function = -1
Paul's avatar
Format  
Paul committed
4

Alan Turner's avatar
Alan Turner committed
5

Paul's avatar
Fixes  
Paul committed
6
7
8
9
10
11
12
13
14
15
16
17
@contextlib.contextmanager
def tmp_file(dump=None):
    tmp_name = None
    try:
        with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
            tmp_name = f.name
            if dump:
                dump(f)
        yield tmp_name
    finally:
        os.unlink(tmp_name)

Paul's avatar
Format  
Paul committed
18

Paul's avatar
Fixes  
Paul committed
19
def pretty_print(obj):
Paul's avatar
Paul committed
20
    print(json.dumps(obj, indent=2))
Paul's avatar
Paul committed
21

Paul's avatar
Format  
Paul committed
22

Paul's avatar
Paul committed
23
24
25
def run_driver(b):
    print(b)
    with tmp_file(lambda tf: json.dump(b, tf)) as tf:
Paul's avatar
Paul committed
26
27
28
        if not os.path.exists('./bin/gpu-driver'):
            print("./bin/gpu-driver not found")
            os.abort()
Paul's avatar
Paul committed
29
30
31
        cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
                            capture_output=True,
                            shell=True)
Paul's avatar
Paul committed
32
33
        print(cp.stderr.decode())
        cp.check_returncode()
Paul's avatar
Paul committed
34
35
36
37
38
39
40
41
        for line in cp.stdout.decode().split("\n"):
            s = line.strip()
            if not s:
                continue
            if not ']: ' in s:
                continue
            yield s.split(']: ')[1].strip()

Paul's avatar
Format  
Paul committed
42

Paul's avatar
Paul committed
43
44
def convert_to_float(s):
    return s[:-2]
Paul's avatar
Format  
Paul committed
45

Paul's avatar
Format  
Paul committed
46

Paul's avatar
Paul committed
47
48
49
50
def get_device_time(s):
    fields = s.split(',')
    return convert_to_float(fields[-1].strip())

Paul's avatar
Format  
Paul committed
51

Alan Turner's avatar
Alan Turner committed
52
def run_driver_ck(config, name, tuning, iterations):
53
54
55
56
57
    b = {
        'settings': {
            'iterations': iterations
        },
        'compile_op': {
Alan Turner's avatar
Alan Turner committed
58
            'name': name,
59
60
61
            'check': True,
            'tuning_val': tuning,
            'inputs': config
Paul's avatar
Paul committed
62
        }
63
64
    }
    return run_driver(b)
Paul's avatar
Format  
Paul committed
65

Paul's avatar
Format  
Paul committed
66

Alan Turner's avatar
Alan Turner committed
67
def benchmark_ck(config, name, tuning):
Paul's avatar
Paul committed
68
    try:
Alan Turner's avatar
Alan Turner committed
69
        for line in run_driver_ck(config, name, tuning, 100):
Paul's avatar
Paul committed
70
71
            dtime = get_device_time(line)
            print(dtime)
Paul's avatar
Paul committed
72
            return float(dtime)
Paul's avatar
Paul committed
73
74
        print("Failed")
        sys.exit(1)
Paul's avatar
Paul committed
75
    except:
Paul's avatar
Paul committed
76
        return sys.float_info.max
Paul's avatar
Paul committed
77

Paul's avatar
Format  
Paul committed
78

79
80
def benchmark(config, name, size):
    times = [benchmark_ck(config, name, i) for i in range(size)]
Paul's avatar
Use min  
Paul committed
81
    return times.index(min(times))
Paul's avatar
Paul committed
82

Paul's avatar
Format  
Paul committed
83

Paul's avatar
Paul committed
84
def parse_log(f):
Paul's avatar
Paul committed
85
86
    for line in open(f).readlines():
        line = line.strip()
Alan Turner's avatar
Alan Turner committed
87
88
89
90
        global ck_function
        if line.startswith('ck_gemm:'):
            line = line[len('ck_gemm:'):].strip()
            config = json.loads(line)
91
            yield (config, 'ck_gemm')
Alan Turner's avatar
Alan Turner committed
92
93
94
95
        if line.startswith('ck_gemm_softmax_gemm:'):
            line = line[len('ck_gemm_softmax_gemm:'):].strip()
            config = json.loads(line)
            ck_function = 1
96
            yield (config, 'ck_gemm_softmax_gemm')
Paul's avatar
Paul committed
97

Paul's avatar
Format  
Paul committed
98

99
100
def precompile(x):
    try:
Paul's avatar
Paul committed
101
        list(run_driver_ck(x[0], x[1], x[2], 0))
102
103
104
    except:
        pass

Paul's avatar
Format  
Paul committed
105

106
def precompile_log(f, n):
Paul's avatar
Paul committed
107
    solutions = ((config, name, i) for config, name in parse_log(f) for i in range(n))
108
109
    with multiprocessing.Pool(24) as p:
        list(p.imap(precompile, solutions))
Paul's avatar
Format  
Paul committed
110

Paul's avatar
Format  
Paul committed
111

112
def benchmark_log(f, n):
Paul's avatar
Paul committed
113
    result = []
114
115
    for config, name in parse_log(f):
        tuned = benchmark(config, name, n)
Paul's avatar
Paul committed
116
        print("Tuned:", tuned)
Paul's avatar
Paul committed
117
118
119
120
121
        result.append([config, tuned])
    return result


def parse_args():
Paul's avatar
Format  
Paul committed
122
123
124
    parser = argparse.ArgumentParser(description="Simple tuner for CK gemms")
    parser.add_argument('--log',
                        '-l',
Paul's avatar
Paul committed
125
126
127
                        type=str,
                        metavar='file',
                        help='Path to logfile')
Paul's avatar
Format  
Paul committed
128
129
    parser.add_argument('--out',
                        '-o',
Paul's avatar
Paul committed
130
131
132
                        type=str,
                        metavar='file',
                        help='Output json file to save tunings')
133
134
135
136
    parser.add_argument('--precompile',
                        '-p',
                        action='store_true',
                        help='Precompile kernels first in parallel')
Paul's avatar
Format  
Paul committed
137
    parser.add_argument('-n', type=int, help='Number of instances to tune')
Paul's avatar
Paul committed
138
139
140
    args = parser.parse_args()
    return args

Paul's avatar
Format  
Paul committed
141

Paul's avatar
Paul committed
142
def run(args):
Paul's avatar
Format  
Paul committed
143
    if (args.precompile):
144
        precompile_log(args.log, args.n)
145
    tuned = benchmark_log(args.log, args.n)
Paul's avatar
Paul committed
146
147
    json.dump(tuned, open(args.out, 'w+'))

Alan Turner's avatar
Alan Turner committed
148

Alan Turner's avatar
Alan Turner committed
149
150
151
def tune(log, n, out):
    tuned = benchmark_log(log, n)
    json.dump(tuned, open(out, 'w+'))
Paul's avatar
Format  
Paul committed
152

Alan Turner's avatar
Alan Turner committed
153

Alan Turner's avatar
Alan Turner committed
154
155
if __name__ == '__main__':
    run(parse_args())