".github/vscode:/vscode.git/clone" did not exist on "8d858c380e59fb307e5e8774ec7fa1866384345c"
Commit 697faa92 authored by liucong's avatar liucong
Browse files

更新resnet50工程

parent a08d10b6
......@@ -106,4 +106,14 @@ ErrorCode Classifier::Classify(const std::vector<cv::Mat> &srcImages,std::vector
- inputData表示MIGraphX的输入数据,inputData是一个映射关系,每个输入节点名都会对应一个输入数据,如果有多个输入,则需要为每个输入节点名创建数据,inputName表示输入节点名,migraphx::argument{inputShape, (float*)inputBlob.data}表示该节点名对应的数据,这里是通过前面预处理的数据inputBlob来创建的,第一个参数表示数据的shape,第二个参数表示数据指针。
- net.eval(inputData)返回模型的推理结果,由于这里只有一个输出节点,所以std::vector中只有一个数据,results[0]表示第一个输出节点,这里对应resnetv24_dense0_fwd节点,获取输出数据。
另外,如果想要指定输出节点,可以在eval()方法中通过提供outputNames参数来实现:
```
...
// 推理
std::vector<std::string> outputNames = {"resnetv24_dense0_fwd"}
std::vector<migraphx::argument> results = net.eval(inputData, outputNames);
...
```
- 如果没有指定outputName参数,则默认输出所有输出节点,此时输出节点的顺序与ONNX中输出节点顺序保持一致,可以通过netron查看ONNX文件的输出节点的顺序。
\ No newline at end of file
......@@ -86,7 +86,7 @@ if __name__ == '__main__':
numberOfOutput=outputShape.elements() # 输出节点元素的个数
# 获取分类结果
result=np.array(result)
print(np.array(result))
```
- Preprocessing()函数返回输入数据(numpy类型),然后通过{inputName: migraphx.argument(image)}构造一个字典输入模型执行推理,如果模型有多个输入,则在字典中需要添加多个输入数据。
......
......@@ -3,6 +3,7 @@
分类器示例
"""
import argparse
import cv2
import numpy as np
import migraphx
......@@ -44,11 +45,26 @@ def Preprocessing(pathOfImage):
return norm_img_data
if __name__ == '__main__':
# 设置最大输入shape
maxInput={"data":[1,3,224,224]}
# 加载模型
model = migraphx.parse_onnx("../Resource/Models/resnet50-v2-7.onnx")
inputName=model.get_parameter_names()[0]
inputShape=model.get_parameter_shapes()[inputName].lens()
print("inputName:{0} \ninputShape:{1}".format(inputName,inputShape))
model = migraphx.parse_onnx("../Resource/Models/resnet50-v2-7.onnx", map_input_dims=maxInput)
# 获取模型输入/输出节点信息
print("inputs:")
inputs = model.get_inputs()
for key,value in inputs.items():
print("{}:{}".format(key,value))
print("outputs:")
outputs = model.get_outputs()
for key,value in outputs.items():
print("{}:{}".format(key,value))
inputName="data"
inputShape=inputs[inputName].lens()
# 编译
model.compile(t=migraphx.get_target("gpu"),device_id=0) # device_id: 设置GPU设备,默认为0号设备
......@@ -58,7 +74,7 @@ if __name__ == '__main__':
image = Preprocessing(pathOfImage)
# 推理
results = model.run({inputName: image}) # 推理结果,list类型
results = model.run({inputName:image}) # 推理结果,list类型
# 获取输出节点属性
result=results[0] # 获取第一个输出节点的数据,migraphx.argument类型
......@@ -67,6 +83,4 @@ if __name__ == '__main__':
numberOfOutput=outputShape.elements() # 输出节点元素的个数
# 获取分类结果
result=np.array(results[0])
print(result)
\ No newline at end of file
print(np.array(result))
\ No newline at end of file
......@@ -43,19 +43,34 @@ ErrorCode Classifier::Initialize(InitializationParameterOfClassifier initializat
useInt8=(bool)(int)netNode["UseInt8"];
useFP16=(bool)(int)netNode["UseFP16"];
// 设置最大输入shape
migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["data"]={1,3,224,224};
// 加载模型
if(Exists(modelPath)==false)
{
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST;
}
net = migraphx::parse_onnx(modelPath);
net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入属性
std::unordered_map<std::string, migraphx::shape> inputMap=net.get_parameter_shapes();
inputName=inputMap.begin()->first;
inputShape=inputMap.begin()->second;
// 获取模型输入/输出节点信息
std::cout<<"inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
int N=inputShape.lens()[0];
int C=inputShape.lens()[1];
int H=inputShape.lens()[2];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment