Commit 0b851159 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent d3267bb3
...@@ -111,10 +111,7 @@ def model_parameter_names(model_file_name): ...@@ -111,10 +111,7 @@ def model_parameter_names(model_file_name):
model_proto = onnx.ModelProto() model_proto = onnx.ModelProto()
model_proto.ParseFromString(data_str) model_proto.ParseFromString(data_str)
init_names = set([(i.name) for i in model_proto.graph.initializer]) init_names = set([(i.name) for i in model_proto.graph.initializer])
param_names = [] param_names = [input.name for input in model_proto.graph.input if input.name not in init_names]
for input in model_proto.graph.input:
if input.name not in init_names:
param_names.append(input.name)
return param_names return param_names
...@@ -149,7 +146,7 @@ def get_input_shapes(sample_case, param_names): ...@@ -149,7 +146,7 @@ def get_input_shapes(sample_case, param_names):
return param_shape_map return param_shape_map
for name in param_names: for name in param_names:
if not name in param_shape_map.keys(): if not name in param_shape_map:
print("Input {} does not exist!".format(name)) print("Input {} does not exist!".format(name))
sys.exit() sys.exit()
...@@ -160,7 +157,6 @@ def run_one_case(model, param_map): ...@@ -160,7 +157,6 @@ def run_one_case(model, param_map):
# convert np array to model argument # convert np array to model argument
pp = {} pp = {}
for key, val in param_map.items(): for key, val in param_map.items():
#print("input: {} = {}".format(key, val))
pp[key] = migraphx.argument(val) pp[key] = migraphx.argument(val)
# run the model # run the model
...@@ -257,9 +253,6 @@ def main(): ...@@ -257,9 +253,6 @@ def main():
# run the model and return outputs # run the model and return outputs
output_data = run_one_case(model, input_data) output_data = run_one_case(model, input_data)
# gold_output_data = []
# for i in range(len((output_data))):
# gold_output_data.append(gold_outputs[output_names[i]])
# check output correctness # check output correctness
ret = check_correctness(gold_outputs, output_data) ret = check_correctness(gold_outputs, output_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment