Commit 29410168 authored by liucong's avatar liucong
Browse files

修改unet工程格式

parent 2796f59f
...@@ -31,15 +31,8 @@ if __name__ == '__main__': ...@@ -31,15 +31,8 @@ if __name__ == '__main__':
model = migraphx.parse_onnx("../Resource/Models/unet_13_256.onnx", map_input_dims=maxInput) model = migraphx.parse_onnx("../Resource/Models/unet_13_256.onnx", map_input_dims=maxInput)
# 获取模型输入/输出节点信息 # 获取模型输入/输出节点信息
print("inputs:")
inputs = model.get_inputs() inputs = model.get_inputs()
for key,value in inputs.items():
print("{}:{}".format(key,value))
print("outputs:")
outputs = model.get_outputs() outputs = model.get_outputs()
for key,value in outputs.items():
print("{}:{}".format(key,value))
# 编译模型 # 编译模型
model.compile(migraphx.get_target("gpu"), device_id=0) # device_id: 设置GPU设备,默认为0号设备 model.compile(migraphx.get_target("gpu"), device_id=0) # device_id: 设置GPU设备,默认为0号设备
......
...@@ -28,7 +28,7 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm ...@@ -28,7 +28,7 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm
{ {
// 读取配置文件 // 读取配置文件
std::string configFilePath=initParamOfSegmentationUnet.configFilePath; std::string configFilePath=initParamOfSegmentationUnet.configFilePath;
if(Exists(configFilePath)==false) if(!Exists(configFilePath))
{ {
LOG_ERROR(stdout, "no configuration file!\n"); LOG_ERROR(stdout, "no configuration file!\n");
return CONFIG_FILE_NOT_EXIST; return CONFIG_FILE_NOT_EXIST;
...@@ -49,7 +49,7 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm ...@@ -49,7 +49,7 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm
onnx_options.map_input_dims["inputs"]={1,3,256,256}; onnx_options.map_input_dims["inputs"]={1,3,256,256};
// 加载模型 // 加载模型
if(Exists(modelPath)==false) if(!Exists(modelPath))
{ {
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST; return MODEL_NOT_EXIST;
...@@ -58,18 +58,8 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm ...@@ -58,18 +58,8 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息 // 获取模型输入/输出节点信息
std::cout<<"inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs(); std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs(); std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first; inputName=inputs.begin()->first;
inputShape=inputs.begin()->second; inputShape=inputs.begin()->second;
int N=inputShape.lens()[0]; int N=inputShape.lens()[0];
......
...@@ -24,11 +24,7 @@ int main() ...@@ -24,11 +24,7 @@ int main()
// 推理 // 推理
cv::Mat maskImage; cv::Mat maskImage;
double time1 = cv::getTickCount();
unet.Segmentation(srcImage, maskImage); unet.Segmentation(srcImage, maskImage);
double time2 = cv::getTickCount();
double elapsedTime = (time2 - time1) * 1000 / cv::getTickFrequency();
LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime);
LOG_INFO(stdout,"========== Segmentation Results ==========\n"); LOG_INFO(stdout,"========== Segmentation Results ==========\n");
LOG_INFO(stdout,"Segmentation results have been saved to ./Result.jpg\n"); LOG_INFO(stdout,"Segmentation results have been saved to ./Result.jpg\n");
cv::imwrite("./Result.jpg", maskImage); cv::imwrite("./Result.jpg", maskImage);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment