Commit 34fcdc47 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine the input and output data file processing

parent 690dd868
......@@ -59,10 +59,22 @@ def read_pb_file(filename):
def wrapup_inputs(io_folder, param_names):
param_map = {}
data_array = []
name_array = []
for i in range(len(param_names)):
file_name = io_folder + '/input_' + str(i) + '.pb'
name, data = read_pb_file(file_name)
param_map[name] = data
data_array.append(data)
if name:
name_array.append(name)
if len(name_array) < len(data_array):
param_map = {}
for i in range(len(param_names)):
param_map[param_names[i]] = data_array[i]
return param_map
for name in param_names:
if not name in param_map.keys():
......@@ -72,12 +84,23 @@ def wrapup_inputs(io_folder, param_names):
return param_map
def read_outputs(io_folder, out_num):
outputs = {}
for i in range(out_num):
def read_outputs(io_folder, out_names):
outputs = []
data_array = []
name_array = []
for i in range(len(out_names)):
file_name = io_folder + '/output_' + str(i) + '.pb'
name, data = read_pb_file(file_name)
outputs[name] = data
data_array.append(data)
if name:
name_array.append(name)
if len(name_array) < len(data_array):
return data_array
for name in out_names:
index = name_array.index(name)
outputs.append(data_array[index])
return outputs
......@@ -126,7 +149,7 @@ def run_one_case(model, param_map):
# convert np array to model argument
pp = {}
for key, val in param_map.items():
print("input = {}".format(val))
#print("input: {} = {}".format(key, val))
pp[key] = migraphx.argument(val)
# run the model
......@@ -198,7 +221,6 @@ def main():
# read and compile model
model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes)
# param_names = model.get_parameter_names()
output_shapes = model.get_output_shapes()
model.compile(migraphx.get_target(target))
......@@ -209,7 +231,7 @@ def main():
for case_name in cases:
io_folder = test_loc + '/' + case_name
input_data = wrapup_inputs(io_folder, param_names)
gold_outputs = read_outputs(io_folder, len(output_shapes))
gold_outputs = read_outputs(io_folder, output_names)
# if input shape is different from model shape, reload and recompile
# model
......@@ -221,12 +243,12 @@ def main():
# run the model and return outputs
output_data = run_one_case(model, input_data)
gold_output_data = []
for i in range(len((output_data))):
gold_output_data.append(gold_outputs[output_names[i]])
# gold_output_data = []
# for i in range(len((output_data))):
# gold_output_data.append(gold_outputs[output_names[i]])
# check output correctness
ret = check_correctness(gold_output_data, output_data)
ret = check_correctness(gold_outputs, output_data)
if ret:
correct_num += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment