Commit 2acccfb0 authored by liucong's avatar liucong
Browse files

修改input/output代码

parent aa4c558e
......@@ -296,9 +296,20 @@ class det_rec_functions(object):
# 解析检测模型
detInput = {"x":[1,3,2496,2496]}
self.modelDet = migraphx.parse_onnx(self.det_file, map_input_dims=detInput)
self.inputName = self.modelDet.get_parameter_names()[0]
self.inputShape = self.modelDet.get_parameter_shapes()[self.inputName].lens()
# 获取模型输入/输出节点信息
print("det_inputs:")
inputs_det = self.modelDet.get_inputs()
for key,value in inputs_det.items():
print("{}:{}".format(key,value))
print("det_outputs:")
outputs_det = self.modelDet.get_outputs()
for key,value in outputs_det.items():
print("{}:{}".format(key,value))
self.inputName = self.modelDet.get_parameter_names()[0]
self.inputShape = inputs_det[self.inputName].lens()
print("DB inputName:{0} \nDB inputShape:{1}".format(self.inputName, self.inputShape))
# 模型编译
......@@ -308,9 +319,20 @@ class det_rec_functions(object):
# 解析识别模型
recInput = {"x":[1,3,48,320]}
self.modelRec = migraphx.parse_onnx(self.rec_file, map_input_dims=recInput)
self.inputName = self.modelRec.get_parameter_names()[0]
self.inputShape = self.modelRec.get_parameter_shapes()[self.inputName].lens()
# 获取模型输入/输出节点信息
print("rec_inputs:")
inputs_rec = self.modelRec.get_inputs()
for key,value in inputs_rec.items():
print("{}:{}".format(key,value))
print("rec_outputs:")
outputs_rec = self.modelRec.get_outputs()
for key,value in outputs_rec.items():
print("{}:{}".format(key,value))
self.inputName = self.modelRec.get_parameter_names()[0]
self.inputShape = inputs_rec[self.inputName].lens()
print("SVTR inputName:{0} \nSVTR inputShape:{1}".format(self.inputName, self.inputShape))
# 模型编译
......
......@@ -57,10 +57,21 @@ ErrorCode DB::Initialize(InitializationParameterOfDB InitializationParameterOfDB
net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入属性
std::unordered_map<std::string, migraphx::shape> inputMap=net.get_parameter_shapes();
inputName=inputMap.begin()->first;
inputShape=inputMap.begin()->second;
// 获取模型输入/输出节点信息
std::cout<<"DB_inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"DB_outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
int N=inputShape.lens()[0];
int C=inputShape.lens()[1];
int H=inputShape.lens()[2];
......
......@@ -50,10 +50,22 @@ ErrorCode SVTR::Initialize(InitializationParameterOfSVTR InitializationParameter
net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入属性
std::unordered_map<std::string, migraphx::shape> inputMap=net.get_parameter_shapes();
inputName=inputMap.begin()->first;
inputShape=inputMap.begin()->second;
// 获取模型输入/输出节点信息
std::cout<<"SVTR_inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"DSVTR_outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
int N=inputShape.lens()[0];
int C=inputShape.lens()[1];
int H=inputShape.lens()[2];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment