Unverified Commit 006d84bf authored by 崔浩's avatar 崔浩 Committed by GitHub
Browse files

Merge branch 'PaddlePaddle:dygraph' into dygraph

parents 302ca30c 8beeb84c
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
# A global variable to record the number of calling times for profiler
# functions. It is used to specify the tracing range of training steps.
_profiler_step_id = 0
# A global variable to avoid parsing from string every time.
_profiler_options = None
class ProfilerOptions(object):
'''
Use a string to initialize a ProfilerOptions.
The string should be in the format: "key1=value1;key2=value;key3=value3".
For example:
"profile_path=model.profile"
"batch_range=[50, 60]; profile_path=model.profile"
"batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
ProfilerOptions supports following key-value pair:
batch_range - a integer list, e.g. [100, 110].
state - a string, the optional values are 'CPU', 'GPU' or 'All'.
sorted_key - a string, the optional values are 'calls', 'total',
'max', 'min' or 'ave.
tracer_option - a string, the optional values are 'Default', 'OpDetail',
'AllOpDetail'.
profile_path - a string, the path to save the serialized profile data,
which can be used to generate a timeline.
exit_on_finished - a boolean.
'''
def __init__(self, options_str):
assert isinstance(options_str, str)
self._options = {
'batch_range': [10, 20],
'state': 'All',
'sorted_key': 'total',
'tracer_option': 'Default',
'profile_path': '/tmp/profile',
'exit_on_finished': True
}
self._parse_from_string(options_str)
def _parse_from_string(self, options_str):
for kv in options_str.replace(' ', '').split(';'):
key, value = kv.split('=')
if key == 'batch_range':
value_list = value.replace('[', '').replace(']', '').split(',')
value_list = list(map(int, value_list))
if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
1] > value_list[0]:
self._options[key] = value_list
elif key == 'exit_on_finished':
self._options[key] = value.lower() in ("yes", "true", "t", "1")
elif key in [
'state', 'sorted_key', 'tracer_option', 'profile_path'
]:
self._options[key] = value
def __getitem__(self, name):
if self._options.get(name, None) is None:
raise ValueError(
"ProfilerOptions does not have an option named %s." % name)
return self._options[name]
def add_profiler_step(options_str=None):
'''
Enable the operator-level timing using PaddlePaddle's profiler.
The profiler uses a independent variable to count the profiler steps.
One call of this function is treated as a profiler step.
Args:
profiler_options - a string to initialize the ProfilerOptions.
Default is None, and the profiler is disabled.
'''
if options_str is None:
return
global _profiler_step_id
global _profiler_options
if _profiler_options is None:
_profiler_options = ProfilerOptions(options_str)
if _profiler_step_id == _profiler_options['batch_range'][0]:
paddle.utils.profiler.start_profiler(
_profiler_options['state'], _profiler_options['tracer_option'])
elif _profiler_step_id == _profiler_options['batch_range'][1]:
paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
_profiler_options['profile_path'])
if _profiler_options['exit_on_finished']:
sys.exit(0)
_profiler_step_id += 1
......@@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer):
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}")
return {}
def load_pretrained_params(model, path):
if path is None:
return False
......@@ -138,6 +139,7 @@ def load_pretrained_params(model, path):
print(f"load pretrain successful from {path}")
return model
def save_model(model,
optimizer,
model_path,
......
......@@ -30,13 +30,13 @@ python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/
# CPU
python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
# For more,refer[Installation](https://www.paddlepaddle.org.cn/install/quick)。
```
For more,refer [Installation](https://www.paddlepaddle.org.cn/install/quick) .
- **(2) Install Layout-Parser**
```bash
pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
```
### 2.2 Install PaddleOCR(including PP-OCR and PP-Structure)
......@@ -124,8 +124,6 @@ Most of the parameters are consistent with the paddleocr whl package, see [doc o
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
## 4. PP-Structure Pipeline
the process is as follows
![pipeline](../doc/table/pipeline_en.jpg)
In PP-Structure, the image will be analyzed by layoutparser first. In the layout analysis, the area in the image will be classified, including **text, title, image, list and table** 5 categories. For the first 4 types of areas, directly use the PP-OCR to complete the text detection and recognition. The table area will be converted to an excel file of the same table style via Table OCR.
......@@ -180,10 +178,10 @@ OCR and table recognition model
|model name|description|model size|download|
| --- | --- | --- | --- |
|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) |
|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) |
|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) |
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
If you need to use other models, you can download the model in [model_list](../doc/doc_en/models_list_en.md) or use your own trained model to configure it to the three fields of `det_model_dir`, `rec_model_dir`, `table_model_dir` .
......@@ -30,13 +30,13 @@ python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/
# CPU安装
python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
# 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
- **(2) 安装 Layout-Parser**
```bash
pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
```
### 2.2 安装PaddleOCR(包含PP-OCR和PP-Structure)
......@@ -179,10 +179,10 @@ OCR和表格识别模型
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) |
|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) |
|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
如需要使用其他模型,可以在 [model_list](../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
......@@ -4,9 +4,9 @@
[1.1 Requirements](#Requirements)
[1.2 Install PaddleDetection](#Install PaddleDetection)
[1.2 Install PaddleDetection](#Install_PaddleDetection)
[2. Data preparation](#Data preparation)
[2. Data preparation](#Data_reparation)
[3. Configuration](#Configuration)
......@@ -16,7 +16,7 @@
[6. Deployment](#Deployment)
[6.1 Export model](#Export model)
[6.1 Export model](#Export_model)
[6.2 Inference](#Inference)
......@@ -35,7 +35,7 @@
- CUDA >= 10.1
- cuDNN >= 7.6
<a name="Install PaddleDetection"></a>
<a name="Install_PaddleDetection"></a>
### 1.2 Install PaddleDetection
......@@ -51,7 +51,7 @@ pip install -r requirements.txt
For more installation tutorials, please refer to: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md)
<a name="Data preparation"></a>
<a name="Data_preparation"></a>
## 2. Data preparation
......@@ -165,7 +165,7 @@ python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer
Use your trained model in Layout Parser
<a name="Export model"></a>
<a name="Export_model"></a>
### 6.1 Export model
......
......@@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# run
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --det_limit_side_len=736 --det_limit_type=min --output ../output/table
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
......
# 表格识别
* [1. 表格识别 pipeline](#1)
* [2. 性能](#2)
* [3. 使用](#3)
+ [3.1 快速开始](#31)
+ [3.2 训练](#32)
+ [3.3 评估](#33)
+ [3.4 预测](#34)
<a name="1"></a>
## 1. 表格识别 pipeline
表格识别主要包含三个模型
1. 单行文本检测-DB
2. 单行文本识别-CRNN
......@@ -17,6 +27,8 @@
3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。
4. 单元格的识别结果和表格结构一起构造表格的html字符串。
<a name="2"></a>
## 2. 性能
我们在 PubTabNet<sup>[1]</sup> 评估数据集上对算法进行了评估,性能如下
......@@ -26,8 +38,9 @@
| EDD<sup>[2]</sup> | 88.3 |
| Ours | 93.32 |
<a name="3"></a>
## 3. 使用
<a name="31"></a>
### 3.1 快速开始
```python
......@@ -43,12 +56,12 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# 执行预测
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --det_limit_side_len=736 --det_limit_type=min --output ../output/table
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
note: 上述模型是在 PubLayNet 数据集上训练的表格识别模型,仅支持英文扫描场景,如需识别其他场景需要自己训练模型后替换 `det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
<a name="32"></a>
### 3.2 训练
在这一章节中,我们仅介绍表格结构模型的训练,[文字检测](../../doc/doc_ch/detection.md)[文字识别](../../doc/doc_ch/recognition.md)的模型训练请参考对应的文档。
......@@ -75,7 +88,7 @@ python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./yo
**注意**`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
<a name="33"></a>
### 3.3 评估
表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
......@@ -100,7 +113,7 @@ python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_di
```bash
teds: 93.32
```
<a name="34"></a>
### 3.4 预测
```python
......
shapely
scikit-image==0.17.2
scikit-image==0.18.3
imgaug==0.4.0
pyclipper
lmdb
......@@ -7,4 +7,9 @@ tqdm
numpy
visualdl
python-Levenshtein
opencv-contrib-python==4.4.0.46
\ No newline at end of file
opencv-contrib-python==4.4.0.46
cython
lxml
premailer
openpyxl
fasttext==0.9.1
\ No newline at end of file
# 介绍
test.sh和params.txt文件配合使用,完成OCR轻量检测和识别模型从训练到预测的流程测试。
# 安装依赖
- 安装PaddlePaddle >= 2.0
- 安装PaddleOCR依赖
```
pip3 install -r ../requirements.txt
```
- 安装autolog
```
git clone https://github.com/LDOUBLEV/AutoLog
cd AutoLog
pip3 install -r requirements.txt
python3 setup.py bdist_wheel
pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl
cd ../
```
# 目录介绍
```bash
tests/
├── ocr_det_params.txt # 测试OCR检测模型的参数配置文件
├── ocr_rec_params.txt # 测试OCR识别模型的参数配置文件
└── prepare.sh # 完成test.sh运行所需要的数据和模型下载
└── test.sh # 根据
```
# 使用方法
test.sh包含四种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是:
- 模式1 lite_train_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
```
bash test/prepare.sh ./tests/ocr_det_params.txt 'lite_train_infer'
bash tests/test.sh ./tests/ocr_det_params.txt 'lite_train_infer'
```
- 模式2 whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理;
```
bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_infer'
bash tests/test.sh ./tests/ocr_det_params.txt 'whole_infer'
```
- 模式3 infer 不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度;
```
bash tests/prepare.sh ./tests/ocr_det_params.txt 'infer'
用法1:
bash tests/test.sh ./tests/ocr_det_params.txt 'infer'
用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
bash tests/test.sh ./tests/ocr_det_params.txt 'infer' '1'
```
模式4: whole_train_infer , CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度
```
bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_train_infer'
bash tests/test.sh ./tests/ocr_det_params.txt 'whole_train_infer'
```
......@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model, load_pretrained_params
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.utility import print_dict
import tools.program as program
......@@ -54,13 +54,13 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"]
if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type']
else:
model_type = None
best_model_dict = init_model(config, model)
best_model_dict = load_dygraph_params(config, model, logger, None)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
......@@ -71,7 +71,7 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn)
eval_class, model_type, extra_input)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import pickle
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.utility import print_dict
import tools.program as program
def main():
global_config = config['Global']
# build dataloader
config['Eval']['dataset']['name'] = config['Train']['dataset']['name']
config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][
'data_dir']
config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
'label_file_list']
eval_dataloader = build_dataloader(config, 'Eval', device, logger)
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
#set return_features = True
config['Architecture']["Head"]["return_feats"] = True
model = build_model(config['Architecture'])
best_model_dict = load_dygraph_params(config, model, logger, None)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
logger.info('{}:{}'.format(k, v))
# get features from train data
char_center = program.get_center(model, eval_dataloader, post_process_class)
#serialize to disk
with open("train_center.pkl", 'wb') as f:
pickle.dump(char_center, f)
return
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main()
......@@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger):
]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SAR":
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
......@@ -60,6 +66,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
if arch_config["algorithm"] == "NRTR":
infer_shape = [1, 32, 100]
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
model = to_static(
......@@ -93,6 +101,9 @@ def main():
for key in config["Architecture"]["Models"]:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
# just one final tensor needs to to exported for inference
config["Architecture"]["Models"][key][
"return_all_feats"] = False
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
......
......@@ -131,14 +131,9 @@ def main(args):
img_list.append(img)
try:
img_list, cls_res, predict_time = text_classifier(img_list)
except:
except Exception as E:
logger.info(traceback.format_exc())
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
logger.info(E)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
......
......@@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
import json
logger = get_logger()
......@@ -89,6 +89,14 @@ class TextDetector(object):
postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3
elif self.det_algorithm == "PSE":
postprocess_params['name'] = 'PSEPostProcess'
postprocess_params["thresh"] = args.det_pse_thresh
postprocess_params["box_thresh"] = args.det_pse_box_thresh
postprocess_params["min_area"] = args.det_pse_min_area
postprocess_params["box_type"] = args.det_pse_box_type
postprocess_params["scale"] = args.det_pse_scale
self.det_pse_box_type = args.det_pse_box_type
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
......@@ -209,7 +217,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3]
elif self.det_algorithm == 'DB':
elif self.det_algorithm in ['DB', 'PSE']:
preds['maps'] = outputs[0]
else:
raise NotImplementedError
......@@ -217,7 +225,9 @@ class TextDetector(object):
#self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon:
if (self.det_algorithm == "SAST" and
self.det_sast_polygon) or (self.det_algorithm == "PSE" and
self.det_pse_box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
......@@ -243,6 +253,7 @@ if __name__ == "__main__":
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
save_results = []
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
......@@ -256,8 +267,11 @@ if __name__ == "__main__":
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
save_pred = os.path.basename(image_file) + "\t" + str(
json.dumps(np.array(dt_boxes).astype(np.int32).tolist())) + "\n"
save_results.append(save_pred)
logger.info(save_pred)
logger.info("The predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
......@@ -265,5 +279,8 @@ if __name__ == "__main__":
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
f.writelines(save_results)
f.close()
if args.benchmark:
text_detector.autolog.report()
......@@ -74,7 +74,7 @@ class TextE2E(object):
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
self.predictor, self.input_tensor, self.output_tensors, _ = utility.create_predictor(
args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import os
import sys
from PIL import Image
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
......@@ -38,26 +38,34 @@ logger = get_logger()
class TextRecognizer(object):
def __init__(self, args):
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
postprocess_params = {
'name': 'CTCLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
if self.rec_algorithm == "SRN":
postprocess_params = {
'name': 'SRNLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "RARE":
postprocess_params = {
'name': 'AttnLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == 'NRTR':
postprocess_params = {
'name': 'NRTRLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "SAR":
postprocess_params = {
'name': 'SARLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
......@@ -87,9 +95,19 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
if self.rec_algorithm == 'NRTR':
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize([100, 32], Image.ANTIALIAS)
img = np.array(img)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
return norm_img.astype(np.float32) / 128. - 1.
assert imgC == img.shape[2]
if self.character_type == "ch":
imgW = int((32 * max_wh_ratio))
max_wh_ratio = max(max_wh_ratio, imgW / imgH)
imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
......@@ -177,6 +195,41 @@ class TextRecognizer(object):
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def resize_norm_img_sar(self, img, image_shape,
width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
......@@ -199,11 +252,19 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm != "SRN":
if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR":
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
......@@ -249,17 +310,38 @@ class TextRecognizer(object):
if self.benchmark:
self.autolog.times.stamp()
preds = {"predict": outputs[2]}
elif self.rec_algorithm == "SAR":
valid_ratios = np.concatenate(valid_ratios)
inputs = [
norm_img_batch,
valid_ratios,
]
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[
i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = outputs[0]
else:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = outputs[0]
if len(outputs) != 1:
preds = outputs
else:
preds = outputs[0]
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
......@@ -278,7 +360,7 @@ def main(args):
if args.warmup:
img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
for i in range(2):
res = text_recognizer([img])
res = text_recognizer([img] * int(args.rec_batch_num))
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
......
......@@ -173,6 +173,9 @@ def main(args):
logger.info("The predict total time is {}".format(time.time() - _st))
logger.info("\nThe predict total time is {}".format(total_time))
if args.benchmark:
text_sys.text_detector.autolog.report()
text_sys.text_recognizer.autolog.report()
if __name__ == "__main__":
......
......@@ -35,7 +35,7 @@ def init_args():
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=10)
parser.add_argument("--min_subgraph_size", type=int, default=15)
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
......@@ -63,11 +63,17 @@ def init_args():
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
# PSE parmas
parser.add_argument("--det_pse_thresh", type=float, default=0)
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
parser.add_argument("--det_pse_min_area", type=float, default=16)
parser.add_argument("--det_pse_box_type", type=str, default='box')
parser.add_argument("--det_pse_scale", type=int, default=1)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
......@@ -236,11 +242,11 @@ def create_predictor(args, mode, logger):
max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape)
elif mode == "rec":
min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
min_input_shape = {"x": [1, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
elif mode == "cls":
min_input_shape = {"x": [args.rec_batch_num, 3, 48, 10]}
min_input_shape = {"x": [1, 3, 48, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else:
......@@ -261,10 +267,11 @@ def create_predictor(args, mode, logger):
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
if args.precision == "fp16":
config.enable_mkldnn_bfloat16()
# enable memory optim
config.enable_memory_optim()
#config.disable_glog_info()
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
if mode == 'table':
......
......@@ -34,23 +34,21 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.utility import get_image_file_list
import tools.program as program
def draw_det_res(dt_boxes, config, img, img_name):
def draw_det_res(dt_boxes, config, img, img_name, save_path):
if len(dt_boxes) > 0:
import cv2
src_im = img
for box in dt_boxes:
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/"
if not os.path.exists(save_det_path):
os.makedirs(save_det_path)
save_path = os.path.join(save_det_path, os.path.basename(img_name))
if not os.path.exists(save_path):
os.makedirs(save_path)
save_path = os.path.join(save_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The detected Image saved in {}".format(save_path))
......@@ -61,8 +59,7 @@ def main():
# build model
model = build_model(config['Architecture'])
init_model(config, model)
_ = load_dygraph_params(config, model, logger, None)
# build post process
post_process_class = build_post_process(config['PostProcess'])
......@@ -96,17 +93,41 @@ def main():
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, shape_list)
boxes = post_result[0]['points']
# write result
src_img = cv2.imread(file)
dt_boxes_json = []
for box in boxes:
tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist()
dt_boxes_json.append(tmp_json)
# parser boxes if post_result is dict
if isinstance(post_result, dict):
det_box_json = {}
for k in post_result.keys():
boxes = post_result[k][0]['points']
dt_boxes_list = []
for box in boxes:
tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist()
dt_boxes_list.append(tmp_json)
det_box_json[k] = dt_boxes_list
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results_{}/".format(k)
draw_det_res(boxes, config, src_img, file, save_det_path)
else:
boxes = post_result[0]['points']
dt_boxes_json = []
# write result
for box in boxes:
tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist()
dt_boxes_json.append(tmp_json)
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/"
draw_det_res(boxes, config, src_img, file, save_det_path)
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
src_img = cv2.imread(file)
draw_det_res(boxes, config, src_img, file)
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/"
draw_det_res(boxes, config, src_img, file, save_det_path)
logger.info("success!")
......
......@@ -74,6 +74,10 @@ def main():
'image', 'encoder_word_pos', 'gsrm_word_pos',
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
]
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = [
'image', 'valid_ratio'
]
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
......@@ -106,11 +110,16 @@ def main():
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
if config['Architecture']['algorithm'] == "SAR":
valid_ratio = np.expand_dims(batch[-1], axis=0)
img_metas = [paddle.to_tensor(valid_ratio)]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
elif config['Architecture']['algorithm'] == "SAR":
preds = model(images, img_metas)
else:
preds = model(images)
post_result = post_process_class(preds)
......@@ -121,7 +130,7 @@ def main():
if len(post_result[key][0]) >= 2:
rec_info[key] = {
"label": post_result[key][0][0],
"score": post_result[key][0][1],
"score": float(post_result[key][0][1]),
}
info = json.dumps(rec_info)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment