Commit 76320bf0 authored by littletomatodonkey's avatar littletomatodonkey
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dev/add_thread_pred

parents e19bedf5 824ceca6
...@@ -14,12 +14,13 @@ Global: ...@@ -14,12 +14,13 @@ Global:
load_static_weights: True load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: infer_img:
save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt
Architecture: Architecture:
model_type: det model_type: det
algorithm: SAST algorithm: SAST
......
Global:
use_gpu: True
epoch_num: 600
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/pgnet_r50_vd_totaltext/
save_epoch_step: 10
# evaluation is run every 0 iterationss after the 1000th iteration
eval_batch_step: [ 0, 1000 ]
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: False
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img:
valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
character_dict_path: ppocr/utils/ic15_dict.txt
character_type: EN
max_text_length: 50 # the max length in seq
max_text_nums: 30 # the max seq nums in a pic
tcl_len: 64
Architecture:
model_type: e2e
algorithm: PGNet
Transform:
Backbone:
name: ResNet
layers: 50
Neck:
name: PGFPN
Head:
name: PGHead
Loss:
name: PGLoss
tcl_bs: 64
max_text_length: 50 # the same as Global: max_text_length
max_text_nums: 30 # the same as Global:max_text_nums
pad_num: 36 # the length of dict for pad
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
learning_rate: 0.001
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: PGPostProcess
score_thresh: 0.5
Metric:
name: E2EMetric
character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e
Train:
dataset:
name: PGDataSet
label_file_list: [.././train_data/total_text/train/]
ratio_list: [1.0]
data_format: icdar #two data format: icdar/textnet
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24
min_text_size: 4
max_text_size: 512
- KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader:
shuffle: True
drop_last: True
batch_size_per_card: 14
num_workers: 16
Eval:
dataset:
name: PGDataSet
data_dir: ./train_data/
label_file_list: [./train_data/total_text/test/]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- E2ELabelEncode:
- E2EResizeForTest:
max_side_len: 768
- NormalizeImage:
scale: 1./255.
mean: [ 0.485, 0.456, 0.406 ]
std: [ 0.229, 0.224, 0.225 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
\ No newline at end of file
...@@ -12,7 +12,8 @@ inference 模型(`paddle.jit.save`保存的模型) ...@@ -12,7 +12,8 @@ inference 模型(`paddle.jit.save`保存的模型)
- [一、训练模型转inference模型](#训练模型转inference模型) - [一、训练模型转inference模型](#训练模型转inference模型)
- [检测模型转inference模型](#检测模型转inference模型) - [检测模型转inference模型](#检测模型转inference模型)
- [识别模型转inference模型](#识别模型转inference模型) - [识别模型转inference模型](#识别模型转inference模型)
- [方向分类模型转inference模型](#方向分类模型转inference模型) - [方向分类模型转inference模型](#方向分类模型转inference模型)
- [端到端模型转inference模型](#端到端模型转inference模型)
- [二、文本检测模型推理](#文本检测模型推理) - [二、文本检测模型推理](#文本检测模型推理)
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理) - [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
...@@ -27,10 +28,13 @@ inference 模型(`paddle.jit.save`保存的模型) ...@@ -27,10 +28,13 @@ inference 模型(`paddle.jit.save`保存的模型)
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理) - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [5. 多语言模型的推理](#多语言模型的推理) - [5. 多语言模型的推理](#多语言模型的推理)
- [四、方向分类模型推理](#方向识别模型推理) - [四、端到端模型推理](#端到端模型推理)
- [1. PGNet端到端模型推理](#PGNet端到端模型推理)
- [五、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理) - [1. 方向分类模型推理](#方向分类模型推理)
- [、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理) - [、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理) - [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
- [2. 其他模型推理](#其他模型推理) - [2. 其他模型推理](#其他模型推理)
...@@ -118,6 +122,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo ...@@ -118,6 +122,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略 ├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
└── inference.pdmodel # 分类inference模型的program文件 └── inference.pdmodel # 分类inference模型的program文件
``` ```
<a name="端到端模型转inference模型"></a>
### 端到端模型转inference模型
下载端到端模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar && tar xf ./ch_lite/ch_ppocr_mobile_v2.0_cls_train.tar -C ./ch_lite/
```
端到端模型转inference模型与检测的方式相同,如下:
```
# -c 后面设置训练算法的yml配置文件
# -o 配置可选参数
# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
# Global.load_static_weights 参数需要设置为 False。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e/
```
转换成功后,在目录下有三个文件:
```
/inference/e2e/
├── inference.pdiparams # 分类inference模型的参数文件
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
└── inference.pdmodel # 分类inference模型的program文件
```
<a name="文本检测模型推理"></a> <a name="文本检测模型推理"></a>
## 二、文本检测模型推理 ## 二、文本检测模型推理
...@@ -332,8 +362,38 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" - ...@@ -332,8 +362,38 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" -
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904) Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
``` ```
<a name="端到端模型推理"></a>
## 四、端到端模型推理
端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
<a name="PGNet端到端模型推理"></a>
### 1. PGNet端到端模型推理
#### (1). 四边形文本检测模型(ICDAR2015)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
#### (2). 弯曲文本检测模型(Total-Text)
和四边形文本检测模型共用一个推理模型
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
```
python3.7 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
```
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pgnet.jpg)
<a name="方向分类模型推理"></a> <a name="方向分类模型推理"></a>
## 、方向分类模型推理 ## 、方向分类模型推理
下面将介绍方向分类模型推理。 下面将介绍方向分类模型推理。
...@@ -358,7 +418,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982] ...@@ -358,7 +418,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
``` ```
<a name="文本检测、方向分类和文字识别串联推理"></a> <a name="文本检测、方向分类和文字识别串联推理"></a>
## 、文本检测、方向分类和文字识别串联推理 ## 、文本检测、方向分类和文字识别串联推理
<a name="超轻量中文OCR模型推理"></a> <a name="超轻量中文OCR模型推理"></a>
### 1. 超轻量中文OCR模型推理 ### 1. 超轻量中文OCR模型推理
......
# 多语言模型
**近期更新**
- 2021.4.9 支持**80种**语言的检测和识别
- 2021.4.9 支持**轻量高精度**英文模型检测识别
- [1 安装](#安装)
- [1.1 paddle 安装](#paddle安装)
- [1.2 paddleocr package 安装](#paddleocr_package_安装)
- [2 快速使用](#快速使用)
- [2.1 命令行运行](#命令行运行)
- [2.1.1 整图预测](#bash_检测+识别)
- [2.1.2 识别预测](#bash_识别)
- [2.1.3 检测预测](#bash_检测)
- [2.2 python 脚本运行](#python_脚本运行)
- [2.2.1 整图预测](#python_检测+识别)
- [2.2.2 识别预测](#python_识别)
- [2.2.3 检测预测](#python_检测)
- [3 自定义训练](#自定义训练)
- [4 支持语种及缩写](#语种缩写)
<a name="安装"></a>
## 1 安装
<a name="paddle安装"></a>
### 1.1 paddle 安装
```
# cpu
pip install paddlepaddle
# gpu
pip instll paddlepaddle-gpu
```
<a name="paddleocr_package_安装"></a>
### 1.2 paddleocr package 安装
pip 安装
```
pip install "paddleocr>=2.0.4" # 推荐使用2.0.4版本
```
本地构建并安装
```
python3 setup.py bdist_wheel
pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x是paddleocr的版本号
```
<a name="快速使用"></a>
## 2 快速使用
<a name="命令行运行"></a>
### 2.1 命令行运行
查看帮助信息
```
paddleocr -h
```
* 整图预测(检测+识别)
Paddleocr目前支持80个语种,可以通过修改--lang参数进行切换,具体支持的[语种](#语种缩写)可查看表格。
``` bash
paddleocr --image_dir doc/imgs/japan_2.jpg --lang=japan
```
![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs/japan_2.jpg)
结果是一个list,每个item包含了文本框,文字和识别置信度
```text
[[[671.0, 60.0], [847.0, 63.0], [847.0, 104.0], [671.0, 102.0]], ('もちもち', 0.9993342)]
[[[394.0, 82.0], [536.0, 77.0], [538.0, 127.0], [396.0, 132.0]], ('天然の', 0.9919842)]
[[[880.0, 89.0], [1014.0, 93.0], [1013.0, 127.0], [879.0, 124.0]], ('とろっと', 0.9976762)]
[[[1067.0, 101.0], [1294.0, 101.0], [1294.0, 138.0], [1067.0, 138.0]], ('後味のよい', 0.9988712)]
......
```
* 识别预测
```bash
paddleocr --image_dir doc/imgs_words/japan/1.jpg --det false --lang=japan
```
![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs_words/japan/1.jpg)
结果是一个tuple,返回识别结果和识别置信度
```text
('したがって', 0.99965394)
```
* 检测预测
```
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
```
结果是一个list,每个item只包含文本框
```
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]]
......
```
<a name="python_脚本运行"></a>
### 2.2 python 脚本运行
ppocr 也支持在python脚本中运行,便于嵌入到您自己的代码中:
* 整图预测(检测+识别)
```
from paddleocr import PaddleOCR, draw_ocr
# 同样也是通过修改 lang 参数切换语种
ocr = PaddleOCR(lang="korean") # 首次执行会自动下载模型文件
img_path = 'doc/imgs/korean_1.jpg '
result = ocr.ocr(img_path)
# 打印检测框和识别结果
for line in result:
print(line)
# 可视化
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/korean.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
结果可视化:
![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs_results/korean.jpg)
* 识别预测
```
from paddleocr import PaddleOCR
ocr = PaddleOCR(lang="german")
img_path = 'PaddleOCR/doc/imgs_words/german/1.jpg'
result = ocr.ocr(img_path, det=False, cls=True)
for line in result:
print(line)
```
![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs_words/german/1.jpg)
结果是一个tuple,只包含识别结果和识别置信度
```
('leider auch jetzt', 0.97538936)
```
* 检测预测
```python
from paddleocr import PaddleOCR, draw_ocr
ocr = PaddleOCR() # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg'
result = ocr.ocr(img_path, rec=False)
for line in result:
print(line)
# 显示结果
from PIL import Image
image = Image.open(img_path).convert('RGB')
im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
结果是一个list,每个item只包含文本框
```bash
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]]
......
```
结果可视化 :
![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs_results/whl/12_det.jpg)
ppocr 还支持方向分类, 更多使用方式请参考:[whl包使用说明](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_ch/whl.md)
<a name="自定义训练"></a>
## 3 自定义训练
ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别模型可以参考 [法语配置文件](../../configs/rec/multi_language/rec_french_lite_train.yml)
修改训练数据路径、字典等参数。
具体数据准备、训练过程可参考:[文本检测](../doc_ch/detection.md)[文本识别](../doc_ch/recognition.md),更多功能如预测部署、
数据标注等功能可以阅读完整的[文档教程](../../README_ch.md)
<a name="语种缩写"></a>
## 4 支持语种及缩写
| 语种 | 描述 | 缩写 |
| --- | --- | --- |
|中文|chinese and english|ch|
|英文|english|en|
|法文|french|fr|
|德文|german|german|
|日文|japan|japan|
|韩文|korean|korean|
|中文繁体|chinese traditional |ch_tra|
|意大利文| Italian |it|
|西班牙文|Spanish |es|
|葡萄牙文| Portuguese|pt|
|俄罗斯文|Russia|ru|
|阿拉伯文|Arabic|ar|
|印地文|Hindi|hi|
|维吾尔|Uyghur|ug|
|波斯文|Persian|fa|
|乌尔都文|Urdu|ur|
|塞尔维亚文(latin)| Serbian(latin) |rs_latin|
|欧西坦文|Occitan |oc|
|马拉地文|Marathi|mr|
|尼泊尔文|Nepali|ne|
|塞尔维亚文(cyrillic)|Serbian(cyrillic)|rs_cyrillic|
|保加利亚文|Bulgarian |bg|
|乌克兰文|Ukranian|uk|
|白俄罗斯文|Belarusian|be|
|泰卢固文|Telugu |te|
|卡纳达文|Kannada |kn|
|泰米尔文|Tamil |ta|
|南非荷兰文 |Afrikaans |af|
|阿塞拜疆文 |Azerbaijani |az|
|波斯尼亚文|Bosnian|bs|
|捷克文|Czech|cs|
|威尔士文 |Welsh |cy|
|丹麦文 |Danish|da|
|爱沙尼亚文 |Estonian |et|
|爱尔兰文 |Irish |ga|
|克罗地亚文|Croatian |hr|
|匈牙利文|Hungarian |hu|
|印尼文|Indonesian|id|
|冰岛文 |Icelandic|is|
|库尔德文 |Kurdish|ku|
|立陶宛文|Lithuanian |lt|
|拉脱维亚文 |Latvian |lv|
|毛利文|Maori|mi|
|马来文 |Malay|ms|
|马耳他文 |Maltese |mt|
|荷兰文 |Dutch |nl|
|挪威文 |Norwegian |no|
|波兰文|Polish |pl|
| 罗马尼亚文|Romanian |ro|
| 斯洛伐克文|Slovak |sk|
| 斯洛文尼亚文|Slovenian |sl|
| 阿尔巴尼亚文|Albanian |sq|
| 瑞典文|Swedish |sv|
| 西瓦希里文|Swahili |sw|
| 塔加洛文|Tagalog |tl|
| 土耳其文|Turkish |tr|
| 乌兹别克文|Uzbek |uz|
| 越南文|Vietnamese |vi|
| 蒙古文|Mongolian |mn|
| 阿巴扎文|Abaza |abq|
| 阿迪赫文|Adyghe |ady|
| 卡巴丹文|Kabardian |kbd|
| 阿瓦尔文|Avar |ava|
| 达尔瓦文|Dargwa |dar|
| 因古什文|Ingush |inh|
| 拉克文|Lak |lbe|
| 莱兹甘文|Lezghian |lez|
|塔巴萨兰文 |Tabassaran |tab|
| 比尔哈文|Bihari |bh|
| 迈蒂利文|Maithili |mai|
| 昂加文|Angika |ang|
| 孟加拉文|Bhojpuri |bho|
| 摩揭陀文 |Magahi |mah|
| 那格浦尔文|Nagpur |sck|
| 尼瓦尔文|Newari |new|
| 保加利亚文 |Goan Konkani|gom|
| 沙特阿拉伯文|Saudi Arabia|sa|
# 端对端OCR算法-PGNet
- [一、简介](#简介)
- [二、环境配置](#环境配置)
- [三、快速使用](#快速使用)
- [四、快速训练](#开始训练)
- [五、预测推理](#预测推理)
<a name="简介"></a>
##简介
OCR算法可以分为两阶段算法和端对端的算法。二阶段OCR算法一般分为两个部分,文本检测和文本识别算法,文件检测算法从图像中得到文本行的检测框,然后识别算法去识别文本框中的内容。而端对端OCR算法可以在一个算法中完成文字检测和文字识别,其基本思想是设计一个同时具有检测单元和识别模块的模型,共享其中两者的CNN特征,并联合训练。由于一个算法即可完成文字识别,端对端模型更小,速度更快。
### PGNet算法介绍
近些年来,端对端OCR算法得到了良好的发展,包括MaskTextSpotter系列、TextSnake、TextDragon、PGNet系列等算法。在这些算法中,PGNet算法具备其他算法不具备的优势,包括:
- 设计PGNet loss指导训练,不需要字符级别的标注
- 不需要NMS和ROI相关操作,加速预测
- 提出预测文本行内的阅读顺序模块;
- 提出基于图的修正模块(GRM)来进一步提高模型识别性能
- 精度更高,预测速度更快
PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf), 算法原理图如下所示:
![](../pgnet_framework.png)
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
其检测识别效果图如下:
![](../imgs_results/e2e_res_img293_pgnet.png)
![](../imgs_results/e2e_res_img295_pgnet.png)
<a name="环境配置"></a>
##环境配置
请先参考[快速安装](./installation.md)配置PaddleOCR运行环境。
*注意:也可以通过 whl 包安装使用PaddleOCR,具体参考[Paddleocr Package使用说明](./whl.md)。*
<a name="快速使用"></a>
##快速使用
### inference模型下载
本节以训练好的端到端模型为例,快速使用模型预测,首先下载训练好的端到端inference模型[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar)
```
mkdir inference && cd inference
# 下载英文端到端模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar && tar xf e2e_server_pgnetA_infer.tar
```
* windows 环境下如果没有安装wget,下载模型时可将链接复制到浏览器中下载,并解压放置在相应目录下
解压完毕后应有如下文件结构:
```
├── e2e_server_pgnetA_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
### 单张图像或者图像集合预测
```bash
# 预测image_dir指定的单张图像
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
# 预测image_dir指定的图像集合
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
# 如果想使用CPU进行预测,需设置use_gpu参数为False
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True --use_gpu=False
```
<a name="开始训练"></a>
##开始训练
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
###数据形式为icdar, 十六点标注数据
解压数据集和下载标注文件后,PaddleOCR/train_data/total_text/train/ 有两个文件夹,分别是:
```
/PaddleOCR/train_data/total_text/train/
|- rgb/ total_text数据集的训练数据
|- gt_0.png
| ...
|- total_text.txt total_text数据集的训练标注
```
提供的标注文件格式如下,中间用"\t"分隔:
```
" 图像文件名 json.dumps编码的图像标注信息"
rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}]
```
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。
`transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。
### 快速启动训练
模型训练一般分两步骤进行,第一步可以选择用合成数据训练,第二步加载第一步训练好的模型训练,这边我们提供了第一步训练好的模型,可以直接加载,从第二步开始训练
[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar)
```shell
cd PaddleOCR/
下载ResNet50_vd的动态图预训练模型
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar
可以得到以下的文件格式
./pretrain_models/train_step1/
└─ best_accuracy.pdopt
└─ best_accuracy.states
└─ best_accuracy.pdparams
```
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```shell
# 单机单卡训练 e2e 模型
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False
```
上述指令中,通过-c 选择训练使用configs/e2e/e2e_r50_vd_pg.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001
```
#### 断点训练
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model
```
**注意**`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
<a name="预测推理"></a>
## 预测推理
PaddleOCR计算三个OCR端到端相关的指标,分别是:Precision、Recall、Hmean。
运行如下代码,根据配置文件`e2e_r50_vd_pg.yml``save_res_path`指定的测试集检测结果文件,计算评估指标。
评估时设置后处理参数`max_side_len=768`,使用不同数据集、不同模型训练,可调整参数进行优化
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
```shell
python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{path/to/weights}/best_accuracy"
```
### 测试端到端效果
测试单张图像的端到端识别效果
```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
```
测试文件夹下所有图像的端到端识别效果
```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
```
###转为推理模型
### (1). 四边形文本检测模型(ICDAR2015)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换:
```
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
### (2). 弯曲文本检测模型(Total-Text)
对于弯曲文本样例
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
```
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pgnet.jpg)
doc/imgs_results/whl/12_det.jpg

166 KB | W: | H:

doc/imgs_results/whl/12_det.jpg

410 KB | W: | H:

doc/imgs_results/whl/12_det.jpg
doc/imgs_results/whl/12_det.jpg
doc/imgs_results/whl/12_det.jpg
doc/imgs_results/whl/12_det.jpg
  • 2-up
  • Swipe
  • Onion skin
...@@ -66,6 +66,46 @@ model_urls = { ...@@ -66,6 +66,46 @@ model_urls = {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt' 'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
},
'te': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt'
},
'ka': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt'
},
'latin': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt'
},
'arabic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
},
'cyrillic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
} }
}, },
'cls': 'cls':
...@@ -233,6 +273,29 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -233,6 +273,29 @@ class PaddleOCR(predict_system.TextSystem):
postprocess_params.__dict__.update(**kwargs) postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang lang = postprocess_params.lang
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'en', 'es', 'et', 'fr',
'ga', 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi',
'ms', 'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin',
'sk', 'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
'gom', 'sa', 'bgc'
]
if lang in latin_lang:
lang = "latin"
elif lang in arabic_lang:
lang = "arabic"
elif lang in cyrillic_lang:
lang = "cyrillic"
elif lang in devanagari_lang:
lang = "devanagari"
assert lang in model_urls[ assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format( 'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang) model_urls['rec'].keys(), lang)
......
...@@ -34,6 +34,7 @@ import paddle.distributed as dist ...@@ -34,6 +34,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -54,7 +55,7 @@ signal.signal(signal.SIGTERM, term_mp) ...@@ -54,7 +55,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
...@@ -72,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -72,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None):
else: else:
use_shared_memory = True use_shared_memory = True
if mode == "Train": if mode == "Train":
#Distribute data to multiple cards # Distribute data to multiple cards
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
else: else:
#Distribute data to single card # Distribute data to single card
batch_sampler = BatchSampler( batch_sampler = BatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
......
...@@ -28,6 +28,7 @@ from .label_ops import * ...@@ -28,6 +28,7 @@ from .label_ops import *
from .east_process import * from .east_process import *
from .sast_process import * from .sast_process import *
from .pg_process import *
def transform(data, ops=None): def transform(data, ops=None):
......
...@@ -187,6 +187,34 @@ class CTCLabelEncode(BaseRecLabelEncode): ...@@ -187,6 +187,34 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character return dict_character
class E2ELabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
self.pad_num = len(self.dict) # the length to pad
def __call__(self, data):
text_label_index_list, temp_text = [], []
texts = data['strs']
for text in texts:
text = text.lower()
temp_text = []
for c_ in text:
if c_ in self.dict:
temp_text.append(self.dict[c_])
temp_text = temp_text + [self.pad_num] * (self.max_text_len -
len(temp_text))
text_label_index_list.append(temp_text)
data['strs'] = np.array(text_label_index_list)
return data
class AttnLabelEncode(BaseRecLabelEncode): class AttnLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -197,7 +197,6 @@ class DetResizeForTest(object): ...@@ -197,7 +197,6 @@ class DetResizeForTest(object):
sys.exit(0) sys.exit(0)
ratio_h = resize_h / float(h) ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w) ratio_w = resize_w / float(w)
# return img, np.array([h, w])
return img, [ratio_h, ratio_w] return img, [ratio_h, ratio_w]
def resize_image_type2(self, img): def resize_image_type2(self, img):
...@@ -206,7 +205,6 @@ class DetResizeForTest(object): ...@@ -206,7 +205,6 @@ class DetResizeForTest(object):
resize_w = w resize_w = w
resize_h = h resize_h = h
# Fix the longer side
if resize_h > resize_w: if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h ratio = float(self.resize_long) / resize_h
else: else:
...@@ -223,3 +221,72 @@ class DetResizeForTest(object): ...@@ -223,3 +221,72 @@ class DetResizeForTest(object):
ratio_w = resize_w / float(w) ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w] return img, [ratio_h, ratio_w]
class E2EResizeForTest(object):
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
This diff is collapsed.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
from paddle.io import Dataset
from .imaug import transform, create_operators
import random
class PGDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PGDataSet, self).__init__()
self.logger = logger
self.seed = seed
self.mode = mode
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
self.data_format = dataset_config.get('data_format', 'icdar')
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
self.data_format)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def extract_polys(self, poly_txt_path):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys, txt_tags, txts = [], [], []
with open(poly_txt_path) as f:
for line in f.readlines():
poly_str, txt = line.strip().split('\t')
poly = list(map(float, poly_str.split(',')))
if self.mode.lower() == "eval":
while len(poly) < 100:
poly.append(-1)
text_polys.append(
np.array(
poly, dtype=np.float32).reshape(-1, 2))
txts.append(txt)
txt_tags.append(txt == '###')
return np.array(list(map(np.array, text_polys))), \
np.array(txt_tags, dtype=np.bool), txts
def extract_info_textnet(self, im_fn, img_dir=''):
"""
Extract information from line in textnet format.
"""
info_list = im_fn.split('\t')
img_path = ''
for ext in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
]:
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
break
if img_path == '':
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
info_list[0], img_dir))
nBox = (len(info_list) - 1) // 9
wordBBs, txts, txt_tags = [], [], []
for n in range(0, nBox):
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
txt = info_list[(n + 1) * 9]
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, data_source in enumerate(file_list):
image_files = []
if data_format == 'icdar':
image_files = [(data_source, x) for x in
os.listdir(os.path.join(data_source, 'rgb'))
if x.split('.')[-1] in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
'tiff', 'gif', 'JPG'
]]
elif data_format == 'textnet':
with open(data_source) as f:
image_files = [(data_source, x.strip())
for x in f.readlines()]
else:
print("Unrecognized data format...")
exit(-1)
random.seed(self.seed)
image_files = random.sample(
image_files, round(len(image_files) * ratio_list[idx]))
data_lines.extend(image_files)
return data_lines
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_path, data_line = self.data_lines[file_idx]
try:
if self.data_format == 'icdar':
im_path = os.path.join(data_path, 'rgb', data_line)
if self.mode.lower() == "eval":
poly_path = os.path.join(data_path, 'poly_gt',
data_line.split('.')[0] + '.txt')
else:
poly_path = os.path.join(data_path, 'poly',
data_line.split('.')[0] + '.txt')
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
else:
image_dir = os.path.join(os.path.dirname(data_path), 'image')
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
data_line, image_dir)
data = {
'img_path': im_path,
'polys': text_polys,
'tags': text_tags,
'strs': text_strs
}
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
self.data_idx_order_list[idx], e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
...@@ -29,10 +29,11 @@ def build_loss(config): ...@@ -29,10 +29,11 @@ def build_loss(config):
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
support_dict = [ support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss' 'SRNLoss', 'PGLoss']
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
import paddle
from .det_basic_loss import DiceLoss
from ppocr.utils.e2e_utils.extract_batchsize import pre_process
class PGLoss(nn.Layer):
def __init__(self,
tcl_bs,
max_text_length,
max_text_nums,
pad_num,
eps=1e-6,
**kwargs):
super(PGLoss, self).__init__()
self.tcl_bs = tcl_bs
self.max_text_nums = max_text_nums
self.max_text_length = max_text_length
self.pad_num = pad_num
self.dice_loss = DiceLoss(eps=eps)
def border_loss(self, f_border, l_border, l_score, l_mask):
l_border_split, l_border_norm = paddle.tensor.split(
l_border, num_or_sections=[4, 1], axis=1)
f_border_split = f_border
b, c, h, w = l_border_norm.shape
l_border_norm_split = paddle.expand(
x=l_border_norm, shape=[b, 4 * c, h, w])
b, c, h, w = l_score.shape
l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
b, c, h, w = l_mask.shape
l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
border_diff = l_border_split - f_border_split
abs_border_diff = paddle.abs(border_diff)
border_sign = abs_border_diff < 1.0
border_sign = paddle.cast(border_sign, dtype='float32')
border_sign.stop_gradient = True
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
(abs_border_diff - 0.5) * (1.0 - border_sign)
border_out_loss = l_border_norm_split * border_in_loss
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
return border_loss
def direction_loss(self, f_direction, l_direction, l_score, l_mask):
l_direction_split, l_direction_norm = paddle.tensor.split(
l_direction, num_or_sections=[2, 1], axis=1)
f_direction_split = f_direction
b, c, h, w = l_direction_norm.shape
l_direction_norm_split = paddle.expand(
x=l_direction_norm, shape=[b, 2 * c, h, w])
b, c, h, w = l_score.shape
l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
b, c, h, w = l_mask.shape
l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
direction_diff = l_direction_split - f_direction_split
abs_direction_diff = paddle.abs(direction_diff)
direction_sign = abs_direction_diff < 1.0
direction_sign = paddle.cast(direction_sign, dtype='float32')
direction_sign.stop_gradient = True
direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
(abs_direction_diff - 0.5) * (1.0 - direction_sign)
direction_out_loss = l_direction_norm_split * direction_in_loss
direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
(paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
return direction_loss
def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
f_char = paddle.transpose(f_char, [0, 2, 3, 1])
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
tcl_pos = paddle.cast(tcl_pos, dtype=int)
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
f_tcl_char = paddle.reshape(f_tcl_char,
[-1, 64, 37]) # len(Lexicon_Table)+1
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
b, c, l = tcl_mask.shape
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
tcl_mask_fg.stop_gradient = True
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
-20.0)
f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
N, B, _ = f_tcl_char_ld.shape
input_lengths = paddle.to_tensor([N] * B, dtype='int64')
cost = paddle.nn.functional.ctc_loss(
log_probs=f_tcl_char_ld,
labels=tcl_label,
input_lengths=input_lengths,
label_lengths=label_t,
blank=self.pad_num,
reduction='none')
cost = cost.mean()
return cost
def forward(self, predicts, labels):
images, tcl_maps, tcl_label_maps, border_maps \
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
# for all the batch_size
pos_list, pos_mask, label_list, label_t = pre_process(
label_list, pos_list, pos_mask, self.max_text_length,
self.max_text_nums, self.pad_num, self.tcl_bs)
f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
predicts['f_char']
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
border_loss = self.border_loss(f_border, border_maps, tcl_maps,
training_masks)
direction_loss = self.direction_loss(f_direction, direction_maps,
tcl_maps, training_masks)
ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
losses = {
'loss': loss_all,
"score_loss": score_loss,
"border_loss": border_loss,
"direction_loss": direction_loss,
"ctc_loss": ctc_loss
}
return losses
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment