Commit 68b1a45f authored by liucong's avatar liucong
Browse files

修改代码

parent ce346218
...@@ -23,9 +23,12 @@ def Sigmoid(x): ...@@ -23,9 +23,12 @@ def Sigmoid(x):
return 1 / (1 + np.exp(-x)) return 1 / (1 + np.exp(-x))
if __name__ == '__main__': if __name__ == '__main__':
# 设置最大输入shape
maxInput={"inputs":[1,3,256,256]}
# 加载模型 # 加载模型
model = migraphx.parse_onnx("../Resource/Models/unet_13_256.onnx") model = migraphx.parse_onnx("../Resource/Models/unet_13_256.onnx", map_input_dims=maxInput)
# 获取模型输入/输出节点信息 # 获取模型输入/输出节点信息
print("inputs:") print("inputs:")
...@@ -38,10 +41,6 @@ if __name__ == '__main__': ...@@ -38,10 +41,6 @@ if __name__ == '__main__':
for key,value in outputs.items(): for key,value in outputs.items():
print("{}:{}".format(key,value)) print("{}:{}".format(key,value))
inputName = model.get_parameter_names()
inputShape = inputs[inputName].lens()
print("inputName:{0} \ninputShape:{1}".format(inputName, inputShape))
# 编译模型 # 编译模型
model.compile(migraphx.get_target("gpu"), device_id=0) # device_id: 设置GPU设备,默认为0号设备 model.compile(migraphx.get_target("gpu"), device_id=0) # device_id: 设置GPU设备,默认为0号设备
......
...@@ -44,13 +44,17 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm ...@@ -44,13 +44,17 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm
cv::FileNode netNode = configurationFile["Unet"]; cv::FileNode netNode = configurationFile["Unet"];
std::string modelPath=(std::string)netNode["ModelPath"]; std::string modelPath=(std::string)netNode["ModelPath"];
// 设置最大输入shape
migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["inputs"]={1,3,256,256};
// 加载模型 // 加载模型
if(Exists(modelPath)==false) if(Exists(modelPath)==false)
{ {
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST; return MODEL_NOT_EXIST;
} }
net = migraphx::parse_onnx(modelPath); net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息 // 获取模型输入/输出节点信息
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment