#include #include #include #include #include #include #include #include #include #include #include using namespace cv::dnn; namespace migraphxSamples { Crnn::Crnn():logFile(NULL) { } Crnn::~Crnn() { configurationFile.release(); } ErrorCode Crnn::Initialize(InitializationParameterOfOcr initializationParameterOfOcr, bool dynamic) { // 初始化(获取日志文件,加载配置文件等) ErrorCode errorCode = DoCommonInitialization(initializationParameterOfOcr); if(errorCode!=SUCCESS) { LOG_ERROR(logFile,"fail to DoCommonInitialization\n"); return errorCode; } LOG_INFO(logFile,"succeed to DoCommonInitialization\n"); // 获取配置文件参数 std::string modelPath; FileNode netNode = configurationFile["CrnnDynamic"]; modelPath = initializationParameter.parentPath + (string)netNode["ModelPath"]; // 加载模型 if(Exists(modelPath)==false) { LOG_ERROR(logFile,"%s not exist!\n",modelPath.c_str()); return MODEL_NOT_EXIST; } if(dynamic) { migraphx::onnx_options onnx_options; onnx_options.map_input_dims["input"]={1,1,32,512}; net = migraphx::parse_onnx(modelPath, onnx_options); LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); // 获取模型输入属性 std::pair inputAttribute=*(net.get_parameter_shapes().begin()); inputName=inputAttribute.first; inputShape=inputAttribute.second; inputSize=cv::Size(inputShape.lens()[3],inputShape.lens()[2]); // log输出日志信息 LOG_INFO(logFile,"InputMaxSize:%dx%d\n",inputSize.width,inputSize.height); LOG_INFO(logFile,"InputName:%s\n",inputName.c_str()); } else { migraphx::onnx_options onnx_options; onnx_options.map_input_dims["input"]={1,1,32,100}; net = migraphx::parse_onnx(modelPath, onnx_options); LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); // 获取模型输入属性 std::pair inputAttribute=*(net.get_parameter_shapes().begin()); inputName=inputAttribute.first; inputShape=inputAttribute.second; inputSize=cv::Size(inputShape.lens()[3],inputShape.lens()[2]); // log输出日志信息 LOG_INFO(logFile,"InputSize:%dx%d\n",inputSize.width,inputSize.height); LOG_INFO(logFile,"InputName:%s\n",inputName.c_str()); } // 设置模型为GPU模式 migraphx::target gpuTarget = migraphx::gpu::target{}; // 编译模型 migraphx::compile_options options; options.device_id=0; // 设置GPU设备,默认为0号设备 options.offload_copy=true; // 设置offload_copy net.compile(gpuTarget,options); LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); // Run once by itself migraphx::parameter_map inputData; inputData[inputName]=migraphx::generate_argument(inputShape); net.eval(inputData); return SUCCESS; } ErrorCode Crnn::Infer(const cv::Mat &srcImage, std::vector &resultsChar, bool raw, bool dynamic) { if(srcImage.empty() || srcImage.type()!=CV_8UC3) { LOG_ERROR(logFile, "image error!\n"); return IMAGE_ERROR; } cv::Mat inputImage, inputBlob; cv::cvtColor(srcImage, inputImage, CV_BGR2GRAY); int height, width, widthRaw; widthRaw = inputImage.cols; if(dynamic) { cv::resize(inputImage, inputImage, cv::Size(widthRaw, 32)); height = inputImage.rows, width = inputImage.cols; } else { cv::resize(inputImage, inputImage, cv::Size(100, 32)); height = inputImage.rows, width = inputImage.cols; } inputBlob = cv::dnn::blobFromImage(inputImage); for(int i=0; i dynamicShape = {1, 1, 32, width}; inputData[inputName]= migraphx::argument{migraphx::shape(inputShape.type(),dynamicShape), (float*)inputBlob.data}; } else { inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data}; } // 推理 std::vector inferenceResults = net.eval(inputData); // 获取推理结果 std::vector outs; migraphx::argument result = inferenceResults[0]; // 转换为cv::Mat migraphx::shape outputShape = result.get_shape(); int shape[]={outputShape.lens()[0],outputShape.lens()[1],outputShape.lens()[2]}; cv::Mat out(3,shape,CV_32F); memcpy(out.data,result.data(),sizeof(float)*outputShape.elements()); outs.push_back(out); std::vector predChars; const std::string alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz"; //获取字符索引序列 for(uint i = 0; i < outs[0].size[0]; i++) { cv::Mat scores = Mat(1,outs[0].size[2],CV_32F,outs[0].ptr(i)); cv::Point charIdPoint; double maxCharScore; cv::minMaxLoc(scores, 0, &maxCharScore, 0, &charIdPoint); int maxIdx = charIdPoint.x; predChars.push_back(maxIdx); } //字符转录处理 for(uint i=0; i 0 && predChars[i-1]==predChars[i])) { resultsChar.push_back(alphabet[predChars[i]]); } } } } return SUCCESS; } ErrorCode Crnn::DoCommonInitialization(InitializationParameterOfOcr initializationParameterOfOcr) { initializationParameter=initializationParameterOfOcr; // 获取日志文件 logFile=LogManager::GetInstance()->GetLogFile(initializationParameter.logName); // 加载配置文件 std::string configFilePath=initializationParameter.configFilePath; if(!Exists(configFilePath)) { LOG_ERROR(logFile, "no configuration file!\n"); return CONFIG_FILE_NOT_EXIST; } if(!configurationFile.open(configFilePath, FileStorage::READ)) { LOG_ERROR(logFile, "fail to open configuration file\n"); return FAIL_TO_OPEN_CONFIG_FILE; } LOG_INFO(logFile, "succeed to open configuration file\n"); // 修改父路径 std::string &parentPath = initializationParameter.parentPath; if (!parentPath.empty()) { if(!IsPathSeparator(parentPath[parentPath.size() - 1])) { parentPath+=PATH_SEPARATOR; } } return SUCCESS; } }