Commit 83303bc7 authored by LDOUBLEV's avatar LDOUBLEV
Browse files

fix conflicts

parents 3af943f3 af0bac58
//
// Created by fujiayi on 2020/7/1.
//
#pragma once
#include "ppredictor.h"
#include <opencv2/opencv.hpp>
#include <paddle_api.h>
#include <string>
namespace ppredictor {
/**
* Config
*/
struct OCR_Config {
int thread_num = 4; // Thread num
paddle::lite_api::PowerMode mode =
paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode
};
/**
* PolyGone Result
*/
struct OCRPredictResult {
std::vector<int> word_index;
std::vector<std::vector<int>> points;
float score;
};
/**
* OCR there are 2 models
* 1. First model(det),select polygones to show where are the texts
* 2. crop from the origin images, use these polygones to infer
*/
class OCR_PPredictor : public PPredictor_Interface {
public:
OCR_PPredictor(const OCR_Config &config);
virtual ~OCR_PPredictor() {}
/**
* 初始化二个模型的Predictor
* @param det_model_content
* @param rec_model_content
* @return
*/
int init(const std::string &det_model_content,
const std::string &rec_model_content,
const std::string &cls_model_content);
int init_from_file(const std::string &det_model_path,
const std::string &rec_model_path,
const std::string &cls_model_path);
/**
* Return OCR result
* @param dims
* @param input_data
* @param input_len
* @param net_flag
* @param origin
* @return
*/
virtual std::vector<OCRPredictResult>
infer_ocr(const std::vector<int64_t> &dims, const float *input_data,
int input_len, int net_flag, cv::Mat &origin);
virtual NET_TYPE get_net_flag() const;
private:
/**
* calcul Polygone from the result image of first model
* @param pred
* @param output_height
* @param output_width
* @param origin
* @return
*/
std::vector<std::vector<std::vector<int>>>
calc_filtered_boxes(const float *pred, int pred_size, int output_height,
int output_width, const cv::Mat &origin);
/**
* infer for second model
*
* @param boxes
* @param origin
* @return
*/
std::vector<OCRPredictResult>
infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes,
const cv::Mat &origin);
/**
* infer for cls model
*
* @param boxes
* @param origin
* @return
*/
cv::Mat infer_cls(const cv::Mat &origin, float thresh = 0.9);
/**
* Postprocess or sencod model to extract text
* @param res
* @return
*/
std::vector<int> postprocess_rec_word_index(const PredictorOutput &res);
/**
* calculate confidence of second model text result
* @param res
* @return
*/
float postprocess_rec_score(const PredictorOutput &res);
std::unique_ptr<PPredictor> _det_predictor;
std::unique_ptr<PPredictor> _rec_predictor;
std::unique_ptr<PPredictor> _cls_predictor;
OCR_Config _config;
};
}
#include "ppredictor.h"
#include "common.h"
namespace ppredictor {
PPredictor::PPredictor(int thread_num, int net_flag,
paddle::lite_api::PowerMode mode)
: _thread_num(thread_num), _net_flag(net_flag), _mode(mode) {}
int PPredictor::init_nb(const std::string &model_content) {
paddle::lite_api::MobileConfig config;
config.set_model_from_buffer(model_content);
return _init(config);
}
int PPredictor::init_from_file(const std::string &model_content) {
paddle::lite_api::MobileConfig config;
config.set_model_from_file(model_content);
return _init(config);
}
template <typename ConfigT> int PPredictor::_init(ConfigT &config) {
config.set_threads(_thread_num);
config.set_power_mode(_mode);
_predictor = paddle::lite_api::CreatePaddlePredictor(config);
LOGI("paddle instance created");
return RETURN_OK;
}
PredictorInput PPredictor::get_input(int index) {
PredictorInput input{_predictor->GetInput(index), index, _net_flag};
_is_input_get = true;
return input;
}
std::vector<PredictorInput> PPredictor::get_inputs(int num) {
std::vector<PredictorInput> results;
for (int i = 0; i < num; i++) {
results.emplace_back(get_input(i));
}
return results;
}
PredictorInput PPredictor::get_first_input() { return get_input(0); }
std::vector<PredictorOutput> PPredictor::infer() {
LOGI("infer Run start %d", _net_flag);
std::vector<PredictorOutput> results;
if (!_is_input_get) {
return results;
}
_predictor->Run();
LOGI("infer Run end");
for (int i = 0; i < _predictor->GetOutputNames().size(); i++) {
std::unique_ptr<const paddle::lite_api::Tensor> output_tensor =
_predictor->GetOutput(i);
LOGI("output tensor[%d] size %ld", i, product(output_tensor->shape()));
PredictorOutput result{std::move(output_tensor), i, _net_flag};
results.emplace_back(std::move(result));
}
return results;
}
NET_TYPE PPredictor::get_net_flag() const { return (NET_TYPE)_net_flag; }
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment