#include "DbNet.h" #include "OcrUtils.h" void DbNet::setGpuIndex(int gpuIndex) { } DbNet::~DbNet() { delete session; inputNamesPtr.clear(); outputNamesPtr.clear(); } void DbNet::setNumThread(int numOfThread) { numThread = numOfThread; sessionOptions.SetInterOpNumThreads(numThread); } void DbNet::initModel(const std::string &pathStr) { //设置DCU OrtROCMProviderOptions rocm_options; rocm_options.device_id = 0; sessionOptions.AppendExecutionProvider_ROCM(rocm_options); sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC); session = new Ort::Session(env, pathStr.c_str(), sessionOptions); inputNamesPtr = getInputNames(session); outputNamesPtr = getOutputNames(session); } std::vector findRsBoxes(const cv::Mat &predMat, const cv::Mat &dilateMat, ScaleParam &s, const float boxScoreThresh, const float unClipRatio) { const int longSideThresh = 3;//minBox 长边门限 const int maxCandidates = 1000; std::vector> contours; std::vector hierarchy; cv::findContours(dilateMat, contours, hierarchy, cv::RETR_LIST, cv::CHAIN_APPROX_SIMPLE); size_t numContours = contours.size() >= maxCandidates ? maxCandidates : contours.size(); std::vector rsBoxes; for (size_t i = 0; i < numContours; i++) { if (contours[i].size() <= 2) { continue; } cv::RotatedRect minAreaRect = cv::minAreaRect(contours[i]); float longSide; std::vector minBoxes = getMinBoxes(minAreaRect, longSide); if (longSide < longSideThresh) { continue; } float boxScore = boxScoreFast(minBoxes, predMat); if (boxScore < boxScoreThresh) continue; //-----unClip----- cv::RotatedRect clipRect = unClip(minBoxes, unClipRatio); if (clipRect.size.height < 1.001 && clipRect.size.width < 1.001) { continue; } //-----unClip----- std::vector clipMinBoxes = getMinBoxes(clipRect, longSide); if (longSide < longSideThresh + 2) continue; std::vector intClipMinBoxes; for (auto &clipMinBox: clipMinBoxes) { float x = clipMinBox.x / s.ratioWidth; float y = clipMinBox.y / s.ratioHeight; int ptX = (std::min)((std::max)(int(x), 0), s.srcWidth - 1); int ptY = (std::min)((std::max)(int(y), 0), s.srcHeight - 1); cv::Point point{ptX, ptY}; intClipMinBoxes.push_back(point); } rsBoxes.push_back(TextBox{intClipMinBoxes, boxScore}); } reverse(rsBoxes.begin(), rsBoxes.end()); return rsBoxes; } std::vector DbNet::getTextBoxes(cv::Mat &src, ScaleParam &s, float boxScoreThresh, float boxThresh, float unClipRatio) { cv::Mat srcResize; resize(src, srcResize, cv::Size(s.dstWidth, s.dstHeight)); std::vector inputTensorValues = substractMeanNormalize(srcResize, meanValues, normValues); std::array inputShape{1, srcResize.channels(), srcResize.rows, srcResize.cols}; auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::Value inputTensor = Ort::Value::CreateTensor(memoryInfo, inputTensorValues.data(), inputTensorValues.size(), inputShape.data(), inputShape.size()); assert(inputTensor.IsTensor()); std::vector inputNames = {inputNamesPtr.data()->get()}; std::vector outputNames = {outputNamesPtr.data()->get()}; auto outputTensor = session->Run(Ort::RunOptions{nullptr}, inputNames.data(), &inputTensor, inputNames.size(), outputNames.data(), outputNames.size()); assert(outputTensor.size() == 1 && outputTensor.front().IsTensor()); std::vector outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies()); float *floatArray = outputTensor.front().GetTensorMutableData(); std::vector outputData(floatArray, floatArray + outputCount); //-----Data preparation----- int outHeight = (int) outputShape[2]; int outWidth = (int) outputShape[3]; size_t area = outHeight * outWidth; std::vector predData(area, 0.0); std::vector cbufData(area, ' '); for (int i = 0; i < area; i++) { predData[i] = float(outputData[i]); cbufData[i] = (unsigned char) ((outputData[i]) * 255); } cv::Mat predMat(outHeight, outWidth, CV_32F, (float *) predData.data()); cv::Mat cBufMat(outHeight, outWidth, CV_8UC1, (unsigned char *) cbufData.data()); //-----boxThresh----- const double maxValue = 255; const double threshold = boxThresh * 255; cv::Mat thresholdMat; cv::threshold(cBufMat, thresholdMat, threshold, maxValue, cv::THRESH_BINARY); //-----dilate----- cv::Mat dilateMat; cv::Mat dilateElement = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2)); cv::dilate(thresholdMat, dilateMat, dilateElement); return findRsBoxes(predMat, dilateMat, s, boxScoreThresh, unClipRatio); }