Commit 406463ef authored by Khanh Tran's avatar Khanh Tran
Browse files

update from original repo

parents 4d22bf3a bc85ebd4
# 文字检测
本节以icdar15数据集为例,介绍PaddleOCR中检测模型的训练、评估与测试。
## 数据准备
icdar2015数据集可以从[官网](https://rrc.cvc.uab.es/?ch=4&com=downloads)下载到,首次下载需注册。
将下载到的数据集解压到工作目录下,假设解压在 PaddleOCR/train_data/ 下。另外,PaddleOCR将零散的标注文件整理成单独的标注文件
,您可以通过wget的方式进行下载。
```
# 在PaddleOCR路径下
cd PaddleOCR/
wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt
wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_label.txt
```
解压数据集和下载标注文件后,PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是:
```
/PaddleOCR/train_data/icdar2015/text_localization/
└─ icdar_c4_train_imgs/ icdar数据集的训练数据
└─ ch4_test_images/ icdar数据集的测试数据
└─ train_icdar2015_label.txt icdar数据集的训练标注
└─ test_icdar2015_label.txt icdar数据集的测试标注
```
提供的标注文件格式为:
```
" 图像文件名 json.dumps编码的图像标注信息"
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]
```
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。
`transcription` 表示当前文本框的文字,在文本检测任务中并不需要这个信息。
如果您想在其他数据集上训练PaddleOCR,可以按照上述形式构建标注文件。
## 快速启动训练
首先下载pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet50_vd,
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。
```
cd PaddleOCR/
# 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
# 下载ResNet50的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
```
**启动训练**
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```
python3 tools/train.py -c configs/det/det_mv3_db.yml
```
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
```
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
```
## 指标评估
PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。
运行如下代码,根据配置文件det_db_mv3.yml中save_res_path指定的测试集检测结果文件,计算评估指标。
评估时设置后处理参数box_thresh=0.6,unclip_ratio=1.5,使用不同数据集、不同模型训练,可调整这两个参数进行优化
```
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
训练中模型参数默认保存在Global.save_model_dir目录下。在评估指标时,需要设置Global.checkpoints指向保存的参数文件。
比如:
```
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
* 注:box_thresh、unclip_ratio是DB后处理所需要的参数,在评估EAST模型时不需要设置
## 测试检测效果
测试单张图像的检测效果
```
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
```
测试DB模型时,调整后处理阈值,
```
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
测试文件夹下所有图像的检测效果
```
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
```
# 基于预测引擎推理
inference 模型(fluid.io.save_inference_model保存的模型)
一般是模型训练完成后保存的固化模型,多用于预测部署。
训练过程中保存的模型是checkpoints模型,保存的是模型的参数,多用于恢复训练等。
与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合与实际系统集成。更详细的介绍请参考文档[分类预测框架](https://paddleclas.readthedocs.io/zh_CN/latest/extension/paddle_inference.html).
接下来首先介绍如何将训练的模型转换成inference模型,然后将依次介绍文本检测、文本识别以及两者串联基于预测引擎推理。
## 训练模型转inference模型
### 检测模型转inference模型
下载超轻量级中文检测模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar && tar xf ./ch_lite/ch_det_mv3_db.tar -C ./ch_lite/
```
上述模型是以MobileNetV3为backbone训练的DB算法,将训练好的模型转换成inference模型只需要运行如下命令:
```
python3 tools/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./ch_lite/det_mv3_db/best_accuracy Global.save_inference_dir=./inference/det_db/
```
转inference模型时,使用的配置文件和训练时使用的配置文件相同。另外,还需要设置配置文件中的Global.checkpoints、Global.save_inference_dir参数。
其中Global.checkpoints指向训练中保存的模型参数文件,Global.save_inference_dir是生成的inference模型要保存的目录。
转换成功后,在save_inference_dir 目录下有两个文件:
```
inference/det_db/
└─ model 检测inference模型的program文件
└─ params 检测inference模型的参数文件
```
### 识别模型转inference模型
下载超轻量中文识别模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar && tar xf ./ch_lite/ch_rec_mv3_crnn.tar -C ./ch_lite/
```
识别模型转inference模型与检测的方式相同,如下:
```
python3 tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints=./ch_lite/rec_mv3_crnn/best_accuracy \
Global.save_inference_dir=./inference/rec_crnn/
```
如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的character_dict_path是否是所需要的字典文件。
转换成功后,在目录下有两个文件:
```
/inference/rec_crnn/
└─ model 识别inference模型的program文件
└─ params 识别inference模型的参数文件
```
## 文本检测模型推理
下面将介绍超轻量中文检测模型推理、DB文本检测模型推理和EAST文本检测模型推理。默认配置是根据DB文本检测模型推理设置的。由于EAST和DB算法差别很大,在推理时,需要通过传入相应的参数适配EAST文本检测算法。
### 1.超轻量中文检测模型推理
超轻量中文检测模型推理,可以执行如下命令:
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/"
```
可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
![](imgs_results/det_res_2.jpg)
通过设置参数det_max_side_len的大小,改变检测算法中图片规范化的最大值。当图片的长宽都小于det_max_side_len,则使用原图预测,否则将图片等比例缩放到最大值,进行预测。该参数默认设置为det_max_side_len=960. 如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以执行如下命令:
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_max_side_len=1200
```
如果想使用CPU进行预测,执行命令如下
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
```
### 2.DB文本检测模型推理
首先将DB文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)),可以使用如下命令进行转换:
```
# -c后面设置训练算法的yml配置文件
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.checkpoints="./models/det_r50_vd_db/best_accuracy" Global.save_inference_dir="./inference/det_db"
```
DB文本检测模型推理,可以执行如下命令:
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/"
```
可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
![](imgs_results/det_res_img_10_db.jpg)
**注意**:由于ICDAR2015数据集只有1000张训练图像,主要针对英文场景,所以上述模型对中文文本图像检测效果非常差。
### 3.EAST文本检测模型推理
首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_east.tar)),可以使用如下命令进行转换:
```
# -c后面设置训练算法的yml配置文件
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.checkpoints="./models/det_r50_vd_east/best_accuracy" Global.save_inference_dir="./inference/det_east"
```
EAST文本检测模型推理,需要设置参数det_algorithm,指定检测算法类型为EAST,可以执行如下命令:
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST"
```
可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
![](imgs_results/det_res_img_10_east.jpg)
**注意**:本代码库中EAST后处理中NMS采用的Python版本,所以预测速度比较耗时。如果采用C++版本,会有明显加速。
## 文本识别模型推理
下面将介绍超轻量中文识别模型推理和基于CTC损失的识别模型推理。**而基于Attention损失的识别模型推理还在调试中**。对于中文文本识别,建议优先选择基于CTC损失的识别模型,实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。
### 1.超轻量中文识别模型推理
超轻量中文识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --rec_model_dir="./inference/rec_crnn/"
```
![](imgs_words/ch/word_4.jpg)
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
Predicts of ./doc/imgs_words/ch/word_4.jpg:['实力活力', 0.89552695]
### 2.基于CTC损失的识别模型推理
我们以STAR-Net为例,介绍基于CTC损失的识别模型推理。 CRNN和Rosetta使用方式类似,不用设置识别算法参数rec_algorithm。
首先将STAR-Net文本识别训练过程中保存的模型,转换成inference model。以基于Resnet34_vd骨干网络,使用MJSynth和SynthText两个英文文本识别合成数据集训练
的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_ctc.tar)),可以使用如下命令进行转换:
```
# -c后面设置训练算法的yml配置文件
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/rec/rec_r34_vd_tps_bilstm_ctc.yml -o Global.checkpoints="./models/rec_r34_vd_tps_bilstm_ctc/best_accuracy" Global.save_inference_dir="./inference/starnet"
```
STAR-Net文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
```
### 3.基于Attention损失的识别模型推理
基于Attention损失的识别模型与ctc不同,需要额外设置识别算法参数 --rec_algorithm="RARE"
RARE 文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/sare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
```
![](imgs_words_en/word_336.png)
执行命令后,上面图像的识别结果如下:
Predicts of ./doc/imgs_words_en/word_336.png:['super', 0.9999555]
**注意**:由于上述模型是参考[DTRB](https://arxiv.org/abs/1904.01906)文本识别训练和评估流程,与超轻量级中文识别模型训练有两方面不同:
- 训练时采用的图像分辨率不同,训练上述模型采用的图像分辨率是[3,32,100],而中文模型训练时,为了保证长文本的识别效果,训练时采用的图像分辨率是[3, 32, 320]。预测推理程序默认的的形状参数是训练中文采用的图像分辨率,即[3, 32, 320]。因此,这里推理上述英文模型时,需要通过参数rec_image_shape设置识别图像的形状。
- 字符列表,DTRB论文中实验只是针对26个小写英文本母和10个数字进行实验,总共36个字符。所有大小字符都转成了小写字符,不在上面列表的字符都忽略,认为是空格。因此这里没有输入字符字典,而是通过如下命令生成字典.因此在推理时需要设置参数rec_char_type,指定为英文"en"。
```
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
```
## 文本检测、识别串联推理
### 1.超轻量中文OCR模型推理
在执行预测时,需要通过参数image_dir指定单张图像或者图像集合的路径、参数det_model_dir指定检测inference模型的路径和参数rec_model_dir指定识别inference模型的路径。可视化识别结果默认保存到 ./inference_results 文件夹里面。
```
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/"
```
执行命令后,识别结果图像如下:
![](imgs_results/2.jpg)
### 2.其他模型推理
如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型,下面给出基于EAST文本检测和STAR-Net文本识别执行命令:
```
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
```
执行命令后,识别结果图像如下:
![](imgs_results/img_10.jpg)
## 快速安装
经测试PaddleOCR可在glibc 2.23上运行,您也可以测试其他glibc版本或安装glic 2.23
PaddleOCR 工作环境
- PaddlePaddle1.7
- python3
- glibc 2.23
建议使用我们提供的docker运行PaddleOCR,有关docker使用请参考[链接](https://docs.docker.com/get-started/)
*如您希望使用 mac 或 windows直接运行预测代码,可以从第2步开始执行。*
1. (建议)准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。
```
# 切换到工作目录下
cd /home/Projects
# 首次运行需创建一个docker容器,再次运行时不需要运行当前命令
# 创建一个名字为ppocr的docker容器,并将当前目录映射到容器的/paddle目录下
如果您希望在CPU环境下使用docker,使用docker而不是nvidia-docker创建docker
sudo docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash
如果您的机器安装的是CUDA9,请运行以下命令创建容器
sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash
如果您的机器安装的是CUDA10,请运行以下命令创建容器
sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda10.0-cudnn7-dev /bin/bash
您也可以访问[DockerHub](https://hub.docker.com/r/paddlepaddle/paddle/tags/)获取与您机器适配的镜像。
# ctrl+P+Q可退出docker,重新进入docker使用如下命令
sudo docker container exec -it ppocr /bin/bash
```
注意:如果docker pull过慢,可以按照如下步骤手动下载后加载docker,以cuda9 docker为例,使用cuda10 docker只需要将cuda9改为cuda10即可。
```
# 下载CUDA9 docker的压缩文件,并解压
wget https://paddleocr.bj.bcebos.com/docker/docker_pdocr_cuda9.tar.gz
# 为减少下载时间,上传的docker image是压缩过的,需要解压使用
tar zxf docker_pdocr_cuda9.tar.gz
# 创建image
docker load < docker_pdocr_cuda9.tar
# 完成上述步骤后通过docker images检查是否加载了下载的镜像
docker images
# 执行docker images后如果有下面的输出,即可按照按照 步骤1 创建docker环境。
hub.baidubce.com/paddlepaddle/paddle latest-gpu-cuda9.0-cudnn7-dev f56310dcc829
```
2. 安装PaddlePaddle Fluid v1.7(暂不支持更高版本,适配工作进行中)
```
pip3 install --upgrade pip
如果您的机器安装的是CUDA9,请运行以下命令安装
python3 -m pip install paddlepaddle-gpu==1.7.2.post97 -i https://pypi.tuna.tsinghua.edu.cn/simple
如果您的机器安装的是CUDA10,请运行以下命令安装
python3 -m pip install paddlepaddle-gpu==1.7.2.post107 -i https://pypi.tuna.tsinghua.edu.cn/simple
如果您的机器是CPU,请运行以下命令安装
python3 -m pip install paddlepaddle==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
```
3. 克隆PaddleOCR repo代码
```
【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
如果因为网络问题无法pull成功,也可选择使用码云上的托管:
git clone https://gitee.com/paddlepaddle/PaddleOCR
注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
```
4. 安装第三方库
```
cd PaddleOCR
pip3 install -r requirments.txt
```
## 文字识别
### 数据准备
PaddleOCR 支持两种数据格式: `lmdb` 用于训练公开数据,调试算法; `通用数据` 训练自己的数据:
请按如下步骤设置数据集:
训练数据的默认存储路径是 `PaddleOCR/train_data`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录:
```
ln -sf <path/to/dataset> <path/to/paddle_detection>/train_data/dataset
```
* 数据下载
若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。
* 使用自己数据集:
若您希望使用自己的数据进行训练,请参考下文组织您的数据。
- 训练集
首先请将训练图片放入同一个文件夹(train_images),并用一个txt文件(rec_gt_train.txt)记录图片路径和标签。
* 注意: 默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错
```
" 图像文件名 图像标注信息 "
train_data/train_0001.jpg 简单可依赖
train_data/train_0002.jpg 用科技让复杂的世界更简单
```
PaddleOCR 提供了一份用于训练 icdar2015 数据集的标签文件,通过以下方式下载:
```
# 训练集标签
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_train.txt
# 测试集标签
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt
```
最终训练集应有如下文件结构:
```
|-train_data
|-ic15_data
|- rec_gt_train.txt
|- train
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
```
- 测试集
同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个rec_gt_test.txt,测试集的结构如下所示:
```
|-train_data
|-ic15_data
|- rec_gt_test.txt
|- test
|- word_001.jpg
|- word_002.jpg
|- word_003.jpg
| ...
```
- 字典
最后需要提供一个字典({word_dict_name}.txt),使模型在训练时,可以将所有出现的字符映射为字典的索引。
因此字典需要包含所有希望被正确识别的字符,{word_dict_name}.txt需要写成如下格式,并以 `utf-8` 编码格式保存:
```
l
d
a
d
r
n
```
word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,“and” 将被映射成 [2 5 1]
`ppocr/utils/ppocr_keys_v1.txt` 是一个包含6623个字符的中文字典,
`ppocr/utils/ic15_dict.txt` 是一个包含36个字符的英文字典,
您可以按需使用。
如需自定义dic文件,请修改 `configs/rec/rec_icdar15_train.yml` 中的 `character_dict_path` 字段, 并将 `character_type` 设置为 `ch`
### 启动训练
PaddleOCR提供了训练脚本、评估脚本和预测脚本,本节将以 CRNN 识别模型为例:
首先下载pretrain model,您可以下载训练好的模型在 icdar2015 数据上进行finetune
```
cd PaddleOCR/
# 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_mv3_none_bilstm_ctc.tar
# 解压模型参数
cd pretrain_models
tar -xf rec_mv3_none_bilstm_ctc.tar && rm -rf rec_mv3_none_bilstm_ctc.tar
```
开始训练:
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```
# 设置PYTHONPATH路径
export PYTHONPATH=$PYTHONPATH:.
# GPU训练 支持单卡,多卡训练,通过CUDA_VISIBLE_DEVICES指定卡号
export CUDA_VISIBLE_DEVICES=0,1,2,3
# 训练icdar15英文数据
python3 tools/train.py -c configs/rec/rec_icdar15_train.yml
```
PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_train.yml` 中修改 `eval_batch_step` 设置评估频率,默认每500个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/rec_CRNN/best_accuracy`
如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。
* 提示: 可通过 -c 参数选择 `configs/rec/` 路径下的多种模型配置进行训练,PaddleOCR支持的识别算法有:
| 配置文件 | 算法名称 | backbone | trans | seq | pred |
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: |
| rec_chinese_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc |
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
| rec_mv3_tps_bilstm_ctc.yml | STARNet | Mobilenet_v3 large 0.5 | tps | BiLSTM | ctc |
| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
| rec_r34_vd_tps_bilstm_attn.yml | RARE | Resnet34_vd | tps | BiLSTM | attention |
| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
训练中文数据,推荐使用`rec_chinese_lite_train.yml`,如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
`rec_mv3_none_none_ctc.yml` 为例:
```
Global:
...
# 修改 image_shape 以适应长文本
image_shape: [3, 32, 320]
...
# 修改字符类型
character_type: ch
# 添加自定义字典,如修改字典请将路径指向新字典
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
...
# 修改reader类型
reader_yml: ./configs/rec/rec_chinese_reader.yml
...
...
```
**注意,预测/评估时的配置文件请务必与训练一致。**
### 评估
评估数据集可以通过 `configs/rec/rec_icdar15_reader.yml` 修改EvalReader中的 `label_file_path` 设置。
*注意* 评估时必须确保配置文件中 infer_img 字段为空
```
export CUDA_VISIBLE_DEVICES=0
# GPU 评估, Global.checkpoints 为待测权重
python3 tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
### 预测
* 训练引擎的预测
使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。
默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 指定权重:
```
# 预测英文结果
python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
预测图片:
![](./imgs_words/en/word_1.png)
得到输入图像的预测结果:
```
infer_img: doc/imgs_words/en/word_1.png
index: [19 24 18 23 29]
word : joint
```
预测使用的配置文件必须与训练一致,如您通过 `python3 tools/train.py -c configs/rec/rec_chinese_lite_train.yml` 完成了中文模型的训练,
您可以使用如下命令进行中文模型预测。
```
# 预测中文结果
python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/ch/word_1.jpg
```
预测图片:
![](./imgs_words/ch/word_1.jpg)
得到输入图像的预测结果:
```
infer_img: doc/imgs_words/ch/word_1.jpg
index: [2092 177 312 2503]
word : 韩国小馆
```
# 版本更新
- 2020.6.5 支持 `attetnion` 模型导出 `inference_model`
- 2020.6.5 支持单独预测识别时,输出结果得分
- 2020.5.30 提供超轻量级中文OCR在线体验
- 2020.5.30 模型预测、训练支持Windows系统
- 2020.5.30 开源通用中文OCR模型
- 2020.5.14 发布[PaddleOCR公开课](https://www.bilibili.com/video/BV1nf4y1U7RX?p=4)
- 2020.5.14 发布[PaddleOCR实战练习](https://aistudio.baidu.com/aistudio/projectdetail/467229)
- 2020.5.14 开源8.6M超轻量级中文OCR模型
...@@ -61,8 +61,6 @@ class TrainReader(object): ...@@ -61,8 +61,6 @@ class TrainReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader return batch_iter_reader
......
...@@ -17,6 +17,8 @@ import cv2 ...@@ -17,6 +17,8 @@ import cv2
import numpy as np import numpy as np
import json import json
import sys import sys
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from .data_augment import AugmentData from .data_augment import AugmentData
from .random_crop_data import RandomCropData from .random_crop_data import RandomCropData
...@@ -100,6 +102,7 @@ class DBProcessTrain(object): ...@@ -100,6 +102,7 @@ class DBProcessTrain(object):
img_path, gt_label = self.convert_label_infor(label_infor) img_path, gt_label = self.convert_label_infor(label_infor)
imgvalue = cv2.imread(img_path) imgvalue = cv2.imread(img_path)
if imgvalue is None: if imgvalue is None:
logger.info("{} does not exist!".format(img_path))
return None return None
data = self.make_data_dict(imgvalue, gt_label) data = self.make_data_dict(imgvalue, gt_label)
data = AugmentData(data) data = AugmentData(data)
......
...@@ -41,13 +41,18 @@ class LMDBReader(object): ...@@ -41,13 +41,18 @@ class LMDBReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.drop_last = False
self.use_tps = False
if "tps" in params:
self.ues_tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == "eval": self.drop_last = True
else:
self.batch_size = params['test_batch_size_per_card'] self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test": self.drop_last = False
self.batch_size = 1 self.infer_img = params['infer_img']
self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self): def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
...@@ -100,13 +105,18 @@ class LMDBReader(object): ...@@ -100,13 +105,18 @@ class LMDBReader(object):
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img yield norm_img
else: else:
lmdb_sets = self.load_hierarchical_lmdb_dataset() lmdb_sets = self.load_hierarchical_lmdb_dataset()
...@@ -126,9 +136,13 @@ class LMDBReader(object): ...@@ -126,9 +136,13 @@ class LMDBReader(object):
if sample_info is None: if sample_info is None:
continue continue
img, label = sample_info img, label = sample_info
outs = process_image(img, self.image_shape, label, outs = process_image(
self.char_ops, self.loss_type, img=img,
self.max_text_length) image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length)
if outs is None: if outs is None:
continue continue
yield outs yield outs
...@@ -136,6 +150,7 @@ class LMDBReader(object): ...@@ -136,6 +150,7 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets): if finish_read_num == len(lmdb_sets):
break break
self.close_lmdb_dataset(lmdb_sets) self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader(): def batch_iter_reader():
batch_outs = [] batch_outs = []
for outs in sample_iter_reader(): for outs in sample_iter_reader():
...@@ -143,10 +158,11 @@ class LMDBReader(object): ...@@ -143,10 +158,11 @@ class LMDBReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if len(batch_outs) != 0: if not self.drop_last:
yield batch_outs if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'test': if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader
...@@ -165,26 +181,34 @@ class SimpleReader(object): ...@@ -165,26 +181,34 @@ class SimpleReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.infer_img = params['infer_img']
self.use_tps = False
if "tps" in params:
self.use_tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == 'eval': self.drop_last = True
self.batch_size = params['test_batch_size_per_card']
else: else:
self.batch_size = 1 self.batch_size = params['test_batch_size_per_card']
self.infer_img = params['infer_img'] self.drop_last = False
def __call__(self, process_id): def __call__(self, process_id):
if self.mode != 'train': if self.mode != 'train':
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img yield norm_img
else: else:
with open(self.label_file_path, "rb") as fin: with open(self.label_file_path, "rb") as fin:
...@@ -192,7 +216,7 @@ class SimpleReader(object): ...@@ -192,7 +216,7 @@ class SimpleReader(object):
img_num = len(label_infor_list) img_num = len(label_infor_list)
img_id_list = list(range(img_num)) img_id_list = list(range(img_num))
random.shuffle(img_id_list) random.shuffle(img_id_list)
if sys.platform=="win32": if sys.platform == "win32":
print("multiprocess is not fully compatible with Windows." print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.") "num_workers will be 1.")
self.num_workers = 1 self.num_workers = 1
...@@ -204,7 +228,7 @@ class SimpleReader(object): ...@@ -204,7 +228,7 @@ class SimpleReader(object):
if img is None: if img is None:
logger.info("{} does not exist!".format(img_path)) logger.info("{} does not exist!".format(img_path))
continue continue
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
label = substr[1] label = substr[1]
...@@ -222,9 +246,10 @@ class SimpleReader(object): ...@@ -222,9 +246,10 @@ class SimpleReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if len(batch_outs) != 0: if not self.drop_last:
yield batch_outs if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'test': if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader
...@@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape): ...@@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape):
return padding_im return padding_im
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = 0
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(32 * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def get_img_data(value): def get_img_data(value):
"""get_img_data""" """get_img_data"""
if not value: if not value:
...@@ -66,8 +92,13 @@ def process_image(img, ...@@ -66,8 +92,13 @@ def process_image(img,
label=None, label=None,
char_ops=None, char_ops=None,
loss_type=None, loss_type=None,
max_text_length=None): max_text_length=None,
norm_img = resize_norm_img(img, image_shape) tps=None,
infer_mode=False):
if not infer_mode or char_ops.character_type == "en" or tps != None:
norm_img = resize_norm_img(img, image_shape)
else:
norm_img = resize_norm_img_chinese(img, image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
if label is not None: if label is not None:
char_num = char_ops.get_char_num() char_num = char_ops.get_char_num()
......
...@@ -30,6 +30,8 @@ class RecModel(object): ...@@ -30,6 +30,8 @@ class RecModel(object):
global_params = params['Global'] global_params = params['Global']
char_num = global_params['char_ops'].get_char_num() char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num global_params['char_num'] = char_num
self.char_type = global_params['character_type']
self.infer_img = global_params['infer_img']
if "TPS" in params: if "TPS" in params:
tps_params = deepcopy(params["TPS"]) tps_params = deepcopy(params["TPS"])
tps_params.update(global_params) tps_params.update(global_params)
...@@ -60,8 +62,8 @@ class RecModel(object): ...@@ -60,8 +62,8 @@ class RecModel(object):
def create_feed(self, mode): def create_feed(self, mode):
image_shape = deepcopy(self.image_shape) image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1) image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if mode == "train": if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if self.loss_type == "attention": if self.loss_type == "attention":
label_in = fluid.data( label_in = fluid.data(
name='label_in', name='label_in',
...@@ -86,6 +88,16 @@ class RecModel(object): ...@@ -86,6 +88,16 @@ class RecModel(object):
use_double_buffer=True, use_double_buffer=True,
iterable=False) iterable=False)
else: else:
if self.char_type == "ch" and self.infer_img:
image_shape[-1] = -1
if self.tps != None:
logger.info(
"WARNRNG!!!\n"
"TPS does not support variable shape in chinese!"
"We set img_shape to be the same , it may affect the inference effect"
)
image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
labels = None labels = None
loader = None loader = None
return image, labels, loader return image, labels, loader
...@@ -110,7 +122,11 @@ class RecModel(object): ...@@ -110,7 +122,11 @@ class RecModel(object):
return loader, outputs return loader, outputs
elif mode == "export": elif mode == "export":
predict = predicts['predict'] predict = predicts['predict']
predict = fluid.layers.softmax(predict) if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else: else:
return loader, {'decoded_out': decoded_out} predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return loader, {'decoded_out': decoded_out, 'predicts': predict}
...@@ -123,6 +123,8 @@ class AttentionPredict(object): ...@@ -123,6 +123,8 @@ class AttentionPredict(object):
full_ids = fluid.layers.fill_constant_batch_size_like( full_ids = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='int64', value=1) input=init_state, shape=[-1, 1], dtype='int64', value=1)
full_scores = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='float32', value=1)
cond = layers.less_than(x=counter, y=array_len) cond = layers.less_than(x=counter, y=array_len)
while_op = layers.While(cond=cond) while_op = layers.While(cond=cond)
...@@ -171,6 +173,9 @@ class AttentionPredict(object): ...@@ -171,6 +173,9 @@ class AttentionPredict(object):
new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1) new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1)
fluid.layers.assign(new_ids, full_ids) fluid.layers.assign(new_ids, full_ids)
new_scores = fluid.layers.concat([full_scores, topk_scores], axis=1)
fluid.layers.assign(new_scores, full_scores)
layers.increment(x=counter, value=1, in_place=True) layers.increment(x=counter, value=1, in_place=True)
# update the memories # update the memories
...@@ -184,7 +189,7 @@ class AttentionPredict(object): ...@@ -184,7 +189,7 @@ class AttentionPredict(object):
length_cond = layers.less_than(x=counter, y=array_len) length_cond = layers.less_than(x=counter, y=array_len)
finish_cond = layers.logical_not(layers.is_empty(x=topk_indices)) finish_cond = layers.logical_not(layers.is_empty(x=topk_indices))
layers.logical_and(x=length_cond, y=finish_cond, out=cond) layers.logical_and(x=length_cond, y=finish_cond, out=cond)
return full_ids return full_ids, full_scores
def __call__(self, inputs, labels=None, mode=None): def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs) encoder_features = self.encoder(inputs)
...@@ -223,10 +228,10 @@ class AttentionPredict(object): ...@@ -223,10 +228,10 @@ class AttentionPredict(object):
decoder_size, char_num) decoder_size, char_num)
_, decoded_out = layers.topk(input=predict, k=1) _, decoded_out = layers.topk(input=predict, k=1)
decoded_out = layers.lod_reset(decoded_out, y=label_out) decoded_out = layers.lod_reset(decoded_out, y=label_out)
predicts = {'predict': predict, 'decoded_out': decoded_out} predicts = {'predict':predict, 'decoded_out':decoded_out}
else: else:
ids = self.gru_attention_infer( ids, predict = self.gru_attention_infer(
decoder_boot, self.max_length, char_num, word_vector_dim, decoder_boot, self.max_length, char_num, word_vector_dim,
encoded_vector, encoded_proj, decoder_size) encoded_vector, encoded_proj, decoder_size)
predicts = {'decoded_out': ids} predicts = {'predict':predict, 'decoded_out':ids}
return predicts return predicts
...@@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): ...@@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
total_sample_num = 0 total_sample_num = 0
total_acc_num = 0 total_acc_num = 0
total_batch_num = 0 total_batch_num = 0
if mode == "test": if mode == "eval":
is_remove_duplicate = False is_remove_duplicate = False
else: else:
is_remove_duplicate = True is_remove_duplicate = True
...@@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict): ...@@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict):
total_correct_number = 0 total_correct_number = 0
eval_data_acc_info = {} eval_data_acc_info = {}
for eval_data in eval_data_list: for eval_data in eval_data_list:
config['EvalReader']['lmdb_sets_dir'] = \ config['TestReader']['lmdb_sets_dir'] = \
eval_data_dir + "/" + eval_data eval_data_dir + "/" + eval_data
eval_reader = reader_main(config=config, mode="eval") eval_reader = reader_main(config=config, mode="test")
eval_info_dict['reader'] = eval_reader eval_info_dict['reader'] = eval_reader
metrics = eval_rec_run(exe, config, eval_info_dict, "eval") metrics = eval_rec_run(exe, config, eval_info_dict, "test")
total_evaluation_data_number += metrics['total_sample_num'] total_evaluation_data_number += metrics['total_sample_num']
total_correct_number += metrics['total_acc_num'] total_correct_number += metrics['total_acc_num']
eval_data_acc_info[eval_data] = metrics eval_data_acc_info[eval_data] = metrics
......
...@@ -32,10 +32,16 @@ class TextRecognizer(object): ...@@ -32,10 +32,16 @@ class TextRecognizer(object):
self.rec_image_shape = image_shape self.rec_image_shape = image_shape
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
char_ops_params = {} char_ops_params = {}
char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_type"] = args.rec_char_type
char_ops_params["character_dict_path"] = args.rec_char_dict_path char_ops_params["character_dict_path"] = args.rec_char_dict_path
char_ops_params['loss_type'] = 'ctc' if self.rec_algorithm != "RARE":
char_ops_params['loss_type'] = 'ctc'
self.loss_type = 'ctc'
else:
char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention'
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
...@@ -80,26 +86,43 @@ class TextRecognizer(object): ...@@ -80,26 +86,43 @@ class TextRecognizer(object):
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() self.predictor.zero_copy_run()
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0] if self.loss_type == "ctc":
predict_batch = self.output_tensors[1].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_lod = self.output_tensors[1].lod()[0] rec_idx_lod = self.output_tensors[0].lod()[0]
elapse = time.time() - starttime predict_batch = self.output_tensors[1].copy_to_cpu()
predict_time += elapse predict_lod = self.output_tensors[1].lod()[0]
starttime = time.time() elapse = time.time() - starttime
for rno in range(len(rec_idx_lod) - 1): predict_time += elapse
beg = rec_idx_lod[rno] for rno in range(len(rec_idx_lod) - 1):
end = rec_idx_lod[rno + 1] beg = rec_idx_lod[rno]
rec_idx_tmp = rec_idx_batch[beg:end, 0] end = rec_idx_lod[rno + 1]
preds_text = self.char_ops.decode(rec_idx_tmp) rec_idx_tmp = rec_idx_batch[beg:end, 0]
beg = predict_lod[rno] preds_text = self.char_ops.decode(rec_idx_tmp)
end = predict_lod[rno + 1] beg = predict_lod[rno]
probs = predict_batch[beg:end, :] end = predict_lod[rno + 1]
ind = np.argmax(probs, axis=1) probs = predict_batch[beg:end, :]
blank = probs.shape[1] ind = np.argmax(probs, axis=1)
valid_ind = np.where(ind != (blank - 1))[0] blank = probs.shape[1]
score = np.mean(probs[valid_ind, ind[valid_ind]]) valid_ind = np.where(ind != (blank - 1))[0]
rec_res.append([preds_text, score]) score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu()
elapse = time.time() - starttime
predict_time += elapse
for rno in range(len(rec_idx_batch)):
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
if len(end_pos) <= 1:
preds = rec_idx_batch[rno, 1:]
score = np.mean(predict_batch[rno, 1:])
else:
preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]])
preds_text = self.char_ops.decode(preds)
rec_res.append([preds_text, score])
return rec_res, predict_time return rec_res, predict_time
...@@ -116,7 +139,17 @@ if __name__ == "__main__": ...@@ -116,7 +139,17 @@ if __name__ == "__main__":
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
rec_res, predict_time = text_recognizer(img_list) try:
rec_res, predict_time = text_recognizer(img_list)
except Exception as e:
print(e)
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" % print("Total predict time for %d images:%.3f" %
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment