"test/onnx/slice_5arg_reverse_test.onnx" did not exist on "8fa33f1ad5005df75a66cf2734d58fdbedaa6333"
tune_ck.py 2.29 KB
Newer Older
Paul's avatar
Fixes  
Paul committed
1
import os, json, subprocess, tempfile, sys, argparse, contextlib
Paul's avatar
Paul committed
2

Paul's avatar
Format  
Paul committed
3

Paul's avatar
Fixes  
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
@contextlib.contextmanager
def tmp_file(dump=None):
    tmp_name = None
    try:
        with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
            tmp_name = f.name
            if dump:
                dump(f)
        yield tmp_name
    finally:
        os.unlink(tmp_name)

Paul's avatar
Format  
Paul committed
16

Paul's avatar
Fixes  
Paul committed
17
def pretty_print(obj):
Paul's avatar
Paul committed
18
    print(json.dumps(obj, indent=2))
Paul's avatar
Paul committed
19

Paul's avatar
Format  
Paul committed
20

Paul's avatar
Paul committed
21
def benchmark_one(config, tuning):
Paul's avatar
Fixes  
Paul committed
22
23
24
25
26
27
28
29
    b = {
        'settings': {
            'iterations': 100
        },
        'compile_op': {
            'name': 'ck_gemm',
            'tuning_val': tuning,
            'inputs': config
Paul's avatar
Paul committed
30
        }
Paul's avatar
Fixes  
Paul committed
31
32
33
34
    }
    print(b)
    with tmp_file(lambda tf: json.dump(b, tf)) as tf:
        cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
Paul's avatar
Format  
Paul committed
35
36
                            capture_output=True,
                            shell=True)
Paul's avatar
Paul committed
37
38
39
40
        for line in cp.stdout.decode().split("\n"):
            s = line.strip()
            if not s:
                continue
Paul's avatar
Paul committed
41
42
            if not ',' in s:
                continue
Paul's avatar
Paul committed
43
44
            fields = s.split(',')
            dtime = fields[-1].strip()
Paul's avatar
Fixes  
Paul committed
45
            print(dtime)
Paul's avatar
Paul committed
46
47
48
            return float(dtime[:-2])
    return sys.float_info.max

Paul's avatar
Format  
Paul committed
49

Paul's avatar
Paul committed
50
51
def benchmark(config, size):
    times = [benchmark_one(config, i) for i in range(size)]
Paul's avatar
Use min  
Paul committed
52
    return times.index(min(times))
Paul's avatar
Paul committed
53

Paul's avatar
Format  
Paul committed
54

Paul's avatar
Paul committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def benchmark_log(f):
    result = []
    for line in open(f).readlines():
        line = line.strip()
        if not line.startswith('ck_gemm:'):
            continue
        line = line[len('ck_gemm:'):].strip()
        config = json.loads(line)
        tuned = benchmark(config, 13)
        result.append([config, tuned])
    return result


def parse_args():
Paul's avatar
Format  
Paul committed
69
70
71
    parser = argparse.ArgumentParser(description="Simple tuner for CK gemms")
    parser.add_argument('--log',
                        '-l',
Paul's avatar
Paul committed
72
73
74
                        type=str,
                        metavar='file',
                        help='Path to logfile')
Paul's avatar
Format  
Paul committed
75
76
    parser.add_argument('--out',
                        '-o',
Paul's avatar
Paul committed
77
78
79
80
81
82
                        type=str,
                        metavar='file',
                        help='Output json file to save tunings')
    args = parser.parse_args()
    return args

Paul's avatar
Format  
Paul committed
83

Paul's avatar
Paul committed
84
85
86
87
def run(args):
    tuned = benchmark_log(args.log)
    json.dump(tuned, open(args.out, 'w+'))

Paul's avatar
Format  
Paul committed
88

Paul's avatar
Paul committed
89
run(parse_args())