"src/vscode:/vscode.git/clone" did not exist on "91851114e4fc6405397baf46f121697a3913c85b"
Unverified Commit 3df64d50 authored by zhoujun's avatar zhoujun Committed by GitHub
Browse files

Merge pull request #5576 from WenmuZhou/android

[android demo] Fix error when run same image
parents 816be4d7 056b7606
# 如何快速测试 - [Android Demo](#android-demo)
### 1. 安装最新版本的Android Studio - [1. 简介](#1-简介)
- [2. 近期更新](#2-近期更新)
- [3. 快速使用](#3-快速使用)
- [3.1 安装最新版本的Android Studio](#31-安装最新版本的android-studio)
- [3.2 安装 NDK 20 以上版本](#32-安装-ndk-20-以上版本)
- [3.3 导入项目](#33-导入项目)
- [4 更多支持](#4-更多支持)
# Android Demo
## 1. 简介
此为PaddleOCR的Android Demo,目前支持文本检测,文本方向分类器和文本识别模型的使用。使用 [PaddleLite v2.10](https://github.com/PaddlePaddle/Paddle-Lite/tree/release/v2.10) 进行开发。
## 2. 近期更新
* 2022.02.27
* 预测库更新到PaddleLite v2.10
* 支持6种运行模式:
* 检测+分类+识别
* 检测+识别
* 分类+识别
* 检测
* 识别
* 分类
## 3. 快速使用
### 3.1 安装最新版本的Android Studio
可以从 https://developer.android.com/studio 下载。本Demo使用是4.0版本Android Studio编写。 可以从 https://developer.android.com/studio 下载。本Demo使用是4.0版本Android Studio编写。
### 2. 按照NDK 20 以上版本 ### 3.2 安装 NDK 20 以上版本
Demo测试的时候使用的是NDK 20b版本,20版本以上均可以支持编译成功。 Demo测试的时候使用的是NDK 20b版本,20版本以上均可以支持编译成功。
如果您是初学者,可以用以下方式安装和测试NDK编译环境。 如果您是初学者,可以用以下方式安装和测试NDK编译环境。
点击 File -> New ->New Project, 新建 "Native C++" project 点击 File -> New ->New Project, 新建 "Native C++" project
### 3. 导入项目 ### 3.3 导入项目
点击 File->New->Import Project..., 然后跟着Android Studio的引导导入 点击 File->New->Import Project..., 然后跟着Android Studio的引导导入
## 4 更多支持
# 获得更多支持 前往[Paddle-Lite](https://github.com/PaddlePaddle/Paddle-Lite),获得更多开发支持
前往[端计算模型生成平台EasyEdge](https://ai.baidu.com/easyedge/app/open_source_demo?referrerUrl=paddlelite),获得更多开发支持:
- Demo APP:可使用手机扫码安装,方便手机端快速体验文字识别
- SDK:模型被封装为适配不同芯片硬件和操作系统SDK,包括完善的接口,方便进行二次开发
...@@ -8,8 +8,8 @@ android { ...@@ -8,8 +8,8 @@ android {
applicationId "com.baidu.paddle.lite.demo.ocr" applicationId "com.baidu.paddle.lite.demo.ocr"
minSdkVersion 23 minSdkVersion 23
targetSdkVersion 29 targetSdkVersion 29
versionCode 1 versionCode 2
versionName "1.0" versionName "2.0"
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
externalNativeBuild { externalNativeBuild {
cmake { cmake {
...@@ -17,11 +17,6 @@ android { ...@@ -17,11 +17,6 @@ android {
arguments '-DANDROID_PLATFORM=android-23', '-DANDROID_STL=c++_shared' ,"-DANDROID_ARM_NEON=TRUE" arguments '-DANDROID_PLATFORM=android-23', '-DANDROID_STL=c++_shared' ,"-DANDROID_ARM_NEON=TRUE"
} }
} }
ndk {
// abiFilters "arm64-v8a", "armeabi-v7a"
abiFilters "arm64-v8a", "armeabi-v7a"
ldLibs "jnigraphics"
}
} }
buildTypes { buildTypes {
release { release {
...@@ -48,7 +43,7 @@ dependencies { ...@@ -48,7 +43,7 @@ dependencies {
def archives = [ def archives = [
[ [
'src' : 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/paddle_lite_libs_v2_9_0.tar.gz', 'src' : 'https://paddleocr.bj.bcebos.com/libs/paddle_lite_libs_v2_10.tar.gz',
'dest': 'PaddleLite' 'dest': 'PaddleLite'
], ],
[ [
...@@ -56,7 +51,7 @@ def archives = [ ...@@ -56,7 +51,7 @@ def archives = [
'dest': 'OpenCV' 'dest': 'OpenCV'
], ],
[ [
'src' : 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ocr_v2_for_cpu.tar.gz', 'src' : 'https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2.tar.gz',
'dest' : 'src/main/assets/models' 'dest' : 'src/main/assets/models'
], ],
[ [
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
android:roundIcon="@mipmap/ic_launcher_round" android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true" android:supportsRtl="true"
android:theme="@style/AppTheme"> android:theme="@style/AppTheme">
<!-- to test MiniActivity, change this to com.baidu.paddle.lite.demo.ocr.MiniActivity -->
<activity android:name="com.baidu.paddle.lite.demo.ocr.MainActivity"> <activity android:name="com.baidu.paddle.lite.demo.ocr.MainActivity">
<intent-filter> <intent-filter>
<action android:name="android.intent.action.MAIN"/> <action android:name="android.intent.action.MAIN"/>
......
...@@ -13,7 +13,7 @@ static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode); ...@@ -13,7 +13,7 @@ static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode);
extern "C" JNIEXPORT jlong JNICALL extern "C" JNIEXPORT jlong JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init( Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(
JNIEnv *env, jobject thiz, jstring j_det_model_path, JNIEnv *env, jobject thiz, jstring j_det_model_path,
jstring j_rec_model_path, jstring j_cls_model_path, jint j_thread_num, jstring j_rec_model_path, jstring j_cls_model_path, jint j_use_opencl, jint j_thread_num,
jstring j_cpu_mode) { jstring j_cpu_mode) {
std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path); std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path);
std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path); std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path);
...@@ -21,6 +21,7 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init( ...@@ -21,6 +21,7 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(
int thread_num = j_thread_num; int thread_num = j_thread_num;
std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode); std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode);
ppredictor::OCR_Config conf; ppredictor::OCR_Config conf;
conf.use_opencl = j_use_opencl;
conf.thread_num = thread_num; conf.thread_num = thread_num;
conf.mode = str_to_cpu_mode(cpu_mode); conf.mode = str_to_cpu_mode(cpu_mode);
ppredictor::OCR_PPredictor *orc_predictor = ppredictor::OCR_PPredictor *orc_predictor =
...@@ -57,32 +58,31 @@ str_to_cpu_mode(const std::string &cpu_mode) { ...@@ -57,32 +58,31 @@ str_to_cpu_mode(const std::string &cpu_mode) {
extern "C" JNIEXPORT jfloatArray JNICALL extern "C" JNIEXPORT jfloatArray JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward( Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(
JNIEnv *env, jobject thiz, jlong java_pointer, jfloatArray buf, JNIEnv *env, jobject thiz, jlong java_pointer, jobject original_image,jint j_max_size_len, jint j_run_det, jint j_run_cls, jint j_run_rec) {
jfloatArray ddims, jobject original_image) {
LOGI("begin to run native forward"); LOGI("begin to run native forward");
if (java_pointer == 0) { if (java_pointer == 0) {
LOGE("JAVA pointer is NULL"); LOGE("JAVA pointer is NULL");
return cpp_array_to_jfloatarray(env, nullptr, 0); return cpp_array_to_jfloatarray(env, nullptr, 0);
} }
cv::Mat origin = bitmap_to_cv_mat(env, original_image); cv::Mat origin = bitmap_to_cv_mat(env, original_image);
if (origin.size == 0) { if (origin.size == 0) {
LOGE("origin bitmap cannot convert to CV Mat"); LOGE("origin bitmap cannot convert to CV Mat");
return cpp_array_to_jfloatarray(env, nullptr, 0); return cpp_array_to_jfloatarray(env, nullptr, 0);
} }
int max_size_len = j_max_size_len;
int run_det = j_run_det;
int run_cls = j_run_cls;
int run_rec = j_run_rec;
ppredictor::OCR_PPredictor *ppredictor = ppredictor::OCR_PPredictor *ppredictor =
(ppredictor::OCR_PPredictor *)java_pointer; (ppredictor::OCR_PPredictor *)java_pointer;
std::vector<float> dims_float_arr = jfloatarray_to_float_vector(env, ddims);
std::vector<int64_t> dims_arr; std::vector<int64_t> dims_arr;
dims_arr.resize(dims_float_arr.size());
std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin());
// 这里值有点大,就不调用jfloatarray_to_float_vector了
int64_t buf_len = (int64_t)env->GetArrayLength(buf);
jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE);
float *data = (jfloat *)buf_data;
std::vector<ppredictor::OCRPredictResult> results = std::vector<ppredictor::OCRPredictResult> results =
ppredictor->infer_ocr(dims_arr, data, buf_len, NET_OCR, origin); ppredictor->infer_ocr(origin, max_size_len, run_det, run_cls, run_rec);
LOGI("infer_ocr finished with boxes %ld", results.size()); LOGI("infer_ocr finished with boxes %ld", results.size());
// 这里将std::vector<ppredictor::OCRPredictResult> 序列化成 // 这里将std::vector<ppredictor::OCRPredictResult> 序列化成
// float数组,传输到java层再反序列化 // float数组,传输到java层再反序列化
std::vector<float> float_arr; std::vector<float> float_arr;
...@@ -90,13 +90,18 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward( ...@@ -90,13 +90,18 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(
float_arr.push_back(r.points.size()); float_arr.push_back(r.points.size());
float_arr.push_back(r.word_index.size()); float_arr.push_back(r.word_index.size());
float_arr.push_back(r.score); float_arr.push_back(r.score);
// add det point
for (const std::vector<int> &point : r.points) { for (const std::vector<int> &point : r.points) {
float_arr.push_back(point.at(0)); float_arr.push_back(point.at(0));
float_arr.push_back(point.at(1)); float_arr.push_back(point.at(1));
} }
// add rec word idx
for (int index : r.word_index) { for (int index : r.word_index) {
float_arr.push_back(index); float_arr.push_back(index);
} }
// add cls result
float_arr.push_back(r.cls_label);
float_arr.push_back(r.cls_score);
} }
return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size()); return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size());
} }
......
...@@ -17,15 +17,15 @@ int OCR_PPredictor::init(const std::string &det_model_content, ...@@ -17,15 +17,15 @@ int OCR_PPredictor::init(const std::string &det_model_content,
const std::string &rec_model_content, const std::string &rec_model_content,
const std::string &cls_model_content) { const std::string &cls_model_content) {
_det_predictor = std::unique_ptr<PPredictor>( _det_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR, _config.mode}); new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR, _config.mode});
_det_predictor->init_nb(det_model_content); _det_predictor->init_nb(det_model_content);
_rec_predictor = std::unique_ptr<PPredictor>( _rec_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_nb(rec_model_content); _rec_predictor->init_nb(rec_model_content);
_cls_predictor = std::unique_ptr<PPredictor>( _cls_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_cls_predictor->init_nb(cls_model_content); _cls_predictor->init_nb(cls_model_content);
return RETURN_OK; return RETURN_OK;
} }
...@@ -34,15 +34,16 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path, ...@@ -34,15 +34,16 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path,
const std::string &rec_model_path, const std::string &rec_model_path,
const std::string &cls_model_path) { const std::string &cls_model_path) {
_det_predictor = std::unique_ptr<PPredictor>( _det_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR, _config.mode}); new PPredictor{_config.use_opencl, _config.thread_num, NET_OCR, _config.mode});
_det_predictor->init_from_file(det_model_path); _det_predictor->init_from_file(det_model_path);
_rec_predictor = std::unique_ptr<PPredictor>( _rec_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_from_file(rec_model_path); _rec_predictor->init_from_file(rec_model_path);
_cls_predictor = std::unique_ptr<PPredictor>( _cls_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_cls_predictor->init_from_file(cls_model_path); _cls_predictor->init_from_file(cls_model_path);
return RETURN_OK; return RETURN_OK;
} }
...@@ -77,33 +78,126 @@ visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes, ...@@ -77,33 +78,126 @@ visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes,
} }
std::vector<OCRPredictResult> std::vector<OCRPredictResult>
OCR_PPredictor::infer_ocr(const std::vector<int64_t> &dims, OCR_PPredictor::infer_ocr(cv::Mat &origin,int max_size_len, int run_det, int run_cls, int run_rec) {
const float *input_data, int input_len, int net_flag, LOGI("ocr cpp start *****************");
cv::Mat &origin) { LOGI("ocr cpp det: %d, cls: %d, rec: %d", run_det, run_cls, run_rec);
std::vector<OCRPredictResult> ocr_results;
if(run_det){
infer_det(origin, max_size_len, ocr_results);
}
if(run_rec){
if(ocr_results.size()==0){
OCRPredictResult res;
ocr_results.emplace_back(std::move(res));
}
for(int i = 0; i < ocr_results.size();i++) {
infer_rec(origin, run_cls, ocr_results[i]);
}
}else if(run_cls){
ClsPredictResult cls_res = infer_cls(origin);
OCRPredictResult res;
res.cls_score = cls_res.cls_score;
res.cls_label = cls_res.cls_label;
ocr_results.push_back(res);
}
LOGI("ocr cpp end *****************");
return ocr_results;
}
cv::Mat DetResizeImg(const cv::Mat img, int max_size_len,
std::vector<float> &ratio_hw) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (h > w) {
ratio = static_cast<float>(max_size_len) / static_cast<float>(h);
} else {
ratio = static_cast<float>(max_size_len) / static_cast<float>(w);
}
}
int resize_h = static_cast<int>(float(h) * ratio);
int resize_w = static_cast<int>(float(w) * ratio);
if (resize_h % 32 == 0)
resize_h = resize_h;
else if (resize_h / 32 < 1 + 1e-5)
resize_h = 32;
else
resize_h = (resize_h / 32 - 1) * 32;
if (resize_w % 32 == 0)
resize_w = resize_w;
else if (resize_w / 32 < 1 + 1e-5)
resize_w = 32;
else
resize_w = (resize_w / 32 - 1) * 32;
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_hw.push_back(static_cast<float>(resize_h) / static_cast<float>(h));
ratio_hw.push_back(static_cast<float>(resize_w) / static_cast<float>(w));
return resize_img;
}
void OCR_PPredictor::infer_det(cv::Mat &origin, int max_size_len, std::vector<OCRPredictResult> &ocr_results) {
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
std::vector<float> scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
PredictorInput input = _det_predictor->get_first_input(); PredictorInput input = _det_predictor->get_first_input();
input.set_dims(dims);
input.set_data(input_data, input_len); std::vector<float> ratio_hw;
cv::Mat input_image = DetResizeImg(origin, max_size_len, ratio_hw);
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
const float *dimg = reinterpret_cast<const float *>(input_image.data);
int input_size = input_image.rows * input_image.cols;
input.set_dims({1, 3, input_image.rows, input_image.cols});
neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean,
scale);
LOGI("ocr cpp det shape %d,%d", input_image.rows,input_image.cols);
std::vector<PredictorOutput> results = _det_predictor->infer(); std::vector<PredictorOutput> results = _det_predictor->infer();
PredictorOutput &res = results.at(0); PredictorOutput &res = results.at(0);
std::vector<std::vector<std::vector<int>>> filtered_box = calc_filtered_boxes( std::vector<std::vector<std::vector<int>>> filtered_box = calc_filtered_boxes(
res.get_float_data(), res.get_size(), (int)dims[2], (int)dims[3], origin); res.get_float_data(), res.get_size(), input_image.rows, input_image.cols, origin);
LOGI("Filter_box size %ld", filtered_box.size()); LOGI("ocr cpp det Filter_box size %ld", filtered_box.size());
return infer_rec(filtered_box, origin);
for(int i = 0;i<filtered_box.size();i++){
LOGI("ocr cpp box %d,%d,%d,%d,%d,%d,%d,%d", filtered_box[i][0][0],filtered_box[i][0][1], filtered_box[i][1][0],filtered_box[i][1][1], filtered_box[i][2][0],filtered_box[i][2][1], filtered_box[i][3][0],filtered_box[i][3][1]);
OCRPredictResult res;
res.points = filtered_box[i];
ocr_results.push_back(res);
}
} }
std::vector<OCRPredictResult> OCR_PPredictor::infer_rec( void OCR_PPredictor::infer_rec(const cv::Mat &origin_img, int run_cls, OCRPredictResult& ocr_result) {
const std::vector<std::vector<std::vector<int>>> &boxes,
const cv::Mat &origin_img) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f}; std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
std::vector<int64_t> dims = {1, 3, 0, 0}; std::vector<int64_t> dims = {1, 3, 0, 0};
std::vector<OCRPredictResult> ocr_results;
PredictorInput input = _rec_predictor->get_first_input(); PredictorInput input = _rec_predictor->get_first_input();
for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) {
const std::vector<std::vector<int>> &box = *bp; const std::vector<std::vector<int>> &box = ocr_result.points;
cv::Mat crop_img = get_rotate_crop_image(origin_img, box); cv::Mat crop_img;
crop_img = infer_cls(crop_img); if(box.size()>0){
crop_img = get_rotate_crop_image(origin_img, box);
}
else{
crop_img = origin_img;
}
if(run_cls){
ClsPredictResult cls_res = infer_cls(crop_img);
crop_img = cls_res.img;
ocr_result.cls_score = cls_res.cls_score;
ocr_result.cls_label = cls_res.cls_label;
}
float wh_ratio = float(crop_img.cols) / float(crop_img.rows); float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio); cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio);
...@@ -122,8 +216,6 @@ std::vector<OCRPredictResult> OCR_PPredictor::infer_rec( ...@@ -122,8 +216,6 @@ std::vector<OCRPredictResult> OCR_PPredictor::infer_rec(
const float *predict_batch = results.at(0).get_float_data(); const float *predict_batch = results.at(0).get_float_data();
const std::vector<int64_t> predict_shape = results.at(0).get_shape(); const std::vector<int64_t> predict_shape = results.at(0).get_shape();
OCRPredictResult res;
// ctc decode // ctc decode
int argmax_idx; int argmax_idx;
int last_index = 0; int last_index = 0;
...@@ -140,27 +232,19 @@ std::vector<OCRPredictResult> OCR_PPredictor::infer_rec( ...@@ -140,27 +232,19 @@ std::vector<OCRPredictResult> OCR_PPredictor::infer_rec(
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
score += max_value; score += max_value;
count += 1; count += 1;
res.word_index.push_back(argmax_idx); ocr_result.word_index.push_back(argmax_idx);
} }
last_index = argmax_idx; last_index = argmax_idx;
} }
score /= count; score /= count;
if (res.word_index.empty()) { ocr_result.score = score;
continue; LOGI("ocr cpp rec word size %ld", count);
}
res.score = score;
res.points = box;
ocr_results.emplace_back(std::move(res));
}
LOGI("ocr_results finished %lu", ocr_results.size());
return ocr_results;
} }
cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { ClsPredictResult OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f}; std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
std::vector<int64_t> dims = {1, 3, 0, 0}; std::vector<int64_t> dims = {1, 3, 0, 0};
std::vector<OCRPredictResult> ocr_results;
PredictorInput input = _cls_predictor->get_first_input(); PredictorInput input = _cls_predictor->get_first_input();
...@@ -182,7 +266,7 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { ...@@ -182,7 +266,7 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
float score = 0; float score = 0;
int label = 0; int label = 0;
for (int64_t i = 0; i < results.at(0).get_size(); i++) { for (int64_t i = 0; i < results.at(0).get_size(); i++) {
LOGI("output scores [%f]", scores[i]); LOGI("ocr cpp cls output scores [%f]", scores[i]);
if (scores[i] > score) { if (scores[i] > score) {
score = scores[i]; score = scores[i];
label = i; label = i;
...@@ -193,7 +277,12 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { ...@@ -193,7 +277,12 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
if (label % 2 == 1 && score > thresh) { if (label % 2 == 1 && score > thresh) {
cv::rotate(srcimg, srcimg, 1); cv::rotate(srcimg, srcimg, 1);
} }
return srcimg; ClsPredictResult res;
res.cls_label = label;
res.cls_score = score;
res.img = srcimg;
LOGI("ocr cpp cls word cls %ld, %f", label, score);
return res;
} }
std::vector<std::vector<std::vector<int>>> std::vector<std::vector<std::vector<int>>>
......
...@@ -15,6 +15,7 @@ namespace ppredictor { ...@@ -15,6 +15,7 @@ namespace ppredictor {
* Config * Config
*/ */
struct OCR_Config { struct OCR_Config {
int use_opencl = 0;
int thread_num = 4; // Thread num int thread_num = 4; // Thread num
paddle::lite_api::PowerMode mode = paddle::lite_api::PowerMode mode =
paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode
...@@ -27,8 +28,15 @@ struct OCRPredictResult { ...@@ -27,8 +28,15 @@ struct OCRPredictResult {
std::vector<int> word_index; std::vector<int> word_index;
std::vector<std::vector<int>> points; std::vector<std::vector<int>> points;
float score; float score;
float cls_score;
int cls_label=-1;
}; };
struct ClsPredictResult {
float cls_score;
int cls_label=-1;
cv::Mat img;
};
/** /**
* OCR there are 2 models * OCR there are 2 models
* 1. First model(det),select polygones to show where are the texts * 1. First model(det),select polygones to show where are the texts
...@@ -62,8 +70,7 @@ public: ...@@ -62,8 +70,7 @@ public:
* @return * @return
*/ */
virtual std::vector<OCRPredictResult> virtual std::vector<OCRPredictResult>
infer_ocr(const std::vector<int64_t> &dims, const float *input_data, infer_ocr(cv::Mat &origin, int max_size_len, int run_det, int run_cls, int run_rec);
int input_len, int net_flag, cv::Mat &origin);
virtual NET_TYPE get_net_flag() const; virtual NET_TYPE get_net_flag() const;
...@@ -80,16 +87,17 @@ private: ...@@ -80,16 +87,17 @@ private:
calc_filtered_boxes(const float *pred, int pred_size, int output_height, calc_filtered_boxes(const float *pred, int pred_size, int output_height,
int output_width, const cv::Mat &origin); int output_width, const cv::Mat &origin);
void
infer_det(cv::Mat &origin, int max_side_len, std::vector<OCRPredictResult>& ocr_results);
/** /**
* infer for second model * infer for rec model
* *
* @param boxes * @param boxes
* @param origin * @param origin
* @return * @return
*/ */
std::vector<OCRPredictResult> void
infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes, infer_rec(const cv::Mat &origin, int run_cls, OCRPredictResult& ocr_result);
const cv::Mat &origin);
/** /**
* infer for cls model * infer for cls model
...@@ -98,7 +106,7 @@ private: ...@@ -98,7 +106,7 @@ private:
* @param origin * @param origin
* @return * @return
*/ */
cv::Mat infer_cls(const cv::Mat &origin, float thresh = 0.9); ClsPredictResult infer_cls(const cv::Mat &origin, float thresh = 0.9);
/** /**
* Postprocess or sencod model to extract text * Postprocess or sencod model to extract text
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
#include "common.h" #include "common.h"
namespace ppredictor { namespace ppredictor {
PPredictor::PPredictor(int thread_num, int net_flag, PPredictor::PPredictor(int use_opencl, int thread_num, int net_flag,
paddle::lite_api::PowerMode mode) paddle::lite_api::PowerMode mode)
: _thread_num(thread_num), _net_flag(net_flag), _mode(mode) {} : _use_opencl(use_opencl), _thread_num(thread_num), _net_flag(net_flag), _mode(mode) {}
int PPredictor::init_nb(const std::string &model_content) { int PPredictor::init_nb(const std::string &model_content) {
paddle::lite_api::MobileConfig config; paddle::lite_api::MobileConfig config;
...@@ -19,10 +19,40 @@ int PPredictor::init_from_file(const std::string &model_content) { ...@@ -19,10 +19,40 @@ int PPredictor::init_from_file(const std::string &model_content) {
} }
template <typename ConfigT> int PPredictor::_init(ConfigT &config) { template <typename ConfigT> int PPredictor::_init(ConfigT &config) {
bool is_opencl_backend_valid = paddle::lite_api::IsOpenCLBackendValid(/*check_fp16_valid = false*/);
if (is_opencl_backend_valid) {
if (_use_opencl != 0) {
// Make sure you have write permission of the binary path.
// We strongly recommend each model has a unique binary name.
const std::string bin_path = "/data/local/tmp/";
const std::string bin_name = "lite_opencl_kernel.bin";
config.set_opencl_binary_path_name(bin_path, bin_name);
// opencl tune option
// CL_TUNE_NONE: 0
// CL_TUNE_RAPID: 1
// CL_TUNE_NORMAL: 2
// CL_TUNE_EXHAUSTIVE: 3
const std::string tuned_path = "/data/local/tmp/";
const std::string tuned_name = "lite_opencl_tuned.bin";
config.set_opencl_tune(paddle::lite_api::CL_TUNE_NORMAL, tuned_path, tuned_name);
// opencl precision option
// CL_PRECISION_AUTO: 0, first fp16 if valid, default
// CL_PRECISION_FP32: 1, force fp32
// CL_PRECISION_FP16: 2, force fp16
config.set_opencl_precision(paddle::lite_api::CL_PRECISION_FP32);
LOGI("ocr cpp device: running on gpu.");
}
} else {
LOGI("ocr cpp device: running on cpu.");
// you can give backup cpu nb model instead
// config.set_model_from_file(cpu_nb_model_dir);
}
config.set_threads(_thread_num); config.set_threads(_thread_num);
config.set_power_mode(_mode); config.set_power_mode(_mode);
_predictor = paddle::lite_api::CreatePaddlePredictor(config); _predictor = paddle::lite_api::CreatePaddlePredictor(config);
LOGI("paddle instance created"); LOGI("ocr cpp paddle instance created");
return RETURN_OK; return RETURN_OK;
} }
...@@ -43,18 +73,18 @@ std::vector<PredictorInput> PPredictor::get_inputs(int num) { ...@@ -43,18 +73,18 @@ std::vector<PredictorInput> PPredictor::get_inputs(int num) {
PredictorInput PPredictor::get_first_input() { return get_input(0); } PredictorInput PPredictor::get_first_input() { return get_input(0); }
std::vector<PredictorOutput> PPredictor::infer() { std::vector<PredictorOutput> PPredictor::infer() {
LOGI("infer Run start %d", _net_flag); LOGI("ocr cpp infer Run start %d", _net_flag);
std::vector<PredictorOutput> results; std::vector<PredictorOutput> results;
if (!_is_input_get) { if (!_is_input_get) {
return results; return results;
} }
_predictor->Run(); _predictor->Run();
LOGI("infer Run end"); LOGI("ocr cpp infer Run end");
for (int i = 0; i < _predictor->GetOutputNames().size(); i++) { for (int i = 0; i < _predictor->GetOutputNames().size(); i++) {
std::unique_ptr<const paddle::lite_api::Tensor> output_tensor = std::unique_ptr<const paddle::lite_api::Tensor> output_tensor =
_predictor->GetOutput(i); _predictor->GetOutput(i);
LOGI("output tensor[%d] size %ld", i, product(output_tensor->shape())); LOGI("ocr cpp output tensor[%d] size %ld", i, product(output_tensor->shape()));
PredictorOutput result{std::move(output_tensor), i, _net_flag}; PredictorOutput result{std::move(output_tensor), i, _net_flag};
results.emplace_back(std::move(result)); results.emplace_back(std::move(result));
} }
......
...@@ -22,7 +22,7 @@ public: ...@@ -22,7 +22,7 @@ public:
class PPredictor : public PPredictor_Interface { class PPredictor : public PPredictor_Interface {
public: public:
PPredictor( PPredictor(
int thread_num, int net_flag = 0, int use_opencl, int thread_num, int net_flag = 0,
paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH); paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH);
virtual ~PPredictor() {} virtual ~PPredictor() {}
...@@ -54,6 +54,7 @@ protected: ...@@ -54,6 +54,7 @@ protected:
template <typename ConfigT> int _init(ConfigT &config); template <typename ConfigT> int _init(ConfigT &config);
private: private:
int _use_opencl;
int _thread_num; int _thread_num;
paddle::lite_api::PowerMode _mode; paddle::lite_api::PowerMode _mode;
std::shared_ptr<paddle::lite_api::PaddlePredictor> _predictor; std::shared_ptr<paddle::lite_api::PaddlePredictor> _predictor;
......
...@@ -13,6 +13,7 @@ import android.graphics.BitmapFactory; ...@@ -13,6 +13,7 @@ import android.graphics.BitmapFactory;
import android.graphics.drawable.BitmapDrawable; import android.graphics.drawable.BitmapDrawable;
import android.media.ExifInterface; import android.media.ExifInterface;
import android.content.res.AssetManager; import android.content.res.AssetManager;
import android.media.FaceDetector;
import android.net.Uri; import android.net.Uri;
import android.os.Bundle; import android.os.Bundle;
import android.os.Environment; import android.os.Environment;
...@@ -27,7 +28,9 @@ import android.view.Menu; ...@@ -27,7 +28,9 @@ import android.view.Menu;
import android.view.MenuInflater; import android.view.MenuInflater;
import android.view.MenuItem; import android.view.MenuItem;
import android.view.View; import android.view.View;
import android.widget.CheckBox;
import android.widget.ImageView; import android.widget.ImageView;
import android.widget.Spinner;
import android.widget.TextView; import android.widget.TextView;
import android.widget.Toast; import android.widget.Toast;
...@@ -68,23 +71,24 @@ public class MainActivity extends AppCompatActivity { ...@@ -68,23 +71,24 @@ public class MainActivity extends AppCompatActivity {
protected ImageView ivInputImage; protected ImageView ivInputImage;
protected TextView tvOutputResult; protected TextView tvOutputResult;
protected TextView tvInferenceTime; protected TextView tvInferenceTime;
protected CheckBox cbOpencl;
protected Spinner spRunMode;
// Model settings of object detection // Model settings of ocr
protected String modelPath = ""; protected String modelPath = "";
protected String labelPath = ""; protected String labelPath = "";
protected String imagePath = ""; protected String imagePath = "";
protected int cpuThreadNum = 1; protected int cpuThreadNum = 1;
protected String cpuPowerMode = ""; protected String cpuPowerMode = "";
protected String inputColorFormat = ""; protected int detLongSize = 960;
protected long[] inputShape = new long[]{};
protected float[] inputMean = new float[]{};
protected float[] inputStd = new float[]{};
protected float scoreThreshold = 0.1f; protected float scoreThreshold = 0.1f;
private String currentPhotoPath; private String currentPhotoPath;
private AssetManager assetManager =null; private AssetManager assetManager = null;
protected Predictor predictor = new Predictor(); protected Predictor predictor = new Predictor();
private Bitmap cur_predict_image = null;
@Override @Override
protected void onCreate(Bundle savedInstanceState) { protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState); super.onCreate(savedInstanceState);
...@@ -98,10 +102,12 @@ public class MainActivity extends AppCompatActivity { ...@@ -98,10 +102,12 @@ public class MainActivity extends AppCompatActivity {
// Setup the UI components // Setup the UI components
tvInputSetting = findViewById(R.id.tv_input_setting); tvInputSetting = findViewById(R.id.tv_input_setting);
cbOpencl = findViewById(R.id.cb_opencl);
tvStatus = findViewById(R.id.tv_model_img_status); tvStatus = findViewById(R.id.tv_model_img_status);
ivInputImage = findViewById(R.id.iv_input_image); ivInputImage = findViewById(R.id.iv_input_image);
tvInferenceTime = findViewById(R.id.tv_inference_time); tvInferenceTime = findViewById(R.id.tv_inference_time);
tvOutputResult = findViewById(R.id.tv_output_result); tvOutputResult = findViewById(R.id.tv_output_result);
spRunMode = findViewById(R.id.sp_run_mode);
tvInputSetting.setMovementMethod(ScrollingMovementMethod.getInstance()); tvInputSetting.setMovementMethod(ScrollingMovementMethod.getInstance());
tvOutputResult.setMovementMethod(ScrollingMovementMethod.getInstance()); tvOutputResult.setMovementMethod(ScrollingMovementMethod.getInstance());
...@@ -111,26 +117,26 @@ public class MainActivity extends AppCompatActivity { ...@@ -111,26 +117,26 @@ public class MainActivity extends AppCompatActivity {
public void handleMessage(Message msg) { public void handleMessage(Message msg) {
switch (msg.what) { switch (msg.what) {
case RESPONSE_LOAD_MODEL_SUCCESSED: case RESPONSE_LOAD_MODEL_SUCCESSED:
if(pbLoadModel!=null && pbLoadModel.isShowing()){ if (pbLoadModel != null && pbLoadModel.isShowing()) {
pbLoadModel.dismiss(); pbLoadModel.dismiss();
} }
onLoadModelSuccessed(); onLoadModelSuccessed();
break; break;
case RESPONSE_LOAD_MODEL_FAILED: case RESPONSE_LOAD_MODEL_FAILED:
if(pbLoadModel!=null && pbLoadModel.isShowing()){ if (pbLoadModel != null && pbLoadModel.isShowing()) {
pbLoadModel.dismiss(); pbLoadModel.dismiss();
} }
Toast.makeText(MainActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show(); Toast.makeText(MainActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show();
onLoadModelFailed(); onLoadModelFailed();
break; break;
case RESPONSE_RUN_MODEL_SUCCESSED: case RESPONSE_RUN_MODEL_SUCCESSED:
if(pbRunModel!=null && pbRunModel.isShowing()){ if (pbRunModel != null && pbRunModel.isShowing()) {
pbRunModel.dismiss(); pbRunModel.dismiss();
} }
onRunModelSuccessed(); onRunModelSuccessed();
break; break;
case RESPONSE_RUN_MODEL_FAILED: case RESPONSE_RUN_MODEL_FAILED:
if(pbRunModel!=null && pbRunModel.isShowing()){ if (pbRunModel != null && pbRunModel.isShowing()) {
pbRunModel.dismiss(); pbRunModel.dismiss();
} }
Toast.makeText(MainActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show(); Toast.makeText(MainActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show();
...@@ -175,71 +181,47 @@ public class MainActivity extends AppCompatActivity { ...@@ -175,71 +181,47 @@ public class MainActivity extends AppCompatActivity {
super.onResume(); super.onResume();
SharedPreferences sharedPreferences = PreferenceManager.getDefaultSharedPreferences(this); SharedPreferences sharedPreferences = PreferenceManager.getDefaultSharedPreferences(this);
boolean settingsChanged = false; boolean settingsChanged = false;
boolean model_settingsChanged = false;
String model_path = sharedPreferences.getString(getString(R.string.MODEL_PATH_KEY), String model_path = sharedPreferences.getString(getString(R.string.MODEL_PATH_KEY),
getString(R.string.MODEL_PATH_DEFAULT)); getString(R.string.MODEL_PATH_DEFAULT));
String label_path = sharedPreferences.getString(getString(R.string.LABEL_PATH_KEY), String label_path = sharedPreferences.getString(getString(R.string.LABEL_PATH_KEY),
getString(R.string.LABEL_PATH_DEFAULT)); getString(R.string.LABEL_PATH_DEFAULT));
String image_path = sharedPreferences.getString(getString(R.string.IMAGE_PATH_KEY), String image_path = sharedPreferences.getString(getString(R.string.IMAGE_PATH_KEY),
getString(R.string.IMAGE_PATH_DEFAULT)); getString(R.string.IMAGE_PATH_DEFAULT));
settingsChanged |= !model_path.equalsIgnoreCase(modelPath); model_settingsChanged |= !model_path.equalsIgnoreCase(modelPath);
settingsChanged |= !label_path.equalsIgnoreCase(labelPath); settingsChanged |= !label_path.equalsIgnoreCase(labelPath);
settingsChanged |= !image_path.equalsIgnoreCase(imagePath); settingsChanged |= !image_path.equalsIgnoreCase(imagePath);
int cpu_thread_num = Integer.parseInt(sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY), int cpu_thread_num = Integer.parseInt(sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY),
getString(R.string.CPU_THREAD_NUM_DEFAULT))); getString(R.string.CPU_THREAD_NUM_DEFAULT)));
settingsChanged |= cpu_thread_num != cpuThreadNum; model_settingsChanged |= cpu_thread_num != cpuThreadNum;
String cpu_power_mode = String cpu_power_mode =
sharedPreferences.getString(getString(R.string.CPU_POWER_MODE_KEY), sharedPreferences.getString(getString(R.string.CPU_POWER_MODE_KEY),
getString(R.string.CPU_POWER_MODE_DEFAULT)); getString(R.string.CPU_POWER_MODE_DEFAULT));
settingsChanged |= !cpu_power_mode.equalsIgnoreCase(cpuPowerMode); model_settingsChanged |= !cpu_power_mode.equalsIgnoreCase(cpuPowerMode);
String input_color_format =
sharedPreferences.getString(getString(R.string.INPUT_COLOR_FORMAT_KEY), int det_long_size = Integer.parseInt(sharedPreferences.getString(getString(R.string.DET_LONG_SIZE_KEY),
getString(R.string.INPUT_COLOR_FORMAT_DEFAULT)); getString(R.string.DET_LONG_SIZE_DEFAULT)));
settingsChanged |= !input_color_format.equalsIgnoreCase(inputColorFormat); settingsChanged |= det_long_size != detLongSize;
long[] input_shape =
Utils.parseLongsFromString(sharedPreferences.getString(getString(R.string.INPUT_SHAPE_KEY),
getString(R.string.INPUT_SHAPE_DEFAULT)), ",");
float[] input_mean =
Utils.parseFloatsFromString(sharedPreferences.getString(getString(R.string.INPUT_MEAN_KEY),
getString(R.string.INPUT_MEAN_DEFAULT)), ",");
float[] input_std =
Utils.parseFloatsFromString(sharedPreferences.getString(getString(R.string.INPUT_STD_KEY)
, getString(R.string.INPUT_STD_DEFAULT)), ",");
settingsChanged |= input_shape.length != inputShape.length;
settingsChanged |= input_mean.length != inputMean.length;
settingsChanged |= input_std.length != inputStd.length;
if (!settingsChanged) {
for (int i = 0; i < input_shape.length; i++) {
settingsChanged |= input_shape[i] != inputShape[i];
}
for (int i = 0; i < input_mean.length; i++) {
settingsChanged |= input_mean[i] != inputMean[i];
}
for (int i = 0; i < input_std.length; i++) {
settingsChanged |= input_std[i] != inputStd[i];
}
}
float score_threshold = float score_threshold =
Float.parseFloat(sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY), Float.parseFloat(sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY),
getString(R.string.SCORE_THRESHOLD_DEFAULT))); getString(R.string.SCORE_THRESHOLD_DEFAULT)));
settingsChanged |= scoreThreshold != score_threshold; settingsChanged |= scoreThreshold != score_threshold;
if (settingsChanged) { if (settingsChanged) {
modelPath = model_path;
labelPath = label_path; labelPath = label_path;
imagePath = image_path; imagePath = image_path;
detLongSize = det_long_size;
scoreThreshold = score_threshold;
set_img();
}
if (model_settingsChanged) {
modelPath = model_path;
cpuThreadNum = cpu_thread_num; cpuThreadNum = cpu_thread_num;
cpuPowerMode = cpu_power_mode; cpuPowerMode = cpu_power_mode;
inputColorFormat = input_color_format;
inputShape = input_shape;
inputMean = input_mean;
inputStd = input_std;
scoreThreshold = score_threshold;
// Update UI // Update UI
tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\n" + "CPU" + tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\nOPENCL: " + cbOpencl.isChecked() + "\nCPU Thread Num: " + cpuThreadNum + "\nCPU Power Mode: " + cpuPowerMode);
" Thread Num: " + Integer.toString(cpuThreadNum) + "\n" + "CPU Power Mode: " + cpuPowerMode);
tvInputSetting.scrollTo(0, 0); tvInputSetting.scrollTo(0, 0);
// Reload model if configure has been changed // Reload model if configure has been changed
// loadModel(); loadModel();
set_img();
} }
} }
...@@ -254,20 +236,28 @@ public class MainActivity extends AppCompatActivity { ...@@ -254,20 +236,28 @@ public class MainActivity extends AppCompatActivity {
} }
public boolean onLoadModel() { public boolean onLoadModel() {
return predictor.init(MainActivity.this, modelPath, labelPath, cpuThreadNum, if (predictor.isLoaded()) {
predictor.releaseModel();
}
return predictor.init(MainActivity.this, modelPath, labelPath, cbOpencl.isChecked() ? 1 : 0, cpuThreadNum,
cpuPowerMode, cpuPowerMode,
inputColorFormat, detLongSize, scoreThreshold);
inputShape, inputMean,
inputStd, scoreThreshold);
} }
public boolean onRunModel() { public boolean onRunModel() {
return predictor.isLoaded() && predictor.runModel(); String run_mode = spRunMode.getSelectedItem().toString();
int run_det = run_mode.contains("检测") ? 1 : 0;
int run_cls = run_mode.contains("分类") ? 1 : 0;
int run_rec = run_mode.contains("识别") ? 1 : 0;
return predictor.isLoaded() && predictor.runModel(run_det, run_cls, run_rec);
} }
public void onLoadModelSuccessed() { public void onLoadModelSuccessed() {
// Load test image from path and run model // Load test image from path and run model
tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\nOPENCL: " + cbOpencl.isChecked() + "\nCPU Thread Num: " + cpuThreadNum + "\nCPU Power Mode: " + cpuPowerMode);
tvInputSetting.scrollTo(0, 0);
tvStatus.setText("STATUS: load model successed"); tvStatus.setText("STATUS: load model successed");
} }
public void onLoadModelFailed() { public void onLoadModelFailed() {
...@@ -290,20 +280,13 @@ public class MainActivity extends AppCompatActivity { ...@@ -290,20 +280,13 @@ public class MainActivity extends AppCompatActivity {
tvStatus.setText("STATUS: run model failed"); tvStatus.setText("STATUS: run model failed");
} }
public void onImageChanged(Bitmap image) {
// Rerun model if users pick test image from gallery or camera
if (image != null && predictor.isLoaded()) {
predictor.setInputImage(image);
runModel();
}
}
public void set_img() { public void set_img() {
// Load test image from path and run model // Load test image from path and run model
try { try {
assetManager= getAssets(); assetManager = getAssets();
InputStream in=assetManager.open(imagePath); InputStream in = assetManager.open(imagePath);
Bitmap bmp=BitmapFactory.decodeStream(in); Bitmap bmp = BitmapFactory.decodeStream(in);
cur_predict_image = bmp;
ivInputImage.setImageBitmap(bmp); ivInputImage.setImageBitmap(bmp);
} catch (IOException e) { } catch (IOException e) {
Toast.makeText(MainActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show(); Toast.makeText(MainActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show();
...@@ -430,7 +413,7 @@ public class MainActivity extends AppCompatActivity { ...@@ -430,7 +413,7 @@ public class MainActivity extends AppCompatActivity {
Cursor cursor = managedQuery(uri, proj, null, null, null); Cursor cursor = managedQuery(uri, proj, null, null, null);
cursor.moveToFirst(); cursor.moveToFirst();
if (image != null) { if (image != null) {
// onImageChanged(image); cur_predict_image = image;
ivInputImage.setImageBitmap(image); ivInputImage.setImageBitmap(image);
} }
} catch (IOException e) { } catch (IOException e) {
...@@ -451,7 +434,7 @@ public class MainActivity extends AppCompatActivity { ...@@ -451,7 +434,7 @@ public class MainActivity extends AppCompatActivity {
Bitmap image = BitmapFactory.decodeFile(currentPhotoPath); Bitmap image = BitmapFactory.decodeFile(currentPhotoPath);
image = Utils.rotateBitmap(image, orientation); image = Utils.rotateBitmap(image, orientation);
if (image != null) { if (image != null) {
// onImageChanged(image); cur_predict_image = image;
ivInputImage.setImageBitmap(image); ivInputImage.setImageBitmap(image);
} }
} else { } else {
...@@ -464,28 +447,28 @@ public class MainActivity extends AppCompatActivity { ...@@ -464,28 +447,28 @@ public class MainActivity extends AppCompatActivity {
} }
} }
public void btn_load_model_click(View view) { public void btn_reset_img_click(View view) {
if (predictor.isLoaded()){ ivInputImage.setImageBitmap(cur_predict_image);
tvStatus.setText("STATUS: model has been loaded"); }
}else{
public void cb_opencl_click(View view) {
tvStatus.setText("STATUS: load model ......"); tvStatus.setText("STATUS: load model ......");
loadModel(); loadModel();
} }
}
public void btn_run_model_click(View view) { public void btn_run_model_click(View view) {
Bitmap image =((BitmapDrawable)ivInputImage.getDrawable()).getBitmap(); Bitmap image = ((BitmapDrawable) ivInputImage.getDrawable()).getBitmap();
if(image == null) { if (image == null) {
tvStatus.setText("STATUS: image is not exists"); tvStatus.setText("STATUS: image is not exists");
} } else if (!predictor.isLoaded()) {
else if (!predictor.isLoaded()){
tvStatus.setText("STATUS: model is not loaded"); tvStatus.setText("STATUS: model is not loaded");
}else{ } else {
tvStatus.setText("STATUS: run model ...... "); tvStatus.setText("STATUS: run model ...... ");
predictor.setInputImage(image); predictor.setInputImage(image);
runModel(); runModel();
} }
} }
public void btn_choice_img_click(View view) { public void btn_choice_img_click(View view) {
if (requestAllPermissions()) { if (requestAllPermissions()) {
openGallery(); openGallery();
...@@ -506,4 +489,32 @@ public class MainActivity extends AppCompatActivity { ...@@ -506,4 +489,32 @@ public class MainActivity extends AppCompatActivity {
worker.quit(); worker.quit();
super.onDestroy(); super.onDestroy();
} }
public int get_run_mode() {
String run_mode = spRunMode.getSelectedItem().toString();
int mode;
switch (run_mode) {
case "检测+分类+识别":
mode = 1;
break;
case "检测+识别":
mode = 2;
break;
case "识别+分类":
mode = 3;
break;
case "检测":
mode = 4;
break;
case "识别":
mode = 5;
break;
case "分类":
mode = 6;
break;
default:
mode = 1;
}
return mode;
}
} }
package com.baidu.paddle.lite.demo.ocr;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Build;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.Message;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import androidx.appcompat.app.AppCompatActivity;
import java.io.IOException;
import java.io.InputStream;
public class MiniActivity extends AppCompatActivity {
public static final int REQUEST_LOAD_MODEL = 0;
public static final int REQUEST_RUN_MODEL = 1;
public static final int REQUEST_UNLOAD_MODEL = 2;
public static final int RESPONSE_LOAD_MODEL_SUCCESSED = 0;
public static final int RESPONSE_LOAD_MODEL_FAILED = 1;
public static final int RESPONSE_RUN_MODEL_SUCCESSED = 2;
public static final int RESPONSE_RUN_MODEL_FAILED = 3;
private static final String TAG = "MiniActivity";
protected Handler receiver = null; // Receive messages from worker thread
protected Handler sender = null; // Send command to worker thread
protected HandlerThread worker = null; // Worker thread to load&run model
protected volatile Predictor predictor = null;
private String assetModelDirPath = "models/ocr_v2_for_cpu";
private String assetlabelFilePath = "labels/ppocr_keys_v1.txt";
private Button button;
private ImageView imageView; // image result
private TextView textView; // text result
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_mini);
Log.i(TAG, "SHOW in Logcat");
// Prepare the worker thread for mode loading and inference
worker = new HandlerThread("Predictor Worker");
worker.start();
sender = new Handler(worker.getLooper()) {
public void handleMessage(Message msg) {
switch (msg.what) {
case REQUEST_LOAD_MODEL:
// Load model and reload test image
if (!onLoadModel()) {
runOnUiThread(new Runnable() {
@Override
public void run() {
Toast.makeText(MiniActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show();
}
});
}
break;
case REQUEST_RUN_MODEL:
// Run model if model is loaded
final boolean isSuccessed = onRunModel();
runOnUiThread(new Runnable() {
@Override
public void run() {
if (isSuccessed){
onRunModelSuccessed();
}else{
Toast.makeText(MiniActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show();
}
}
});
break;
}
}
};
sender.sendEmptyMessage(REQUEST_LOAD_MODEL); // corresponding to REQUEST_LOAD_MODEL, to call onLoadModel()
imageView = findViewById(R.id.imageView);
textView = findViewById(R.id.sample_text);
button = findViewById(R.id.button);
button.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
sender.sendEmptyMessage(REQUEST_RUN_MODEL);
}
});
}
@Override
protected void onDestroy() {
onUnloadModel();
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.JELLY_BEAN_MR2) {
worker.quitSafely();
} else {
worker.quit();
}
super.onDestroy();
}
/**
* call in onCreate, model init
*
* @return
*/
private boolean onLoadModel() {
if (predictor == null) {
predictor = new Predictor();
}
return predictor.init(this, assetModelDirPath, assetlabelFilePath);
}
/**
* init engine
* call in onCreate
*
* @return
*/
private boolean onRunModel() {
try {
String assetImagePath = "images/0.jpg";
InputStream imageStream = getAssets().open(assetImagePath);
Bitmap image = BitmapFactory.decodeStream(imageStream);
// Input is Bitmap
predictor.setInputImage(image);
return predictor.isLoaded() && predictor.runModel();
} catch (IOException e) {
e.printStackTrace();
return false;
}
}
private void onRunModelSuccessed() {
Log.i(TAG, "onRunModelSuccessed");
textView.setText(predictor.outputResult);
imageView.setImageBitmap(predictor.outputImage);
}
private void onUnloadModel() {
if (predictor != null) {
predictor.releaseModel();
}
}
}
...@@ -29,22 +29,22 @@ public class OCRPredictorNative { ...@@ -29,22 +29,22 @@ public class OCRPredictorNative {
public OCRPredictorNative(Config config) { public OCRPredictorNative(Config config) {
this.config = config; this.config = config;
loadLibrary(); loadLibrary();
nativePointer = init(config.detModelFilename, config.recModelFilename,config.clsModelFilename, nativePointer = init(config.detModelFilename, config.recModelFilename, config.clsModelFilename, config.useOpencl,
config.cpuThreadNum, config.cpuPower); config.cpuThreadNum, config.cpuPower);
Log.i("OCRPredictorNative", "load success " + nativePointer); Log.i("OCRPredictorNative", "load success " + nativePointer);
} }
public ArrayList<OcrResultModel> runImage(float[] inputData, int width, int height, int channels, Bitmap originalImage) { public ArrayList<OcrResultModel> runImage(Bitmap originalImage, int max_size_len, int run_det, int run_cls, int run_rec) {
Log.i("OCRPredictorNative", "begin to run image " + inputData.length + " " + width + " " + height); Log.i("OCRPredictorNative", "begin to run image ");
float[] dims = new float[]{1, channels, height, width}; float[] rawResults = forward(nativePointer, originalImage, max_size_len, run_det, run_cls, run_rec);
float[] rawResults = forward(nativePointer, inputData, dims, originalImage);
ArrayList<OcrResultModel> results = postprocess(rawResults); ArrayList<OcrResultModel> results = postprocess(rawResults);
return results; return results;
} }
public static class Config { public static class Config {
public int useOpencl;
public int cpuThreadNum; public int cpuThreadNum;
public String cpuPower; public String cpuPower;
public String detModelFilename; public String detModelFilename;
...@@ -53,16 +53,16 @@ public class OCRPredictorNative { ...@@ -53,16 +53,16 @@ public class OCRPredictorNative {
} }
public void destory(){ public void destory() {
if (nativePointer > 0) { if (nativePointer > 0) {
release(nativePointer); release(nativePointer);
nativePointer = 0; nativePointer = 0;
} }
} }
protected native long init(String detModelPath, String recModelPath,String clsModelPath, int threadNum, String cpuMode); protected native long init(String detModelPath, String recModelPath, String clsModelPath, int useOpencl, int threadNum, String cpuMode);
protected native float[] forward(long pointer, float[] buf, float[] ddims, Bitmap originalImage); protected native float[] forward(long pointer, Bitmap originalImage,int max_size_len, int run_det, int run_cls, int run_rec);
protected native void release(long pointer); protected native void release(long pointer);
...@@ -73,9 +73,9 @@ public class OCRPredictorNative { ...@@ -73,9 +73,9 @@ public class OCRPredictorNative {
while (begin < raw.length) { while (begin < raw.length) {
int point_num = Math.round(raw[begin]); int point_num = Math.round(raw[begin]);
int word_num = Math.round(raw[begin + 1]); int word_num = Math.round(raw[begin + 1]);
OcrResultModel model = parse(raw, begin + 2, point_num, word_num); OcrResultModel res = parse(raw, begin + 2, point_num, word_num);
begin += 2 + 1 + point_num * 2 + word_num; begin += 2 + 1 + point_num * 2 + word_num + 2;
results.add(model); results.add(res);
} }
return results; return results;
...@@ -83,19 +83,22 @@ public class OCRPredictorNative { ...@@ -83,19 +83,22 @@ public class OCRPredictorNative {
private OcrResultModel parse(float[] raw, int begin, int pointNum, int wordNum) { private OcrResultModel parse(float[] raw, int begin, int pointNum, int wordNum) {
int current = begin; int current = begin;
OcrResultModel model = new OcrResultModel(); OcrResultModel res = new OcrResultModel();
model.setConfidence(raw[current]); res.setConfidence(raw[current]);
current++; current++;
for (int i = 0; i < pointNum; i++) { for (int i = 0; i < pointNum; i++) {
model.addPoints(Math.round(raw[current + i * 2]), Math.round(raw[current + i * 2 + 1])); res.addPoints(Math.round(raw[current + i * 2]), Math.round(raw[current + i * 2 + 1]));
} }
current += (pointNum * 2); current += (pointNum * 2);
for (int i = 0; i < wordNum; i++) { for (int i = 0; i < wordNum; i++) {
int index = Math.round(raw[current + i]); int index = Math.round(raw[current + i]);
model.addWordIndex(index); res.addWordIndex(index);
} }
current += wordNum;
res.setClsIdx(raw[current]);
res.setClsConfidence(raw[current + 1]);
Log.i("OCRPredictorNative", "word finished " + wordNum); Log.i("OCRPredictorNative", "word finished " + wordNum);
return model; return res;
} }
......
...@@ -10,6 +10,9 @@ public class OcrResultModel { ...@@ -10,6 +10,9 @@ public class OcrResultModel {
private List<Integer> wordIndex; private List<Integer> wordIndex;
private String label; private String label;
private float confidence; private float confidence;
private float cls_idx;
private String cls_label;
private float cls_confidence;
public OcrResultModel() { public OcrResultModel() {
super(); super();
...@@ -49,4 +52,28 @@ public class OcrResultModel { ...@@ -49,4 +52,28 @@ public class OcrResultModel {
public void setConfidence(float confidence) { public void setConfidence(float confidence) {
this.confidence = confidence; this.confidence = confidence;
} }
public float getClsIdx() {
return cls_idx;
}
public void setClsIdx(float idx) {
this.cls_idx = idx;
}
public String getClsLabel() {
return cls_label;
}
public void setClsLabel(String label) {
this.cls_label = label;
}
public float getClsConfidence() {
return cls_confidence;
}
public void setClsConfidence(float confidence) {
this.cls_confidence = confidence;
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment