#include "CrnnNet.h" #include "OcrUtils.h" #include #include void CrnnNet::setGpuIndex(int gpuIndex) { } CrnnNet::~CrnnNet() { delete session; inputNamesPtr.clear(); outputNamesPtr.clear(); } void CrnnNet::setNumThread(int numOfThread) { numThread = numOfThread; sessionOptions.SetInterOpNumThreads(numThread); } void CrnnNet::initModel(const std::string &pathStr, const std::string &keysPath) { //设置DCU OrtROCMProviderOptions rocm_options; rocm_options.device_id = 0; sessionOptions.AppendExecutionProvider_ROCM(rocm_options); sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC); session = new Ort::Session(env, pathStr.c_str(), sessionOptions); inputNamesPtr = getInputNames(session); outputNamesPtr = getOutputNames(session); //load keys std::ifstream in(keysPath.c_str()); std::string line; if (in) { while (getline(in, line)) {// line中不包括每行的换行符 keys.push_back(line); } } else { printf("The keys.txt file was not found\n"); return; } keys.insert(keys.begin(), "#"); keys.emplace_back(" "); printf("total keys size(%lu)\n", keys.size()); } template inline static size_t argmax(ForwardIterator first, ForwardIterator last) { return std::distance(first, std::max_element(first, last)); } TextLine CrnnNet::scoreToTextLine(const std::vector &outputData, size_t h, size_t w) { auto keySize = keys.size(); auto dataSize = outputData.size(); std::string strRes; std::vector scores; size_t lastIndex = 0; size_t maxIndex; float maxValue; for (size_t i = 0; i < h; i++) { size_t start = i * w; size_t stop = (i + 1) * w; if (stop > dataSize - 1) { stop = (i + 1) * w - 1; } maxIndex = int(argmax(&outputData[start], &outputData[stop])); maxValue = float(*std::max_element(&outputData[start], &outputData[stop])); if (maxIndex > 0 && maxIndex < keySize && (!(i > 0 && maxIndex == lastIndex))) { scores.emplace_back(maxValue); strRes.append(keys[maxIndex]); } lastIndex = maxIndex; } return {strRes, scores}; } TextLine CrnnNet::getTextLine(const cv::Mat &src) { float scale = (float) dstHeight / (float) src.rows; int dstWidth = int((float) src.cols * scale); cv::Mat srcResize; resize(src, srcResize, cv::Size(dstWidth, dstHeight)); std::vector inputTensorValues = substractMeanNormalize(srcResize, meanValues, normValues); std::array inputShape{1, srcResize.channels(), srcResize.rows, srcResize.cols}; auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::Value inputTensor = Ort::Value::CreateTensor(memoryInfo, inputTensorValues.data(), inputTensorValues.size(), inputShape.data(), inputShape.size()); assert(inputTensor.IsTensor()); std::vector inputNames = {inputNamesPtr.data()->get()}; std::vector outputNames = {outputNamesPtr.data()->get()}; auto outputTensor = session->Run(Ort::RunOptions{nullptr}, inputNames.data(), &inputTensor, inputNames.size(), outputNames.data(), outputNames.size()); assert(outputTensor.size() == 1 && outputTensor.front().IsTensor()); std::vector outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies()); float *floatArray = outputTensor.front().GetTensorMutableData(); std::vector outputData(floatArray, floatArray + outputCount); return scoreToTextLine(outputData, outputShape[1], outputShape[2]); } std::vector CrnnNet::getTextLines(std::vector &partImg, const char *path, const char *imgName) { int size = partImg.size(); std::vector textLines(size); for (int i = 0; i < size; ++i) { //OutPut DebugImg if (isOutputDebugImg) { std::string debugImgFile = getDebugImgFilePath(path, imgName, i, "-debug-"); saveImg(partImg[i], debugImgFile.c_str()); } //getTextLine double startCrnnTime = getCurrentTime(); TextLine textLine = getTextLine(partImg[i]); double endCrnnTime = getCurrentTime(); textLine.time = endCrnnTime - startCrnnTime; textLines[i] = textLine; } return textLines; }