Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
34fcdc47
"vscode:/vscode.git/clone" did not exist on "cef0553d59713a5e5842b9a1d79334d4ffc066b9"
Commit
34fcdc47
authored
Nov 04, 2021
by
Shucai Xiao
Browse files
refine the input and output data file processing
parent
690dd868
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
11 deletions
+33
-11
tools/test_runner.py
tools/test_runner.py
+33
-11
No files found.
tools/test_runner.py
View file @
34fcdc47
...
@@ -59,10 +59,22 @@ def read_pb_file(filename):
...
@@ -59,10 +59,22 @@ def read_pb_file(filename):
def
wrapup_inputs
(
io_folder
,
param_names
):
def
wrapup_inputs
(
io_folder
,
param_names
):
param_map
=
{}
param_map
=
{}
data_array
=
[]
name_array
=
[]
for
i
in
range
(
len
(
param_names
)):
for
i
in
range
(
len
(
param_names
)):
file_name
=
io_folder
+
'/input_'
+
str
(
i
)
+
'.pb'
file_name
=
io_folder
+
'/input_'
+
str
(
i
)
+
'.pb'
name
,
data
=
read_pb_file
(
file_name
)
name
,
data
=
read_pb_file
(
file_name
)
param_map
[
name
]
=
data
param_map
[
name
]
=
data
data_array
.
append
(
data
)
if
name
:
name_array
.
append
(
name
)
if
len
(
name_array
)
<
len
(
data_array
):
param_map
=
{}
for
i
in
range
(
len
(
param_names
)):
param_map
[
param_names
[
i
]]
=
data_array
[
i
]
return
param_map
for
name
in
param_names
:
for
name
in
param_names
:
if
not
name
in
param_map
.
keys
():
if
not
name
in
param_map
.
keys
():
...
@@ -72,12 +84,23 @@ def wrapup_inputs(io_folder, param_names):
...
@@ -72,12 +84,23 @@ def wrapup_inputs(io_folder, param_names):
return
param_map
return
param_map
def
read_outputs
(
io_folder
,
out_num
):
def
read_outputs
(
io_folder
,
out_names
):
outputs
=
{}
outputs
=
[]
for
i
in
range
(
out_num
):
data_array
=
[]
name_array
=
[]
for
i
in
range
(
len
(
out_names
)):
file_name
=
io_folder
+
'/output_'
+
str
(
i
)
+
'.pb'
file_name
=
io_folder
+
'/output_'
+
str
(
i
)
+
'.pb'
name
,
data
=
read_pb_file
(
file_name
)
name
,
data
=
read_pb_file
(
file_name
)
outputs
[
name
]
=
data
data_array
.
append
(
data
)
if
name
:
name_array
.
append
(
name
)
if
len
(
name_array
)
<
len
(
data_array
):
return
data_array
for
name
in
out_names
:
index
=
name_array
.
index
(
name
)
outputs
.
append
(
data_array
[
index
])
return
outputs
return
outputs
...
@@ -126,7 +149,7 @@ def run_one_case(model, param_map):
...
@@ -126,7 +149,7 @@ def run_one_case(model, param_map):
# convert np array to model argument
# convert np array to model argument
pp
=
{}
pp
=
{}
for
key
,
val
in
param_map
.
items
():
for
key
,
val
in
param_map
.
items
():
print
(
"input = {}"
.
format
(
val
))
#
print("input
: {}
= {}".format(
key,
val))
pp
[
key
]
=
migraphx
.
argument
(
val
)
pp
[
key
]
=
migraphx
.
argument
(
val
)
# run the model
# run the model
...
@@ -198,7 +221,6 @@ def main():
...
@@ -198,7 +221,6 @@ def main():
# read and compile model
# read and compile model
model
=
migraphx
.
parse_onnx
(
model_path_name
,
map_input_dims
=
param_shapes
)
model
=
migraphx
.
parse_onnx
(
model_path_name
,
map_input_dims
=
param_shapes
)
# param_names = model.get_parameter_names()
output_shapes
=
model
.
get_output_shapes
()
output_shapes
=
model
.
get_output_shapes
()
model
.
compile
(
migraphx
.
get_target
(
target
))
model
.
compile
(
migraphx
.
get_target
(
target
))
...
@@ -209,7 +231,7 @@ def main():
...
@@ -209,7 +231,7 @@ def main():
for
case_name
in
cases
:
for
case_name
in
cases
:
io_folder
=
test_loc
+
'/'
+
case_name
io_folder
=
test_loc
+
'/'
+
case_name
input_data
=
wrapup_inputs
(
io_folder
,
param_names
)
input_data
=
wrapup_inputs
(
io_folder
,
param_names
)
gold_outputs
=
read_outputs
(
io_folder
,
len
(
output_
shap
es
)
)
gold_outputs
=
read_outputs
(
io_folder
,
output_
nam
es
)
# if input shape is different from model shape, reload and recompile
# if input shape is different from model shape, reload and recompile
# model
# model
...
@@ -221,12 +243,12 @@ def main():
...
@@ -221,12 +243,12 @@ def main():
# run the model and return outputs
# run the model and return outputs
output_data
=
run_one_case
(
model
,
input_data
)
output_data
=
run_one_case
(
model
,
input_data
)
gold_output_data
=
[]
#
gold_output_data = []
for
i
in
range
(
len
((
output_data
))):
#
for i in range(len((output_data))):
gold_output_data
.
append
(
gold_outputs
[
output_names
[
i
]])
#
gold_output_data.append(gold_outputs[output_names[i]])
# check output correctness
# check output correctness
ret
=
check_correctness
(
gold_output
_data
,
output_data
)
ret
=
check_correctness
(
gold_output
s
,
output_data
)
if
ret
:
if
ret
:
correct_num
+=
1
correct_num
+=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment