Commit 453f511c authored by shizhm's avatar shizhm
Browse files

完善代码格式

parent 165d6c8b
#include <Crnn.h> #include <Crnn.h>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <Filesystem.h> #include <Filesystem.h>
#include <SimpleLog.h> #include <SimpleLog.h>
...@@ -19,7 +17,7 @@ Crnn::~Crnn() ...@@ -19,7 +17,7 @@ Crnn::~Crnn()
{ {
configurationFile.release(); configurationFile.release();
} }
ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterOfOcr, bool dynamic) ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterOfOcr, bool dynamic)
...@@ -50,11 +48,11 @@ ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterO ...@@ -50,11 +48,11 @@ ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterO
} }
if(dynamic) if(dynamic)
{ {
migraphx::onnx_options onnx_options; migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["input"]={1,1,32,512}; onnx_options.map_input_dims["input"]={1,1,32,512};
net = migraphx::parse_onnx(modelPath, onnx_options); net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入属性 // 获取模型输入属性
...@@ -76,7 +74,7 @@ ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterO ...@@ -76,7 +74,7 @@ ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterO
migraphx::onnx_options onnx_options; migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["input"]={1,1,32,100}; onnx_options.map_input_dims["input"]={1,1,32,100};
net = migraphx::parse_onnx(modelPath, onnx_options); net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入属性 // 获取模型输入属性
...@@ -100,14 +98,14 @@ ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterO ...@@ -100,14 +98,14 @@ ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterO
// 编译模型 // 编译模型
migraphx::compile_options options; migraphx::compile_options options;
options.device_id=0; // 设置GPU设备,默认为0号设备 options.device_id=0; // 设置GPU设备,默认为0号设备
options.offload_copy=true; options.offload_copy=true;
net.compile(gpuTarget,options); net.compile(gpuTarget,options);
LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());
// warm up // warm up
std::unordered_map<std::string, migraphx::argument> inputData; std::unordered_map<std::string, migraphx::argument> inputData;
inputData[inputName]=migraphx::argument{inputShape}; inputData[inputName]=migraphx::argument{inputShape};
net.eval(inputData); net.eval(inputData);
return SUCCESS; return SUCCESS;
} }
...@@ -122,7 +120,7 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b ...@@ -122,7 +120,7 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b
cv::Mat inputImage, inputBlob; cv::Mat inputImage, inputBlob;
cv::cvtColor(srcImage, inputImage, CV_BGR2GRAY); cv::cvtColor(srcImage, inputImage, CV_BGR2GRAY);
int height, width, widthRaw; int height, width, widthRaw;
widthRaw = inputImage.cols; widthRaw = inputImage.cols;
if(dynamic) if(dynamic)
...@@ -136,12 +134,12 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b ...@@ -136,12 +134,12 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b
height = inputImage.rows, width = inputImage.cols; height = inputImage.rows, width = inputImage.cols;
} }
inputBlob = cv::dnn::blobFromImage(inputImage); inputBlob = cv::dnn::blobFromImage(inputImage);
for(int i=0; i<width * height; i++) for(int i=0; i<width * height; i++)
{ {
*((float*)inputBlob.data+i) = ((*((float*)inputBlob.data+i))/255.f - 0.5)/0.5; *((float*)inputBlob.data+i) = ((*((float*)inputBlob.data+i))/255.f - 0.5)/0.5;
} }
// 创建输入数据 // 创建输入数据
std::unordered_map<std::string, migraphx::argument> inputData; std::unordered_map<std::string, migraphx::argument> inputData;
if(dynamic) if(dynamic)
...@@ -153,13 +151,13 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b ...@@ -153,13 +151,13 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b
{ {
inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data}; inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data};
} }
// 推理 // 推理
std::vector<migraphx::argument> inferenceResults = net.eval(inputData); std::vector<migraphx::argument> inferenceResults = net.eval(inputData);
// 获取推理结果 // 获取推理结果
std::vector<cv::Mat> outs; std::vector<cv::Mat> outs;
migraphx::argument result = inferenceResults[0]; migraphx::argument result = inferenceResults[0];
// 转换为cv::Mat // 转换为cv::Mat
migraphx::shape outputShape = result.get_shape(); migraphx::shape outputShape = result.get_shape();
...@@ -168,7 +166,7 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b ...@@ -168,7 +166,7 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b
memcpy(out.data,result.data(),sizeof(float)*outputShape.elements()); memcpy(out.data,result.data(),sizeof(float)*outputShape.elements());
outs.push_back(out); outs.push_back(out);
std::vector<int> predChars; std::vector<int> predChars;
const std::string alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz"; const std::string alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz";
//获取字符索引序列 //获取字符索引序列
...@@ -181,7 +179,7 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b ...@@ -181,7 +179,7 @@ ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector<char> &resultsChar, b
int maxIdx = charIdPoint.x; int maxIdx = charIdPoint.x;
predChars.push_back(maxIdx); predChars.push_back(maxIdx);
} }
//字符转录处理 //字符转录处理
for(uint i=0; i<predChars.size(); i++) for(uint i=0; i<predChars.size(); i++)
{ {
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#define __OCR_CRNN_H__ #define __OCR_CRNN_H__
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <CommonDefinition.h> #include <CommonDefinition.h>
namespace migraphxSamples namespace migraphxSamples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment