tune_ck.py 2.24 KB
Newer Older
Paul's avatar
Fixes  
Paul committed
1
import os, json, subprocess, tempfile, sys, argparse, contextlib
Paul's avatar
Paul committed
2

Paul's avatar
Format  
Paul committed
3

Paul's avatar
Fixes  
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
@contextlib.contextmanager
def tmp_file(dump=None):
    tmp_name = None
    try:
        with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
            tmp_name = f.name
            if dump:
                dump(f)
        yield tmp_name
    finally:
        os.unlink(tmp_name)

Paul's avatar
Format  
Paul committed
16

Paul's avatar
Fixes  
Paul committed
17
def pretty_print(obj):
Paul's avatar
Paul committed
18
    print(json.dumps(obj, indent=2))
Paul's avatar
Paul committed
19

Paul's avatar
Format  
Paul committed
20

Paul's avatar
Paul committed
21
def benchmark_one(config, tuning):
Paul's avatar
Fixes  
Paul committed
22
23
24
25
26
27
28
29
    b = {
        'settings': {
            'iterations': 100
        },
        'compile_op': {
            'name': 'ck_gemm',
            'tuning_val': tuning,
            'inputs': config
Paul's avatar
Paul committed
30
        }
Paul's avatar
Fixes  
Paul committed
31
32
33
34
    }
    print(b)
    with tmp_file(lambda tf: json.dump(b, tf)) as tf:
        cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
Paul's avatar
Format  
Paul committed
35
36
                            capture_output=True,
                            shell=True)
Paul's avatar
Paul committed
37
38
39
40
41
42
        for line in cp.stdout.decode().split("\n"):
            s = line.strip()
            if not s:
                continue
            fields = s.split(',')
            dtime = fields[-1].strip()
Paul's avatar
Fixes  
Paul committed
43
            print(dtime)
Paul's avatar
Paul committed
44
45
46
            return float(dtime[:-2])
    return sys.float_info.max

Paul's avatar
Format  
Paul committed
47

Paul's avatar
Paul committed
48
49
def benchmark(config, size):
    times = [benchmark_one(config, i) for i in range(size)]
Paul's avatar
Use min  
Paul committed
50
    return times.index(min(times))
Paul's avatar
Paul committed
51

Paul's avatar
Format  
Paul committed
52

Paul's avatar
Paul committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def benchmark_log(f):
    result = []
    for line in open(f).readlines():
        line = line.strip()
        if not line.startswith('ck_gemm:'):
            continue
        line = line[len('ck_gemm:'):].strip()
        config = json.loads(line)
        tuned = benchmark(config, 13)
        result.append([config, tuned])
    return result


def parse_args():
Paul's avatar
Format  
Paul committed
67
68
69
    parser = argparse.ArgumentParser(description="Simple tuner for CK gemms")
    parser.add_argument('--log',
                        '-l',
Paul's avatar
Paul committed
70
71
72
                        type=str,
                        metavar='file',
                        help='Path to logfile')
Paul's avatar
Format  
Paul committed
73
74
    parser.add_argument('--out',
                        '-o',
Paul's avatar
Paul committed
75
76
77
78
79
80
                        type=str,
                        metavar='file',
                        help='Output json file to save tunings')
    args = parser.parse_args()
    return args

Paul's avatar
Format  
Paul committed
81

Paul's avatar
Paul committed
82
83
84
85
def run(args):
    tuned = benchmark_log(args.log)
    json.dump(tuned, open(args.out, 'w+'))

Paul's avatar
Format  
Paul committed
86

Paul's avatar
Paul committed
87
run(parse_args())