#include #include #include #include #include namespace migraphxSamples { Crnn::Crnn() { } Crnn::~Crnn() { configurationFile.release(); } ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterOfOcr, bool dynamic) { // 读取配置文件 std::string configFilePath=initializationParameterOfOcr.configFilePath; if(Exists(configFilePath)==false) { LOG_ERROR(stdout, "no configuration file!\n"); return CONFIG_FILE_NOT_EXIST; } if(!configurationFile.open(configFilePath, cv::FileStorage::READ)) { LOG_ERROR(stdout, "fail to open configuration file\n"); return FAIL_TO_OPEN_CONFIG_FILE; } LOG_INFO(stdout, "succeed to open configuration file\n"); // 获取配置文件参数 cv::FileNode netNode = configurationFile["CrnnDynamic"]; std::string modelPath=(std::string)netNode["ModelPath"]; // 加载模型 if(Exists(modelPath)==false) { LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); return MODEL_NOT_EXIST; } if(dynamic) { migraphx::onnx_options onnx_options; onnx_options.map_input_dims["input"]={1,1,32,512}; net = migraphx::parse_onnx(modelPath, onnx_options); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); // 获取模型输入/输出节点信息 std::cout<<"inputs:"< inputs=net.get_inputs(); for(auto i:inputs) { std::cout< outputs=net.get_outputs(); for(auto i:outputs) { std::cout<first; inputShape=inputs.begin()->second; int N=inputShape.lens()[0]; int C=inputShape.lens()[1]; int H=inputShape.lens()[2]; int W=inputShape.lens()[3]; inputSize=cv::Size(W,H); // log输出日志信息 LOG_INFO(stdout,"InputMaxSize:%dx%d\n",inputSize.width,inputSize.height); LOG_INFO(stdout,"InputName:%s\n",inputName.c_str()); } else { migraphx::onnx_options onnx_options; onnx_options.map_input_dims["input"]={1,1,32,100}; net = migraphx::parse_onnx(modelPath, onnx_options); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); // 获取模型输入/输出节点信息 std::cout<<"inputs:"< inputs=net.get_inputs(); for(auto i:inputs) { std::cout< outputs=net.get_outputs(); for(auto i:outputs) { std::cout<first; inputShape=inputs.begin()->second; int N=inputShape.lens()[0]; int C=inputShape.lens()[1]; int H=inputShape.lens()[2]; int W=inputShape.lens()[3]; inputSize=cv::Size(W,H); // log输出日志信息 LOG_INFO(stdout,"InputSize:%dx%d\n",inputSize.width,inputSize.height); LOG_INFO(stdout,"InputName:%s\n",inputName.c_str()); } // 设置模型为GPU模式 migraphx::target gpuTarget = migraphx::gpu::target{}; // 编译模型 migraphx::compile_options options; options.device_id=0; // 设置GPU设备,默认为0号设备 options.offload_copy=true; net.compile(gpuTarget,options); LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); // warm up std::unordered_map inputData; inputData[inputName]=migraphx::argument{inputShape}; net.eval(inputData); return SUCCESS; } ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector &resultsChar, bool raw, bool dynamic) { if(srcImage.empty() || srcImage.type()!=CV_8UC3) { LOG_ERROR(stdout, "image error!\n"); return IMAGE_ERROR; } cv::Mat inputImage, inputBlob; cv::cvtColor(srcImage, inputImage, CV_BGR2GRAY); int height, width, widthRaw; widthRaw = inputImage.cols; if(dynamic) { cv::resize(inputImage, inputImage, cv::Size(widthRaw, 32)); height = inputImage.rows, width = inputImage.cols; } else { cv::resize(inputImage, inputImage, cv::Size(100, 32)); height = inputImage.rows, width = inputImage.cols; } inputBlob = cv::dnn::blobFromImage(inputImage); for(int i=0; i inputData; if(dynamic) { std::vector dynamicShape = {1, 1, 32, width}; inputData[inputName]= migraphx::argument{migraphx::shape(inputShape.type(),dynamicShape), (float*)inputBlob.data}; } else { inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data}; } // 推理 std::vector inferenceResults = net.eval(inputData); // 获取推理结果 std::vector outs; migraphx::argument result = inferenceResults[0]; // 转换为cv::Mat migraphx::shape outputShape = result.get_shape(); int shape[]={outputShape.lens()[0],outputShape.lens()[1],outputShape.lens()[2]}; cv::Mat out(3,shape,CV_32F); memcpy(out.data,result.data(),sizeof(float)*outputShape.elements()); outs.push_back(out); std::vector predChars; const std::string alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz"; //获取字符索引序列 for(uint i = 0; i < outs[0].size[0]; i++) { cv::Mat scores = cv::Mat(1,outs[0].size[2],CV_32F,outs[0].ptr(i)); cv::Point charIdPoint; double maxCharScore; cv::minMaxLoc(scores, 0, &maxCharScore, 0, &charIdPoint); int maxIdx = charIdPoint.x; predChars.push_back(maxIdx); } //字符转录处理 for(uint i=0; i 0 && predChars[i-1]==predChars[i])) { resultsChar.push_back(alphabet[predChars[i]]); } } } } return SUCCESS; } }