Commit eb607cf9 authored by liucong's avatar liucong
Browse files

修改crnn工程格式

parent 082cd068
......@@ -10,33 +10,25 @@ class Crnn:
self.model = migraphx.parse_onnx(path)
# 获取模型输入/输出节点信息
print("inputs:")
inputs = self.model.get_inputs()
for key,value in inputs.items():
print("{}:{}".format(key,value))
print("outputs:")
outputs = self.model.get_outputs()
for key,value in outputs.items():
print("{}:{}".format(key,value))
# 获取模型的输入name
self.inputName = self.model.get_parameter_names()[0]
# 获取模型的输入尺寸
self.inputShape = inputs[self.inputName].lens()
print("inputName:{0} \ninputShape:{1}".format(self.inputName, self.inputShape))
# 模型编译
self.model.compile(t=migraphx.get_target("gpu"), device_id=0) # device_id: 设置GPU设备,默认为0号设备
print("Success to compile")
def infer(self, image):
inputImage = self.prepare_input(image)
# 执行推理
results = self.model.run({self.model.get_parameter_names()[0]: inputImage})
# 获取第一个输出节点的数据,migraphx.argument类型
result=results[0]
result=np.array(result)
......@@ -86,10 +78,8 @@ if __name__ == '__main__':
srcimg = cv2.imread(args.imgpath, 1)
# 执行推理
print("Start to inference")
start = time.time()
resultRaw, resultSim = crnn.infer(srcimg)
print('net forward time: {:.4f}'.format(time.time() - start))
print("============= Ocr Results =============")
print('%-20s => %-20s' % (resultRaw, resultSim))
......
......@@ -15,16 +15,14 @@ Crnn::Crnn()
Crnn::~Crnn()
{
configurationFile.release();
}
ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterOfOcr, bool dynamic)
{
// 读取配置文件
std::string configFilePath=initializationParameterOfOcr.configFilePath;
if(Exists(configFilePath)==false)
if(!Exists(configFilePath))
{
LOG_ERROR(stdout, "no configuration file!\n");
return CONFIG_FILE_NOT_EXIST;
......@@ -41,7 +39,7 @@ ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterO
std::string modelPath=(std::string)netNode["ModelPath"];
// 加载模型
if(Exists(modelPath)==false)
if(!Exists(modelPath))
{
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST;
......@@ -56,19 +54,8 @@ ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterO
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息
std::cout<<"inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
int N=inputShape.lens()[0];
......@@ -90,19 +77,8 @@ ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterO
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息
std::cout<<"inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
int N=inputShape.lens()[0];
......@@ -169,11 +145,11 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b
if(dynamic)
{
std::vector<std::size_t> dynamicShape = {1, 1, 32, width};
inputData[inputName]= migraphx::argument{migraphx::shape(inputShape.type(),dynamicShape), (float*)inputBlob.data};
inputData[inputName] = migraphx::argument{migraphx::shape(inputShape.type(),dynamicShape), (float*)inputBlob.data};
}
else
{
inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data};
inputData[inputName] = migraphx::argument{inputShape, (float*)inputBlob.data};
}
// 推理
......@@ -205,7 +181,7 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b
}
//字符转录处理
for(uint i=0; i<predChars.size(); i++)
for(uint i=0; i < predChars.size(); i++)
{
if(raw)
{
......@@ -215,7 +191,7 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b
{
if(predChars[i] != 0)
{
if(!(i > 0 && predChars[i-1]==predChars[i]))
if(!(i > 0 && predChars[i-1] == predChars[i]))
{
resultsChar.push_back(alphabet[predChars[i]]);
}
......
......@@ -55,8 +55,8 @@ void Sample_Crnn()
migraphxSamples::Crnn crnn;
migraphxSamples::InitializationParameterOfOcr initParamOfOcrCRNN;
initParamOfOcrCRNN.configFilePath = CONFIG_FILE;
migraphxSamples::ErrorCode errorCode=crnn.Initialize(initParamOfOcrCRNN, false);
if(errorCode!=migraphxSamples::SUCCESS)
migraphxSamples::ErrorCode errorCode = crnn.Initialize(initParamOfOcrCRNN, false);
if(errorCode! = migraphxSamples::SUCCESS)
{
LOG_ERROR(stdout, "fail to initialize crnn!\n");
exit(-1);
......@@ -64,17 +64,13 @@ void Sample_Crnn()
LOG_INFO(stdout, "succeed to initialize crnn\n");
// 读取测试图片
cv:: Mat srcImage=cv::imread("../Resource/Images/text.jpg", 1);
cv:: Mat srcImage = cv::imread("../Resource/Images/text.jpg", 1);
// 推理
std::vector<char> resultRaw;
std::vector<char> resultSim;
double time1 = cv::getTickCount();
crnn.Infer(srcImage,resultRaw, true, false);
crnn.Infer(srcImage,resultSim, false, false);
double time2 = cv::getTickCount();
double elapsedTime = (time2 - time1)*1000*0.5 / cv::getTickFrequency();
LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime);
crnn.Infer(srcImage, resultRaw, true, false);
crnn.Infer(srcImage, resultSim, false, false);
// 获取推理结果
LOG_INFO(stdout,"========== Ocr Results ==========\n");
......@@ -96,8 +92,8 @@ void Sample_Crnn_Dynamic()
migraphxSamples::Crnn crnn;
migraphxSamples::InitializationParameterOfOcr initParamOfOcrCRNN;
initParamOfOcrCRNN.configFilePath = CONFIG_FILE;
migraphxSamples::ErrorCode errorCode=crnn.Initialize(initParamOfOcrCRNN, true);
if(errorCode!=migraphxSamples::SUCCESS)
migraphxSamples::ErrorCode errorCode = crnn.Initialize(initParamOfOcrCRNN, true);
if(errorCode! = migraphxSamples::SUCCESS)
{
LOG_ERROR(stdout, "fail to initialize crnn!\n");
exit(-1);
......@@ -111,7 +107,7 @@ void Sample_Crnn_Dynamic()
cv::glob(folder,imagePathList);
for (int i = 0; i < imagePathList.size(); ++i)
{
cv:: Mat srcImage=cv::imread(imagePathList[i], 1);
cv:: Mat srcImage = cv::imread(imagePathList[i], 1);
srcImages.push_back(srcImage);
}
......@@ -120,7 +116,7 @@ void Sample_Crnn_Dynamic()
for(int i=0; i<srcImages.size(); ++i)
{
std::vector<char> resultSim;
crnn.Infer(srcImages[i],resultSim, false, true);
crnn.Infer(srcImages[i], resultSim, false, true);
for(int i = 0; i < resultSim.size(); i++)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment