"ts/nni_manager/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "51c6afde97a8a29f2c4a15567daf789c66376fbf"
Commit 4da1d448 authored by Alan Turner's avatar Alan Turner
Browse files

Add tune_models.py

parent 1cce5d43
...@@ -22,9 +22,9 @@ def pretty_print(obj): ...@@ -22,9 +22,9 @@ def pretty_print(obj):
def run_driver(b): def run_driver(b):
print(b) print(b)
outfile = open("temp2.json", "w") #outfile = open("temp2.json", "w")
json.dump(b, outfile) #json.dump(b, outfile)
outfile.close() #outfile.close()
with tmp_file(lambda tf: json.dump(b, tf)) as tf: with tmp_file(lambda tf: json.dump(b, tf)) as tf:
cp = subprocess.run('./bin/gpu-driver {}'.format(tf), cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
capture_output=True, capture_output=True,
...@@ -134,5 +134,9 @@ def run(args): ...@@ -134,5 +134,9 @@ def run(args):
tuned = benchmark_log(args.log, args.n) tuned = benchmark_log(args.log, args.n)
json.dump(tuned, open(args.out, 'w+')) json.dump(tuned, open(args.out, 'w+'))
def tune(log, n, out):
tuned = benchmark_log(log, n)
json.dump(tuned, open(out, 'w+'))
run(parse_args()) if __name__ == '__main__':
run(parse_args())
#!/bin/bash
MODEL=$1
LOG="ck_bbc.log"
TUNING_DB="ck_bbc.json"
rm $LOG
touch $LOG
for N in 1 16 32 64
do
MIGRAPHX_LOG_CK_GEMM=1 ./bin/driver run $MODEL -g --fill1 input_ids --input-dim @input_ids $N 384 | grep 'ck_gemm.*: \[{' | sort -u >> $LOG
done
python3 ../tools/tune_ck.py -n 16 -l $LOG -o $TUNING_DB
export MIGRAPHX_CK_TUNING=$TUNING_DB
\ No newline at end of file
import os, json, subprocess, tempfile, sys, argparse, contextlib, time
import tune_ck as tc
def parse_args():
parser = argparse.ArgumentParser(description="Tune CK GEMMs for one or more ONNX models")
parser.add_argument('--models',
'-m',
nargs='+',
help='ONNX models to be tuned',
required=True)
parser.add_argument('--batch_sizes',
'-b',
nargs='+',
help='Batch sizes to tune',
required=True)
parser.add_argument('--sequence_length',
'-s', type=int,
default=384,
help='Sequence length for transformer models')
parser.add_argument('-n',
type=int,
default=16,
help='Number of instances to tune')
args = parser.parse_args()
return args
def tune_models(models, batch_sizes, seq_len, n):
time_stamp = time.strftime("%Y_%m_%d_%H_%M")
log_file = "ck_tuning_{}.log".format(time_stamp)
json_file = "ck_tuning_{}.json".format(time_stamp)
for model in models:
for batch in batch_sizes:
out = subprocess.run('MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g --fill1 input_ids --input-dim @input_ids {} {} | grep \'ck_gemm.*: \[{{\' | sort -u >> {}'.format(model, batch, seq_len, log_file),
capture_output=True,
check=True,
shell=True)
tc.tune(log_file, n, json_file)
print("\nTuning results have been saved to:\n{}\n".format(json_file))
def run(args):
tune_models(args.models, args.batch_sizes, args.sequence_length, args.n)
run(parse_args())
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment