Commit 68b1a45f authored by liucong's avatar liucong
Browse files

修改代码

parent ce346218
......@@ -23,9 +23,12 @@ def Sigmoid(x):
return 1 / (1 + np.exp(-x))
if __name__ == '__main__':
# 设置最大输入shape
maxInput={"inputs":[1,3,256,256]}
# 加载模型
model = migraphx.parse_onnx("../Resource/Models/unet_13_256.onnx")
model = migraphx.parse_onnx("../Resource/Models/unet_13_256.onnx", map_input_dims=maxInput)
# 获取模型输入/输出节点信息
print("inputs:")
......@@ -38,10 +41,6 @@ if __name__ == '__main__':
for key,value in outputs.items():
print("{}:{}".format(key,value))
inputName = model.get_parameter_names()
inputShape = inputs[inputName].lens()
print("inputName:{0} \ninputShape:{1}".format(inputName, inputShape))
# 编译模型
model.compile(migraphx.get_target("gpu"), device_id=0) # device_id: 设置GPU设备,默认为0号设备
......
......@@ -44,13 +44,17 @@ ErrorCode Unet::Initialize(InitializationParameterOfSegmentation initParamOfSegm
cv::FileNode netNode = configurationFile["Unet"];
std::string modelPath=(std::string)netNode["ModelPath"];
// 设置最大输入shape
migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["inputs"]={1,3,256,256};
// 加载模型
if(Exists(modelPath)==false)
{
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST;
}
net = migraphx::parse_onnx(modelPath);
net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment