Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
34fcdc47
Commit
34fcdc47
authored
Nov 04, 2021
by
Shucai Xiao
Browse files
refine the input and output data file processing
parent
690dd868
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
11 deletions
+33
-11
tools/test_runner.py
tools/test_runner.py
+33
-11
No files found.
tools/test_runner.py
View file @
34fcdc47
...
...
@@ -59,10 +59,22 @@ def read_pb_file(filename):
def
wrapup_inputs
(
io_folder
,
param_names
):
param_map
=
{}
data_array
=
[]
name_array
=
[]
for
i
in
range
(
len
(
param_names
)):
file_name
=
io_folder
+
'/input_'
+
str
(
i
)
+
'.pb'
name
,
data
=
read_pb_file
(
file_name
)
param_map
[
name
]
=
data
data_array
.
append
(
data
)
if
name
:
name_array
.
append
(
name
)
if
len
(
name_array
)
<
len
(
data_array
):
param_map
=
{}
for
i
in
range
(
len
(
param_names
)):
param_map
[
param_names
[
i
]]
=
data_array
[
i
]
return
param_map
for
name
in
param_names
:
if
not
name
in
param_map
.
keys
():
...
...
@@ -72,12 +84,23 @@ def wrapup_inputs(io_folder, param_names):
return
param_map
def
read_outputs
(
io_folder
,
out_num
):
outputs
=
{}
for
i
in
range
(
out_num
):
def
read_outputs
(
io_folder
,
out_names
):
outputs
=
[]
data_array
=
[]
name_array
=
[]
for
i
in
range
(
len
(
out_names
)):
file_name
=
io_folder
+
'/output_'
+
str
(
i
)
+
'.pb'
name
,
data
=
read_pb_file
(
file_name
)
outputs
[
name
]
=
data
data_array
.
append
(
data
)
if
name
:
name_array
.
append
(
name
)
if
len
(
name_array
)
<
len
(
data_array
):
return
data_array
for
name
in
out_names
:
index
=
name_array
.
index
(
name
)
outputs
.
append
(
data_array
[
index
])
return
outputs
...
...
@@ -126,7 +149,7 @@ def run_one_case(model, param_map):
# convert np array to model argument
pp
=
{}
for
key
,
val
in
param_map
.
items
():
print
(
"input = {}"
.
format
(
val
))
#
print("input
: {}
= {}".format(
key,
val))
pp
[
key
]
=
migraphx
.
argument
(
val
)
# run the model
...
...
@@ -198,7 +221,6 @@ def main():
# read and compile model
model
=
migraphx
.
parse_onnx
(
model_path_name
,
map_input_dims
=
param_shapes
)
# param_names = model.get_parameter_names()
output_shapes
=
model
.
get_output_shapes
()
model
.
compile
(
migraphx
.
get_target
(
target
))
...
...
@@ -209,7 +231,7 @@ def main():
for
case_name
in
cases
:
io_folder
=
test_loc
+
'/'
+
case_name
input_data
=
wrapup_inputs
(
io_folder
,
param_names
)
gold_outputs
=
read_outputs
(
io_folder
,
len
(
output_
shap
es
)
)
gold_outputs
=
read_outputs
(
io_folder
,
output_
nam
es
)
# if input shape is different from model shape, reload and recompile
# model
...
...
@@ -221,12 +243,12 @@ def main():
# run the model and return outputs
output_data
=
run_one_case
(
model
,
input_data
)
gold_output_data
=
[]
for
i
in
range
(
len
((
output_data
))):
gold_output_data
.
append
(
gold_outputs
[
output_names
[
i
]])
#
gold_output_data = []
#
for i in range(len((output_data))):
#
gold_output_data.append(gold_outputs[output_names[i]])
# check output correctness
ret
=
check_correctness
(
gold_output
_data
,
output_data
)
ret
=
check_correctness
(
gold_output
s
,
output_data
)
if
ret
:
correct_num
+=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment