Commit 29410168 authored by liucong's avatar liucong
Browse files

修改unet工程格式

parent 2796f59f
......@@ -31,15 +31,8 @@ if __name__ == '__main__':
model = migraphx.parse_onnx("../Resource/Models/unet_13_256.onnx", map_input_dims=maxInput)
# 获取模型输入/输出节点信息
print("inputs:")
inputs = model.get_inputs()
for key,value in inputs.items():
print("{}:{}".format(key,value))
print("outputs:")
outputs = model.get_outputs()
for key,value in outputs.items():
print("{}:{}".format(key,value))
# 编译模型
model.compile(migraphx.get_target("gpu"), device_id=0) # device_id: 设置GPU设备,默认为0号设备
......
......@@ -28,7 +28,7 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm
{
// 读取配置文件
std::string configFilePath=initParamOfSegmentationUnet.configFilePath;
if(Exists(configFilePath)==false)
if(!Exists(configFilePath))
{
LOG_ERROR(stdout, "no configuration file!\n");
return CONFIG_FILE_NOT_EXIST;
......@@ -49,7 +49,7 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm
onnx_options.map_input_dims["inputs"]={1,3,256,256};
// 加载模型
if(Exists(modelPath)==false)
if(!Exists(modelPath))
{
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST;
......@@ -58,18 +58,8 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息
std::cout<<"inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
int N=inputShape.lens()[0];
......
......@@ -24,11 +24,7 @@ int main()
// 推理
cv::Mat maskImage;
double time1 = cv::getTickCount();
unet.Segmentation(srcImage, maskImage);
double time2 = cv::getTickCount();
double elapsedTime = (time2 - time1) * 1000 / cv::getTickFrequency();
LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime);
LOG_INFO(stdout,"========== Segmentation Results ==========\n");
LOG_INFO(stdout,"Segmentation results have been saved to ./Result.jpg\n");
cv::imwrite("./Result.jpg", maskImage);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment