Commit 544811c3 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine test_runner to match inputs and outputs according to their names

parent eb6abd27
import os import os, sys
import numpy as np import numpy as np
import argparse import argparse
import onnx import onnx
...@@ -54,31 +54,72 @@ def read_pb_file(filename): ...@@ -54,31 +54,72 @@ def read_pb_file(filename):
tensor.ParseFromString(data_str) tensor.ParseFromString(data_str)
np_array = numpy_helper.to_array(tensor) np_array = numpy_helper.to_array(tensor)
return np_array return tensor.name, np_array
def wrapup_inputs(io_folder, parameter_names): def wrapup_inputs(io_folder, param_names):
index = 0
param_map = {} param_map = {}
for param_name in parameter_names: for i in range(len(param_names)):
file_name = io_folder + '/input_' + str(index) + '.pb' file_name = io_folder + '/input_' + str(i) + '.pb'
data = read_pb_file(file_name) name, data = read_pb_file(file_name)
param_map[param_name] = data param_map[name] = data
index = index + 1
for name in param_names:
if not name in param_map.keys():
print("Input {} does not exist!".format(name))
sys.exit()
return param_map return param_map
def read_outputs(io_folder, out_num): def read_outputs(io_folder, out_num):
outputs = [] outputs = {}
for i in range(out_num): for i in range(out_num):
file_name = io_folder + '/output_' + str(i) + '.pb' file_name = io_folder + '/output_' + str(i) + '.pb'
data = read_pb_file(file_name) name, data = read_pb_file(file_name)
outputs.append(data) outputs[name] = data
return outputs return outputs
def model_parameter_names(model_file_name):
with open(model_file_name, 'rb') as pfile:
data_str = pfile.read()
model_proto = onnx.ModelProto()
model_proto.ParseFromString(data_str)
init_names = set([(i.name) for i in model_proto.graph.initializer])
param_names = []
for input in model_proto.graph.input:
if input.name not in init_names:
param_names.append(input.name)
return param_names
def model_output_names(model_file_name):
with open(model_file_name, 'rb') as pfile:
data_str = pfile.read()
model_proto = onnx.ModelProto()
model_proto.ParseFromString(data_str)
output_names = [out.name for out in model_proto.graph.output]
return output_names
def get_input_shapes(sample_case, param_names):
param_shape_map = {}
for i in range(len(param_names)):
file_name = sample_case + '/input_' + str(i) + '.pb'
name, data = read_pb_file(file_name)
param_shape_map[name] = list(data.shape)
print("{}: {}".format(name, data.shape))
for name in param_names:
if not name in param_shape_map.keys():
print("Input {} does not exist!".format(name))
sys.exit()
return param_shape_map
def run_one_case(model, param_map): def run_one_case(model, param_map):
# convert np array to model argument # convert np array to model argument
pp = {} pp = {}
...@@ -106,8 +147,8 @@ def check_correctness(gold_outputs, outputs, rtol=1e-3, atol=1e-3): ...@@ -106,8 +147,8 @@ def check_correctness(gold_outputs, outputs, rtol=1e-3, atol=1e-3):
out_num = len(gold_outputs) out_num = len(gold_outputs)
ret = True ret = True
for i in range(out_num): for i in range(out_num):
print("Expected value: \n{}".format(gold_outputs[i])) # print("Expected value: \n{}".format(gold_outputs[i]))
print("Actual value: \n{}".format(outputs[i])) # print("Actual value: \n{}".format(outputs[i]))
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol): if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
print("Output {} is incorrect ...".format(i)) print("Output {} is incorrect ...".format(i))
print("Expected value: \n{}".format(gold_outputs[i])) print("Expected value: \n{}".format(gold_outputs[i]))
...@@ -142,21 +183,33 @@ def main(): ...@@ -142,21 +183,33 @@ def main():
# get model full path # get model full path
model_name = get_model_name(test_loc) model_name = get_model_name(test_loc)
model_path_name = test_loc + '/' + model_name model_path_name = test_loc + '/' + model_name
# get param names
param_names = model_parameter_names(model_path_name)
print("param_name = {}".format(param_names))
# get output names
output_names = model_output_names(model_path_name)
# get test cases
cases = get_test_cases(test_loc)
sample_case = test_loc + '/' + cases[0]
param_shapes = get_input_shapes(sample_case, param_names)
# read and compile model # read and compile model
model = migraphx.parse_onnx(model_path_name) model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes)
param_names = model.get_parameter_names() # param_names = model.get_parameter_names()
output_shapes = model.get_output_shapes() output_shapes = model.get_output_shapes()
model.compile(migraphx.get_target(target)) model.compile(migraphx.get_target(target))
# get test cases # get test cases
cases = get_test_cases(test_loc)
case_num = len(cases) case_num = len(cases)
correct_num = 0 correct_num = 0
for case_name in cases: for case_name in cases:
io_folder = test_loc + '/' + case_name io_folder = test_loc + '/' + case_name
input_data = wrapup_inputs(io_folder, param_names) input_data = wrapup_inputs(io_folder, param_names)
gold_output_data = read_outputs(io_folder, len(output_shapes)) gold_outputs = read_outputs(io_folder, len(output_shapes))
# if input shape is different from model shape, reload and recompile # if input shape is different from model shape, reload and recompile
# model # model
...@@ -168,6 +221,9 @@ def main(): ...@@ -168,6 +221,9 @@ def main():
# run the model and return outputs # run the model and return outputs
output_data = run_one_case(model, input_data) output_data = run_one_case(model, input_data)
gold_output_data = []
for i in range(len((output_data))):
gold_output_data.append(gold_outputs[output_names[i]])
# check output correctness # check output correctness
ret = check_correctness(gold_output_data, output_data) ret = check_correctness(gold_output_data, output_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment