Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b56b11f1
Commit
b56b11f1
authored
Nov 17, 2022
by
Alan Turner
Browse files
Formatting
parent
4da1d448
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
19 deletions
+29
-19
tools/tune_ck.py
tools/tune_ck.py
+2
-0
tools/tune_models.py
tools/tune_models.py
+27
-19
No files found.
tools/tune_ck.py
View file @
b56b11f1
...
...
@@ -134,9 +134,11 @@ def run(args):
tuned
=
benchmark_log
(
args
.
log
,
args
.
n
)
json
.
dump
(
tuned
,
open
(
args
.
out
,
'w+'
))
def
tune
(
log
,
n
,
out
):
tuned
=
benchmark_log
(
log
,
n
)
json
.
dump
(
tuned
,
open
(
out
,
'w+'
))
if
__name__
==
'__main__'
:
run
(
parse_args
())
tools/tune_models.py
View file @
b56b11f1
import
os
,
json
,
subprocess
,
tempfile
,
sys
,
argparse
,
contextlib
,
time
import
tune_ck
as
tc
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Tune CK GEMMs for one or more ONNX models"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Tune CK GEMMs for one or more ONNX models"
)
parser
.
add_argument
(
'--models'
,
'-m'
,
nargs
=
'+'
,
help
=
'ONNX models to be tuned'
,
'-m'
,
nargs
=
'+'
,
help
=
'ONNX models to be tuned'
,
required
=
True
)
parser
.
add_argument
(
'--batch_sizes'
,
'-b'
,
nargs
=
'+'
,
help
=
'Batch sizes to tune'
,
'-b'
,
nargs
=
'+'
,
help
=
'Batch sizes to tune'
,
required
=
True
)
parser
.
add_argument
(
'--sequence_length'
,
'-s'
,
type
=
int
,
default
=
384
,
parser
.
add_argument
(
'--sequence_length'
,
'-s'
,
type
=
int
,
default
=
384
,
help
=
'Sequence length for transformer models'
)
parser
.
add_argument
(
'-n'
,
type
=
int
,
default
=
16
,
parser
.
add_argument
(
'-n'
,
type
=
int
,
default
=
16
,
help
=
'Number of instances to tune'
)
args
=
parser
.
parse_args
()
return
args
def
tune_models
(
models
,
batch_sizes
,
seq_len
,
n
):
time_stamp
=
time
.
strftime
(
"%Y_%m_%d_%H_%M"
)
log_file
=
"ck_tuning_{}.log"
.
format
(
time_stamp
)
json_file
=
"ck_tuning_{}.json"
.
format
(
time_stamp
)
for
model
in
models
:
for
batch
in
batch_sizes
:
out
=
subprocess
.
run
(
'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g --fill1 input_ids --input-dim @input_ids {} {} | grep
\'
ck_gemm.*: \[{{
\'
| sort -u >> {}'
.
format
(
model
,
batch
,
seq_len
,
log_file
),
capture_output
=
True
,
check
=
True
,
shell
=
True
)
out
=
subprocess
.
run
(
'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g --fill1 input_ids --input-dim @input_ids {} {} | grep
\'
ck_gemm.*: \[{{
\'
| sort -u >> {}'
.
format
(
model
,
batch
,
seq_len
,
log_file
),
capture_output
=
True
,
check
=
True
,
shell
=
True
)
tc
.
tune
(
log_file
,
n
,
json_file
)
print
(
"
\n
Tuning results have been saved to:
\n
{}
\n
"
.
format
(
json_file
))
def
run
(
args
):
tune_models
(
args
.
models
,
args
.
batch_sizes
,
args
.
sequence_length
,
args
.
n
)
run
(
parse_args
())
\ No newline at end of file
run
(
parse_args
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment