#include "CrnnNet.h" #include "OcrUtils.h" #include #include void CrnnNet::setGpuIndex(int gpuIndex) { #ifdef __CUDA__ if (gpuIndex >= 0) { OrtCUDAProviderOptions cuda_options; cuda_options.device_id = gpuIndex; cuda_options.arena_extend_strategy = 0; cuda_options.gpu_mem_limit = 2 * 1024 * 1024 * 1024; cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive; cuda_options.do_copy_in_default_stream = 1; sessionOptions.AppendExecutionProvider_CUDA(cuda_options); printf("rec try to use GPU%d\n", gpuIndex); } else { printf("rec use CPU\n"); } #endif } CrnnNet::~CrnnNet() { delete session; for (auto name : inputNames) { #ifdef _WIN32 _aligned_free(name); #else free(name); #endif } inputNames.clear(); for (auto name : outputNames) { #ifdef _WIN32 _aligned_free(name); #else free(name); #endif } outputNames.clear(); } void CrnnNet::setNumThread(int numOfThread) { numThread = numOfThread; //===session options=== // Sets the number of threads used to parallelize the execution within nodes // A value of 0 means ORT will pick a default //sessionOptions.SetIntraOpNumThreads(numThread); //set OMP_NUM_THREADS=16 // Sets the number of threads used to parallelize the execution of the graph (across nodes) // If sequential execution is enabled this value is ignored // A value of 0 means ORT will pick a default sessionOptions.SetInterOpNumThreads(numThread); // Sets graph optimization level // ORT_DISABLE_ALL -> To disable all optimizations // ORT_ENABLE_BASIC -> To enable basic optimizations (Such as redundant node removals) // ORT_ENABLE_EXTENDED -> To enable extended optimizations (Includes level 1 + more complex optimizations like node fusions) // ORT_ENABLE_ALL -> To Enable All possible opitmizations sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); } void CrnnNet::initModel(const std::string &pathStr, const std::string &keysPath) { #ifdef _WIN32 std::wstring crnnPath = strToWstr(pathStr); session = new Ort::Session(env, crnnPath.c_str(), sessionOptions); #else session = new Ort::Session(env, pathStr.c_str(), sessionOptions); #endif inputNames = getInputNames(session); outputNames = getOutputNames(session); //load keys std::ifstream in(keysPath.c_str()); std::string line; if (in) { while (getline(in, line)) {// line中不包括每行的换行符 keys.push_back(line); } } else { printf("The keys.txt file was not found\n"); return; } keys.insert(keys.begin(), "#"); keys.emplace_back(" "); printf("total keys size(%lu)\n", keys.size()); } template inline static size_t argmax(ForwardIterator first, ForwardIterator last) { return std::distance(first, std::max_element(first, last)); } TextLine CrnnNet::scoreToTextLine(const std::vector &outputData, int h, int w) { auto keySize = keys.size(); auto dataSize = outputData.size(); std::string strRes; std::vector scores; int lastIndex = 0; int maxIndex; float maxValue; for (int i = 0; i < h; i++) { int start = i * w; int stop = (i + 1) * w; if (stop > dataSize - 1) { stop = (i + 1) * w - 1; } maxIndex = int(argmax(&outputData[start], &outputData[stop])); maxValue = float(*std::max_element(&outputData[start], &outputData[stop])); if (maxIndex > 0 && maxIndex < keySize && (!(i > 0 && maxIndex == lastIndex))) { scores.emplace_back(maxValue); strRes.append(keys[maxIndex]); } lastIndex = maxIndex; } return {strRes, scores}; } TextLine CrnnNet::getTextLine(const cv::Mat &src) { float scale = (float) dstHeight / (float) src.rows; int dstWidth = int((float) src.cols * scale); cv::Mat srcResize; resize(src, srcResize, cv::Size(dstWidth, dstHeight)); std::vector inputTensorValues = substractMeanNormalize(srcResize, meanValues, normValues); std::array inputShape{1, srcResize.channels(), srcResize.rows, srcResize.cols}; auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::Value inputTensor = Ort::Value::CreateTensor(memoryInfo, inputTensorValues.data(), inputTensorValues.size(), inputShape.data(), inputShape.size()); assert(inputTensor.IsTensor()); auto outputTensor = session->Run(Ort::RunOptions{nullptr}, inputNames.data(), &inputTensor, inputNames.size(), outputNames.data(), outputNames.size()); assert(outputTensor.size() == 1 && outputTensor.front().IsTensor()); std::vector outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies()); float *floatArray = outputTensor.front().GetTensorMutableData(); std::vector outputData(floatArray, floatArray + outputCount); return scoreToTextLine(outputData, outputShape[1], outputShape[2]); } std::vector CrnnNet::getTextLines(std::vector &partImg, const char *path, const char *imgName) { int size = partImg.size(); std::vector textLines(size); for (int i = 0; i < size; ++i) { //OutPut DebugImg if (isOutputDebugImg) { std::string debugImgFile = getDebugImgFilePath(path, imgName, i, "-debug-"); saveImg(partImg[i], debugImgFile.c_str()); } //getTextLine double startCrnnTime = getCurrentTime(); TextLine textLine = getTextLine(partImg[i]); double endCrnnTime = getCurrentTime(); textLine.time = endCrnnTime - startCrnnTime; textLines[i] = textLine; } return textLines; }