Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f5409f95
"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "2753a4a65488f7e26556dd1d4d2923121a956fe6"
Commit
f5409f95
authored
Nov 04, 2021
by
Shucai Xiao
Browse files
additional refinement of input and output names mapping
parent
34fcdc47
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
2 deletions
+14
-2
tools/test_runner.py
tools/test_runner.py
+14
-2
No files found.
tools/test_runner.py
View file @
f5409f95
...
@@ -131,11 +131,21 @@ def model_output_names(model_file_name):
...
@@ -131,11 +131,21 @@ def model_output_names(model_file_name):
def
get_input_shapes
(
sample_case
,
param_names
):
def
get_input_shapes
(
sample_case
,
param_names
):
param_shape_map
=
{}
param_shape_map
=
{}
name_array
=
[]
shape_array
=
[]
for
i
in
range
(
len
(
param_names
)):
for
i
in
range
(
len
(
param_names
)):
file_name
=
sample_case
+
'/input_'
+
str
(
i
)
+
'.pb'
file_name
=
sample_case
+
'/input_'
+
str
(
i
)
+
'.pb'
name
,
data
=
read_pb_file
(
file_name
)
name
,
data
=
read_pb_file
(
file_name
)
param_shape_map
[
name
]
=
list
(
data
.
shape
)
shape_array
.
append
(
data
.
shape
)
print
(
"{}: {}"
.
format
(
name
,
data
.
shape
))
if
name
:
name_array
.
append
(
name
)
if
len
(
name_array
)
<
len
(
shape_array
):
param_shape_map
=
{}
for
i
in
range
(
len
(
param_names
)):
param_shape_map
[
param_names
[
i
]]
=
shape_array
[
i
]
return
param_shape_map
for
name
in
param_names
:
for
name
in
param_names
:
if
not
name
in
param_shape_map
.
keys
():
if
not
name
in
param_shape_map
.
keys
():
...
@@ -218,6 +228,8 @@ def main():
...
@@ -218,6 +228,8 @@ def main():
cases
=
get_test_cases
(
test_loc
)
cases
=
get_test_cases
(
test_loc
)
sample_case
=
test_loc
+
'/'
+
cases
[
0
]
sample_case
=
test_loc
+
'/'
+
cases
[
0
]
param_shapes
=
get_input_shapes
(
sample_case
,
param_names
)
param_shapes
=
get_input_shapes
(
sample_case
,
param_names
)
for
name
,
dims
in
param_shapes
.
items
():
print
(
"Input: {}, shape: {}"
.
format
(
name
,
dims
))
# read and compile model
# read and compile model
model
=
migraphx
.
parse_onnx
(
model_path_name
,
map_input_dims
=
param_shapes
)
model
=
migraphx
.
parse_onnx
(
model_path_name
,
map_input_dims
=
param_shapes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment