tune_ck.py 2.54 KB
Newer Older
Paul's avatar
Fixes  
Paul committed
1
import os, json, subprocess, tempfile, sys, argparse, contextlib
Paul's avatar
Paul committed
2

Paul's avatar
Format  
Paul committed
3

Paul's avatar
Fixes  
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
@contextlib.contextmanager
def tmp_file(dump=None):
    tmp_name = None
    try:
        with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
            tmp_name = f.name
            if dump:
                dump(f)
        yield tmp_name
    finally:
        os.unlink(tmp_name)

Paul's avatar
Format  
Paul committed
16

Paul's avatar
Fixes  
Paul committed
17
def pretty_print(obj):
Paul's avatar
Paul committed
18
    print(json.dumps(obj, indent=2))
Paul's avatar
Paul committed
19

Paul's avatar
Paul committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def run_driver(b):
    print(b)
    with tmp_file(lambda tf: json.dump(b, tf)) as tf:
        cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
                            capture_output=True,
                            shell=True)
        for line in cp.stdout.decode().split("\n"):
            s = line.strip()
            if not s:
                continue
            if not ']: ' in s:
                continue
            yield s.split(']: ')[1].strip()

def convert_to_float(s):
    return s[:-2]
Paul's avatar
Format  
Paul committed
36

Paul's avatar
Paul committed
37
38
39
40
41
def get_device_time(s):
    fields = s.split(',')
    return convert_to_float(fields[-1].strip())

def benchmark_ck(config, tuning):
Paul's avatar
Fixes  
Paul committed
42
43
44
45
46
47
48
49
    b = {
        'settings': {
            'iterations': 100
        },
        'compile_op': {
            'name': 'ck_gemm',
            'tuning_val': tuning,
            'inputs': config
Paul's avatar
Paul committed
50
        }
Paul's avatar
Fixes  
Paul committed
51
    }
Paul's avatar
Paul committed
52
53
54
55
    for line in run_driver(b):
        dtime = get_device_time(line)
        print(dtime)
        return dtime
Paul's avatar
Paul committed
56
57
    return sys.float_info.max

Paul's avatar
Format  
Paul committed
58

Paul's avatar
Paul committed
59
def benchmark(config, size):
Paul's avatar
Paul committed
60
    times = [benchmark_ck(config, i) for i in range(size)]
Paul's avatar
Use min  
Paul committed
61
    return times.index(min(times))
Paul's avatar
Paul committed
62

Paul's avatar
Format  
Paul committed
63

Paul's avatar
Paul committed
64
def parse_log(f):
Paul's avatar
Paul committed
65
66
67
68
69
70
    for line in open(f).readlines():
        line = line.strip()
        if not line.startswith('ck_gemm:'):
            continue
        line = line[len('ck_gemm:'):].strip()
        config = json.loads(line)
Paul's avatar
Paul committed
71
72
73
74
75
        yield config

def benchmark_log(f):
    result = []
    for config in parse_log(f):
Paul's avatar
Paul committed
76
77
78
79
80
81
        tuned = benchmark(config, 13)
        result.append([config, tuned])
    return result


def parse_args():
Paul's avatar
Format  
Paul committed
82
83
84
    parser = argparse.ArgumentParser(description="Simple tuner for CK gemms")
    parser.add_argument('--log',
                        '-l',
Paul's avatar
Paul committed
85
86
87
                        type=str,
                        metavar='file',
                        help='Path to logfile')
Paul's avatar
Format  
Paul committed
88
89
    parser.add_argument('--out',
                        '-o',
Paul's avatar
Paul committed
90
91
92
93
94
95
                        type=str,
                        metavar='file',
                        help='Output json file to save tunings')
    args = parser.parse_args()
    return args

Paul's avatar
Format  
Paul committed
96

Paul's avatar
Paul committed
97
98
99
100
def run(args):
    tuned = benchmark_log(args.log)
    json.dump(tuned, open(args.out, 'w+'))

Paul's avatar
Format  
Paul committed
101

Paul's avatar
Paul committed
102
run(parse_args())