Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4da1d448
Commit
4da1d448
authored
Nov 17, 2022
by
Alan Turner
Browse files
Add tune_models.py
parent
1cce5d43
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
18 deletions
+52
-18
tools/tune_ck.py
tools/tune_ck.py
+8
-4
tools/tune_ck.sh
tools/tune_ck.sh
+0
-14
tools/tune_models.py
tools/tune_models.py
+44
-0
No files found.
tools/tune_ck.py
View file @
4da1d448
...
...
@@ -22,9 +22,9 @@ def pretty_print(obj):
def
run_driver
(
b
):
print
(
b
)
outfile
=
open
(
"temp2.json"
,
"w"
)
json
.
dump
(
b
,
outfile
)
outfile
.
close
()
#
outfile = open("temp2.json", "w")
#
json.dump(b, outfile)
#
outfile.close()
with
tmp_file
(
lambda
tf
:
json
.
dump
(
b
,
tf
))
as
tf
:
cp
=
subprocess
.
run
(
'./bin/gpu-driver {}'
.
format
(
tf
),
capture_output
=
True
,
...
...
@@ -134,5 +134,9 @@ def run(args):
tuned
=
benchmark_log
(
args
.
log
,
args
.
n
)
json
.
dump
(
tuned
,
open
(
args
.
out
,
'w+'
))
def
tune
(
log
,
n
,
out
):
tuned
=
benchmark_log
(
log
,
n
)
json
.
dump
(
tuned
,
open
(
out
,
'w+'
))
run
(
parse_args
())
if
__name__
==
'__main__'
:
run
(
parse_args
())
tools/tune_ck.sh
deleted
100755 → 0
View file @
1cce5d43
#!/bin/bash
MODEL
=
$1
LOG
=
"ck_bbc.log"
TUNING_DB
=
"ck_bbc.json"
rm
$LOG
touch
$LOG
for
N
in
1 16 32 64
do
MIGRAPHX_LOG_CK_GEMM
=
1 ./bin/driver run
$MODEL
-g
--fill1
input_ids
--input-dim
@input_ids
$N
384 |
grep
'ck_gemm.*: \[{'
|
sort
-u
>>
$LOG
done
python3 ../tools/tune_ck.py
-n
16
-l
$LOG
-o
$TUNING_DB
export
MIGRAPHX_CK_TUNING
=
$TUNING_DB
\ No newline at end of file
tools/tune_models.py
0 → 100644
View file @
4da1d448
import
os
,
json
,
subprocess
,
tempfile
,
sys
,
argparse
,
contextlib
,
time
import
tune_ck
as
tc
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Tune CK GEMMs for one or more ONNX models"
)
parser
.
add_argument
(
'--models'
,
'-m'
,
nargs
=
'+'
,
help
=
'ONNX models to be tuned'
,
required
=
True
)
parser
.
add_argument
(
'--batch_sizes'
,
'-b'
,
nargs
=
'+'
,
help
=
'Batch sizes to tune'
,
required
=
True
)
parser
.
add_argument
(
'--sequence_length'
,
'-s'
,
type
=
int
,
default
=
384
,
help
=
'Sequence length for transformer models'
)
parser
.
add_argument
(
'-n'
,
type
=
int
,
default
=
16
,
help
=
'Number of instances to tune'
)
args
=
parser
.
parse_args
()
return
args
def
tune_models
(
models
,
batch_sizes
,
seq_len
,
n
):
time_stamp
=
time
.
strftime
(
"%Y_%m_%d_%H_%M"
)
log_file
=
"ck_tuning_{}.log"
.
format
(
time_stamp
)
json_file
=
"ck_tuning_{}.json"
.
format
(
time_stamp
)
for
model
in
models
:
for
batch
in
batch_sizes
:
out
=
subprocess
.
run
(
'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g --fill1 input_ids --input-dim @input_ids {} {} | grep
\'
ck_gemm.*: \[{{
\'
| sort -u >> {}'
.
format
(
model
,
batch
,
seq_len
,
log_file
),
capture_output
=
True
,
check
=
True
,
shell
=
True
)
tc
.
tune
(
log_file
,
n
,
json_file
)
print
(
"
\n
Tuning results have been saved to:
\n
{}
\n
"
.
format
(
json_file
))
def
run
(
args
):
tune_models
(
args
.
models
,
args
.
batch_sizes
,
args
.
sequence_length
,
args
.
n
)
run
(
parse_args
())
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment