tune_models.py 3.91 KB
Newer Older
1
import os, subprocess, argparse, time, json, difflib
Alan Turner's avatar
Alan Turner committed
2
3
import tune_ck as tc

Alan Turner's avatar
Alan Turner committed
4

Alan Turner's avatar
Alan Turner committed
5
def parse_args():
Alan Turner's avatar
Alan Turner committed
6
7
    parser = argparse.ArgumentParser(
        description="Tune CK GEMMs for one or more ONNX models")
Alan Turner's avatar
Alan Turner committed
8
    parser.add_argument('--models',
Alan Turner's avatar
Alan Turner committed
9
10
11
                        '-m',
                        nargs='+',
                        help='ONNX models to be tuned',
Alan Turner's avatar
Alan Turner committed
12
13
                        required=True)
    parser.add_argument('--batch_sizes',
Alan Turner's avatar
Alan Turner committed
14
15
16
                        '-b',
                        nargs='+',
                        help='Batch sizes to tune',
Alan Turner's avatar
Alan Turner committed
17
                        required=True)
Alan Turner's avatar
Alan Turner committed
18
19
20
21
    parser.add_argument('--sequence_length',
                        '-s',
                        type=int,
                        default=384,
Alan Turner's avatar
Alan Turner committed
22
                        help='Sequence length for transformer models')
Alan Turner's avatar
Alan Turner committed
23
24
    parser.add_argument('-n',
                        type=int,
25
                        default=18,
Alan Turner's avatar
Alan Turner committed
26
                        help='Number of instances to tune')
Alan Turner's avatar
Alan Turner committed
27
28
29
30
31
32
    parser.add_argument(
        '--update',
        '-u',
        type=str,
        help=
        'Existing tuning JSON. Configs already present will not be re-tuned.')
33
    parser.add_argument("-q", "--quantize_int8", action="store_true")
Alan Turner's avatar
Alan Turner committed
34
35
36
    args = parser.parse_args()
    return args

Alan Turner's avatar
Alan Turner committed
37

38
def tune_models(models, batch_sizes, seq_len, n, existing, q_int8):
Alan Turner's avatar
Alan Turner committed
39
40
41
    time_stamp = time.strftime("%Y_%m_%d_%H_%M")
    log_file = "ck_tuning_{}.log".format(time_stamp)
    json_file = "ck_tuning_{}.json".format(time_stamp)
42
    prec_str = "--int8" if q_int8 else ""
Alan Turner's avatar
Alan Turner committed
43
44
    for model in models:
        for batch in batch_sizes:
45
46
            params = "--input-dim @sample {} 4 64 64 @timestep 1 @encoder_hidden_states {} 64 1024 --fp16 {} ".format(
                batch, batch, prec_str)
Alan Turner's avatar
Alan Turner committed
47
            if "bert" in model:
48
49
50
51
52
                params = "{} --fp16 --fill1 input_ids --input-dim @input_ids {} {} ".format(
                    prec_str, batch, seq_len)
            if "squad" in model:
                params = "--fill1 input_ids:0 unique_ids_raw_output___9:0 input_mask:0 segment_ids:0 --input-dim @input_ids:0 {} 256 @input_mask:0 {} 256 @segment_ids:0 {} 256 --fp16 {}".format(
                    batch, batch, batch, prec_str)
Alan Turner's avatar
Alan Turner committed
53
            out = subprocess.run(
Alan Turner's avatar
Alan Turner committed
54
55
                'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g {} | grep \'ck_gemm.*: \[{{\' | sort -u >> {}'
                .format(model, params, log_file),
Alan Turner's avatar
Alan Turner committed
56
57
58
59
                capture_output=True,
                check=True,
                shell=True)

60
61
62
63
64
65
66
    if (existing is not None):
        f = open(existing)
        configs = json.load(f)
        configs = [str(s).replace(" ", "") for l in configs for s in l]
        update_logs = []
        with open(log_file, "r") as lf:
            logs = [line for line in lf]
Alan Turner's avatar
Alan Turner committed
67
68
69
70
71
72
73
74
            stripped_logs = [
                line.replace("ck_gemm: ",
                             "").replace("ck_gemm_softmax_gemm: ",
                                         "").replace("\"",
                                                     "'").replace("\n", "")
                for line in logs
            ]

75
76
77
            for i in range(len(stripped_logs)):
                if (stripped_logs[i] not in configs):
                    update_logs.append(logs[i])
Alan Turner's avatar
Alan Turner committed
78

79
80
81
        with open(log_file, "w") as lf:
            for line in update_logs:
                lf.write(line)
Alan Turner's avatar
Alan Turner committed
82

83
84
        f.close()

Alan Turner's avatar
Alan Turner committed
85
    tc.tune(log_file, n, json_file)
86
87
88
89
90
91
92
93
94
95
96

    if (existing is not None):
        f_old = open(existing, "r")
        f_new = open(json_file, "r")
        old = json.load(f_old)
        new = json.load(f_new)
        new = old + new
        f_old.close()
        f_new.close()
        json.dump(new, open(json_file, "w"))

Alan Turner's avatar
Alan Turner committed
97
98
    tuning_path = os.path.abspath(json_file)
    os.environ["MIGRAPHX_CK_TUNING"] = tuning_path
Alan Turner's avatar
Alan Turner committed
99
100
    print("\nTuning results have been saved to:\n{}\n".format(json_file))

Alan Turner's avatar
Alan Turner committed
101

Alan Turner's avatar
Alan Turner committed
102
def run(args):
Alan Turner's avatar
Alan Turner committed
103
    tune_models(args.models, args.batch_sizes, args.sequence_length, args.n,
104
                args.update, args.quantize_int8)
Alan Turner's avatar
Alan Turner committed
105

Alan Turner's avatar
Alan Turner committed
106
107

run(parse_args())