Commit f5409f95 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

additional refinement of input and output names mapping

parent 34fcdc47
...@@ -131,11 +131,21 @@ def model_output_names(model_file_name): ...@@ -131,11 +131,21 @@ def model_output_names(model_file_name):
def get_input_shapes(sample_case, param_names): def get_input_shapes(sample_case, param_names):
param_shape_map = {} param_shape_map = {}
name_array = []
shape_array = []
for i in range(len(param_names)): for i in range(len(param_names)):
file_name = sample_case + '/input_' + str(i) + '.pb' file_name = sample_case + '/input_' + str(i) + '.pb'
name, data = read_pb_file(file_name) name, data = read_pb_file(file_name)
param_shape_map[name] = list(data.shape) shape_array.append(data.shape)
print("{}: {}".format(name, data.shape)) if name:
name_array.append(name)
if len(name_array) < len(shape_array):
param_shape_map = {}
for i in range(len(param_names)):
param_shape_map[param_names[i]] = shape_array[i]
return param_shape_map
for name in param_names: for name in param_names:
if not name in param_shape_map.keys(): if not name in param_shape_map.keys():
...@@ -218,6 +228,8 @@ def main(): ...@@ -218,6 +228,8 @@ def main():
cases = get_test_cases(test_loc) cases = get_test_cases(test_loc)
sample_case = test_loc + '/' + cases[0] sample_case = test_loc + '/' + cases[0]
param_shapes = get_input_shapes(sample_case, param_names) param_shapes = get_input_shapes(sample_case, param_names)
for name, dims in param_shapes.items():
print("Input: {}, shape: {}".format(name, dims))
# read and compile model # read and compile model
model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes) model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment