tune_ck.py 1.78 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os, json, subprocess, tempfile, sys, argparse


def benchmark_one(config, tuning):
    with tempfile.NamedTemporaryFile(mode="w+") as tf:
        b = {
            'settings': {'iterations': 100},
            'compile_op': {
                'name': 'ck_gemm',
                'tuning_val': tuning,
                'inputs': config
            }
        }
        json.dump(b, tf)
        cp = subprocess.run('./bin/gpu-driver {}'.format(tf.name), capture_output=True)
        for line in cp.stdout.decode().split("\n"):
            s = line.strip()
            if not s:
                continue
            fields = s.split(',')
            dtime = fields[-1].strip()
            return float(dtime[:-2])
    return sys.float_info.max

def benchmark(config, size):
    times = [benchmark_one(config, i) for i in range(size)]
    return times.index(max(times))

def benchmark_log(f):
    result = []
    for line in open(f).readlines():
        line = line.strip()
        if not line.startswith('ck_gemm:'):
            continue
        line = line[len('ck_gemm:'):].strip()
        config = json.loads(line)
        tuned = benchmark(config, 13)
        result.append([config, tuned])
    return result


def parse_args():
    parser = argparse.ArgumentParser(
        description="Simple tuner for CK gemms")
    parser.add_argument('--log', '-l',
                        type=str,
                        metavar='file',
                        help='Path to logfile')
    parser.add_argument('--out', '-o',
                        type=str,
                        metavar='file',
                        help='Output json file to save tunings')
    args = parser.parse_args()
    return args

def run(args):
    tuned = benchmark_log(args.log)
    json.dump(tuned, open(args.out, 'w+'))

run(parse_args())