Commit 544811c3 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine test_runner to match inputs and outputs according to their names

parent eb6abd27
import os
import os, sys
import numpy as np
import argparse
import onnx
......@@ -54,31 +54,72 @@ def read_pb_file(filename):
tensor.ParseFromString(data_str)
np_array = numpy_helper.to_array(tensor)
return np_array
return tensor.name, np_array
def wrapup_inputs(io_folder, parameter_names):
index = 0
def wrapup_inputs(io_folder, param_names):
param_map = {}
for param_name in parameter_names:
file_name = io_folder + '/input_' + str(index) + '.pb'
data = read_pb_file(file_name)
param_map[param_name] = data
index = index + 1
for i in range(len(param_names)):
file_name = io_folder + '/input_' + str(i) + '.pb'
name, data = read_pb_file(file_name)
param_map[name] = data
for name in param_names:
if not name in param_map.keys():
print("Input {} does not exist!".format(name))
sys.exit()
return param_map
def read_outputs(io_folder, out_num):
outputs = []
outputs = {}
for i in range(out_num):
file_name = io_folder + '/output_' + str(i) + '.pb'
data = read_pb_file(file_name)
outputs.append(data)
name, data = read_pb_file(file_name)
outputs[name] = data
return outputs
def model_parameter_names(model_file_name):
with open(model_file_name, 'rb') as pfile:
data_str = pfile.read()
model_proto = onnx.ModelProto()
model_proto.ParseFromString(data_str)
init_names = set([(i.name) for i in model_proto.graph.initializer])
param_names = []
for input in model_proto.graph.input:
if input.name not in init_names:
param_names.append(input.name)
return param_names
def model_output_names(model_file_name):
with open(model_file_name, 'rb') as pfile:
data_str = pfile.read()
model_proto = onnx.ModelProto()
model_proto.ParseFromString(data_str)
output_names = [out.name for out in model_proto.graph.output]
return output_names
def get_input_shapes(sample_case, param_names):
param_shape_map = {}
for i in range(len(param_names)):
file_name = sample_case + '/input_' + str(i) + '.pb'
name, data = read_pb_file(file_name)
param_shape_map[name] = list(data.shape)
print("{}: {}".format(name, data.shape))
for name in param_names:
if not name in param_shape_map.keys():
print("Input {} does not exist!".format(name))
sys.exit()
return param_shape_map
def run_one_case(model, param_map):
# convert np array to model argument
pp = {}
......@@ -106,8 +147,8 @@ def check_correctness(gold_outputs, outputs, rtol=1e-3, atol=1e-3):
out_num = len(gold_outputs)
ret = True
for i in range(out_num):
print("Expected value: \n{}".format(gold_outputs[i]))
print("Actual value: \n{}".format(outputs[i]))
# print("Expected value: \n{}".format(gold_outputs[i]))
# print("Actual value: \n{}".format(outputs[i]))
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
print("Output {} is incorrect ...".format(i))
print("Expected value: \n{}".format(gold_outputs[i]))
......@@ -142,21 +183,33 @@ def main():
# get model full path
model_name = get_model_name(test_loc)
model_path_name = test_loc + '/' + model_name
# get param names
param_names = model_parameter_names(model_path_name)
print("param_name = {}".format(param_names))
# get output names
output_names = model_output_names(model_path_name)
# get test cases
cases = get_test_cases(test_loc)
sample_case = test_loc + '/' + cases[0]
param_shapes = get_input_shapes(sample_case, param_names)
# read and compile model
model = migraphx.parse_onnx(model_path_name)
param_names = model.get_parameter_names()
model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes)
# param_names = model.get_parameter_names()
output_shapes = model.get_output_shapes()
model.compile(migraphx.get_target(target))
# get test cases
cases = get_test_cases(test_loc)
case_num = len(cases)
correct_num = 0
for case_name in cases:
io_folder = test_loc + '/' + case_name
input_data = wrapup_inputs(io_folder, param_names)
gold_output_data = read_outputs(io_folder, len(output_shapes))
gold_outputs = read_outputs(io_folder, len(output_shapes))
# if input shape is different from model shape, reload and recompile
# model
......@@ -168,6 +221,9 @@ def main():
# run the model and return outputs
output_data = run_one_case(model, input_data)
gold_output_data = []
for i in range(len((output_data))):
gold_output_data.append(gold_outputs[output_names[i]])
# check output correctness
ret = check_correctness(gold_output_data, output_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment