#include #include #include #include #include #include #include #include #include #include #include using namespace cv::dnn; namespace migraphxSamples { SVTR::SVTR():logFile(NULL) { } SVTR::~SVTR() { configurationFile.release(); } ErrorCode SVTR::Initialize(InitializationParameterOfSVTR InitializationParameterOfSVTR) { // 初始化(获取日志文件,加载配置文件等) ErrorCode errorCode = DoCommonInitialization(InitializationParameterOfSVTR); if (errorCode!=SUCCESS) { LOG_ERROR(logFile, "fail to DoCommonInitialization\n"); return errorCode; } LOG_INFO(logFile, "success to DoCommonInitialization\n"); // 获取配置文件参数 FileNode netNode = configurationFile["OcrSVTR"]; string modelPath = initializationParameter.parentPath + (string)netNode["ModelPath"]; string dictPath = (string)netNode["DictPath"]; // 加载模型 if(Exists(modelPath)==false) { LOG_ERROR(logFile,"%s not exist!\n",modelPath.c_str()); return MODEL_NOT_EXIST; } migraphx::onnx_options onnx_options; onnx_options.map_input_dims["x"]={1,3,48,320}; net = migraphx::parse_onnx(modelPath, onnx_options); LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); // 获取模型输入属性 std::pair inputAttribute=*(net.get_parameter_shapes().begin()); inputName=inputAttribute.first; inputShape=inputAttribute.second; inputSize=cv::Size(inputShape.lens()[3],inputShape.lens()[2]); // 设置模型为GPU模式 migraphx::target gpuTarget = migraphx::gpu::target{}; // 编译模型 migraphx::compile_options options; options.device_id=0; // 设置GPU设备,默认为0号设备 options.offload_copy=true; // 设置offload_copy net.compile(gpuTarget,options); LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); // Run once by itself migraphx::parameter_map inputData; inputData[inputName]=migraphx::generate_argument(inputShape); net.eval(inputData); std::ifstream in(dictPath); std::string line; if (in) { while (getline(in, line)) { charactorDict.push_back(line); } charactorDict.insert(charactorDict.begin(), "#"); charactorDict.push_back(" "); } else { std::cout << "no such label file: " << dictPath << ", exit the program..." << std::endl; exit(1); } // log LOG_INFO(logFile,"InputMaxSize:%dx%d\n",inputSize.width,inputSize.height); LOG_INFO(logFile,"InputName:%s\n",inputName.c_str()); return SUCCESS; } ErrorCode SVTR::Infer(cv::Mat &img, std::string &resultsChar, float &resultsdScore, float &maxWHRatio) { if(img.empty()||img.type()!=CV_8UC3) { LOG_ERROR(logFile, "image error!\n"); return IMAGE_ERROR; } cv::Mat srcImage; cv::Mat resizeImg; img.copyTo(srcImage); float ratio = 1.f; int imgC = 3, imgH = 48; int resizeW; int imgW = int((48 * maxWHRatio)); ratio = float(srcImage.cols) / float(srcImage.rows); if (ceil(imgH * ratio) > imgW) { resizeW = imgW; } else { resizeW = int(ceil(imgH * ratio)); } cv::resize(srcImage, resizeImg, cv::Size(resizeW, imgH)); cv::copyMakeBorder(resizeImg, resizeImg, 0, 0, 0, int(imgW - resizeImg.cols), cv::BORDER_CONSTANT, {127, 127, 127}); resizeImg.convertTo(resizeImg, CV_32FC3, 1.0/255.0); std::vector bgrChannels(3); cv::split(resizeImg, bgrChannels); std::vector mean = {0.485f, 0.456f, 0.406f}; std::vector scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; for (auto i = 0; i < bgrChannels.size(); i++) { bgrChannels[i].convertTo(bgrChannels[i], CV_32FC1, 1.0 * scale[i], (0.0 - mean[i]) * scale[i]); } cv::merge(bgrChannels, resizeImg); cv::Mat inputBlob = cv::dnn::blobFromImage(resizeImg); std::vector inputShapeOfInfer={1,3,48,resizeW}; // 输入数据 migraphx::parameter_map inputData; inputData[inputName]= migraphx::argument{migraphx::shape(inputShape.type(),inputShapeOfInfer), (float*)inputBlob.data}; // 推理 std::vector inferenceResults = net.eval(inputData); // 获取推理结果 migraphx::argument result = inferenceResults[0]; migraphx::shape outputShape = result.get_shape(); int n2 = outputShape.lens()[1]; int n3 = outputShape.lens()[2]; int n = n2 * n3; std::vector out(n); memcpy(out.data(),result.data(),sizeof(float)*outputShape.elements()); out.resize(n); int argmaxIdx; int lastIndex = 0; float score = 0.f; int count = 0; float maxValue = 0.0f; for (int j = 0; j < n2; j++) { argmaxIdx = int(std::distance(&out[(j) * n3], std::max_element(&out[(j) * n3], &out[(j + 1) * n3]))); maxValue = float(*std::max_element(&out[(j) * n3], &out[(j + 1) * n3])); if (argmaxIdx > 0 && (!(n > 0 && argmaxIdx == lastIndex))) { score += maxValue; count += 1; resultsChar += charactorDict[argmaxIdx]; } lastIndex = argmaxIdx; } resultsdScore = score / count; return SUCCESS; } ErrorCode SVTR::DoCommonInitialization(InitializationParameterOfSVTR InitializationParameterOfSVTR) { initializationParameter = InitializationParameterOfSVTR; // 获取日志文件 logFile=LogManager::GetInstance()->GetLogFile(initializationParameter.logName); // 加载配置文件 std::string configFilePath=initializationParameter.configFilePath; if(!Exists(configFilePath)) { LOG_ERROR(logFile, "no configuration file!\n"); return CONFIG_FILE_NOT_EXIST; } if(!configurationFile.open(configFilePath, FileStorage::READ)) { LOG_ERROR(logFile, "fail to open configuration file\n"); return FAIL_TO_OPEN_CONFIG_FILE; } LOG_INFO(logFile, "succeed to open configuration file\n"); // 修改父路径 std::string &parentPath = initializationParameter.parentPath; if (!parentPath.empty()) { if(!IsPathSeparator(parentPath[parentPath.size() - 1])) { parentPath+=PATH_SEPARATOR; } } return SUCCESS; } }