Commit b56b11f1 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 4da1d448
...@@ -134,9 +134,11 @@ def run(args): ...@@ -134,9 +134,11 @@ def run(args):
tuned = benchmark_log(args.log, args.n) tuned = benchmark_log(args.log, args.n)
json.dump(tuned, open(args.out, 'w+')) json.dump(tuned, open(args.out, 'w+'))
def tune(log, n, out): def tune(log, n, out):
tuned = benchmark_log(log, n) tuned = benchmark_log(log, n)
json.dump(tuned, open(out, 'w+')) json.dump(tuned, open(out, 'w+'))
if __name__ == '__main__': if __name__ == '__main__':
run(parse_args()) run(parse_args())
import os, json, subprocess, tempfile, sys, argparse, contextlib, time import os, json, subprocess, tempfile, sys, argparse, contextlib, time
import tune_ck as tc import tune_ck as tc
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Tune CK GEMMs for one or more ONNX models") parser = argparse.ArgumentParser(
description="Tune CK GEMMs for one or more ONNX models")
parser.add_argument('--models', parser.add_argument('--models',
'-m', '-m',
nargs='+', nargs='+',
...@@ -14,7 +16,8 @@ def parse_args(): ...@@ -14,7 +16,8 @@ def parse_args():
help='Batch sizes to tune', help='Batch sizes to tune',
required=True) required=True)
parser.add_argument('--sequence_length', parser.add_argument('--sequence_length',
'-s', type=int, '-s',
type=int,
default=384, default=384,
help='Sequence length for transformer models') help='Sequence length for transformer models')
parser.add_argument('-n', parser.add_argument('-n',
...@@ -24,13 +27,16 @@ def parse_args(): ...@@ -24,13 +27,16 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
def tune_models(models, batch_sizes, seq_len, n): def tune_models(models, batch_sizes, seq_len, n):
time_stamp = time.strftime("%Y_%m_%d_%H_%M") time_stamp = time.strftime("%Y_%m_%d_%H_%M")
log_file = "ck_tuning_{}.log".format(time_stamp) log_file = "ck_tuning_{}.log".format(time_stamp)
json_file = "ck_tuning_{}.json".format(time_stamp) json_file = "ck_tuning_{}.json".format(time_stamp)
for model in models: for model in models:
for batch in batch_sizes: for batch in batch_sizes:
out = subprocess.run('MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g --fill1 input_ids --input-dim @input_ids {} {} | grep \'ck_gemm.*: \[{{\' | sort -u >> {}'.format(model, batch, seq_len, log_file), out = subprocess.run(
'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g --fill1 input_ids --input-dim @input_ids {} {} | grep \'ck_gemm.*: \[{{\' | sort -u >> {}'
.format(model, batch, seq_len, log_file),
capture_output=True, capture_output=True,
check=True, check=True,
shell=True) shell=True)
...@@ -38,7 +44,9 @@ def tune_models(models, batch_sizes, seq_len, n): ...@@ -38,7 +44,9 @@ def tune_models(models, batch_sizes, seq_len, n):
tc.tune(log_file, n, json_file) tc.tune(log_file, n, json_file)
print("\nTuning results have been saved to:\n{}\n".format(json_file)) print("\nTuning results have been saved to:\n{}\n".format(json_file))
def run(args): def run(args):
tune_models(args.models, args.batch_sizes, args.sequence_length, args.n) tune_models(args.models, args.batch_sizes, args.sequence_length, args.n)
run(parse_args()) run(parse_args())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment