Commit 80aced81 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents fce82425 896d149e
......@@ -34,10 +34,10 @@ op:
client_type: local_predictor
#det模型路径
model_config: ./ppocr_det_mobile_2.0_serving
model_config: ./ppocr_det_v3_serving
#Fetch结果列表,以client_config中fetch_var的alias_name为准
fetch_list: ["save_infer_model/scale_0.tmp_1"]
fetch_list: ["sigmoid_0.tmp_0"]
#计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡
devices: "0"
......@@ -60,10 +60,10 @@ op:
client_type: local_predictor
#rec模型路径
model_config: ./ppocr_rec_mobile_2.0_serving
model_config: ./ppocr_rec_v3_serving
#Fetch结果列表,以client_config中fetch_var的alias_name为准
fetch_list: ["save_infer_model/scale_0.tmp_1"]
fetch_list: ["softmax_5.tmp_0"]
#计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡
devices: "0"
......
......@@ -392,38 +392,8 @@ class OCRReader(object):
return norm_img_batch[0]
def postprocess_old(self, outputs, with_score=False):
rec_res = []
rec_idx_lod = outputs["ctc_greedy_decoder_0.tmp_0.lod"]
rec_idx_batch = outputs["ctc_greedy_decoder_0.tmp_0"]
if with_score:
predict_lod = outputs["softmax_0.tmp_0.lod"]
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
if isinstance(rec_idx_batch, list):
rec_idx_tmp = [x[0] for x in rec_idx_batch[beg:end]]
else: #nd array
rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp)
if with_score:
beg = predict_lod[rno]
end = predict_lod[rno + 1]
if isinstance(outputs["softmax_0.tmp_0"], list):
outputs["softmax_0.tmp_0"] = np.array(outputs[
"softmax_0.tmp_0"]).astype(np.float32)
probs = outputs["softmax_0.tmp_0"][beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
else:
rec_res.append([preds_text])
return rec_res
def postprocess(self, outputs, with_score=False):
preds = outputs["save_infer_model/scale_0.tmp_1"]
preds = outputs["softmax_5.tmp_0"]
try:
preds = preds.numpy()
except:
......
......@@ -56,7 +56,7 @@ class DetOp(Op):
return {"x": det_img[np.newaxis, :].copy()}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
det_out = fetch_dict["save_infer_model/scale_0.tmp_1"]
det_out = fetch_dict["sigmoid_0.tmp_0"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
]
......
......@@ -55,7 +55,7 @@ class DetOp(Op):
return {"x": det_img[np.newaxis, :].copy()}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
det_out = fetch_dict["save_infer_model/scale_0.tmp_1"]
det_out = fetch_dict["sigmoid_0.tmp_0"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
]
......
......@@ -392,38 +392,8 @@ class OCRReader(object):
return norm_img_batch[0]
def postprocess_old(self, outputs, with_score=False):
rec_res = []
rec_idx_lod = outputs["ctc_greedy_decoder_0.tmp_0.lod"]
rec_idx_batch = outputs["ctc_greedy_decoder_0.tmp_0"]
if with_score:
predict_lod = outputs["softmax_0.tmp_0.lod"]
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
if isinstance(rec_idx_batch, list):
rec_idx_tmp = [x[0] for x in rec_idx_batch[beg:end]]
else: #nd array
rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp)
if with_score:
beg = predict_lod[rno]
end = predict_lod[rno + 1]
if isinstance(outputs["softmax_0.tmp_0"], list):
outputs["softmax_0.tmp_0"] = np.array(outputs[
"softmax_0.tmp_0"]).astype(np.float32)
probs = outputs["softmax_0.tmp_0"][beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
else:
rec_res.append([preds_text])
return rec_res
def postprocess(self, outputs, with_score=False):
preds = outputs["save_infer_model/scale_0.tmp_1"]
preds = outputs["softmax_5.tmp_0"]
try:
preds = preds.numpy()
except:
......
# PP-OCR Models Pruning
Generally, a more complex model would achive better performance in the task, but it also leads to some redundancy in the model. Model Pruning is a technique that reduces this redundancy by removing the sub-models in the neural network model, so as to reduce model calculation complexity and improve model inference performance.
Generally, a more complex model would achieve better performance in the task, but it also leads to some redundancy in the model. Model Pruning is a technique that reduces this redundancy by removing the sub-models in the neural network model, so as to reduce model calculation complexity and improve model inference performance.
This example uses PaddleSlim provided[APIs of Pruning](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/docs/zh_cn/api_cn/dygraph/pruners) to compress the OCR model.
[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim), an open source library which integrates model pruning, quantization (including quantization training and offline quantization), distillation, neural network architecture search, and many other commonly used and leading model compression technique in the industry.
......
......@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
config['Optimizer'],
epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_dataloader),
parameters=model.parameters())
model=model)
# build metric
eval_class = build_metric(config['Metric'])
......
......@@ -22,9 +22,7 @@
### 1. 安装PaddleSlim
```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd PaddleSlim
python setup.py install
pip3 install paddleslim==2.2.2
```
### 2. 准备训练好的模型
......@@ -43,7 +41,15 @@ python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
tar -xf ch_ppocr_mobile_v2.0_det_train.tar
python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_model
```
模型蒸馏和模型量化可以同时使用,以PPOCRv3检测模型为例:
```
# 下载检测预训练模型:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
tar xf https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
python deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3_det/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model='./ch_PP-OCRv3_det_distill_train/best_accuracy' Global.save_model_dir=./output/quant_model_distill/
```
如果要训练识别模型的量化,修改配置文件和加载的模型参数即可。
......
......@@ -25,9 +25,7 @@ After training, if you want to further compress the model size and accelerate th
### 1. Install PaddleSlim
```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd PaddlSlim
python setup.py install
pip3 install paddleslim==2.2.2
```
......@@ -52,6 +50,17 @@ python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3
```
Model distillation and model quantization can be used at the same time, taking the PPOCRv3 detection model as an example:
```
# download provided model
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
tar xf https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
python deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3_det/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model='./ch_PP-OCRv3_det_distill_train/best_accuracy' Global.save_model_dir=./output/quant_model_distill/
```
If you want to quantify the text recognition model, you can modify the configuration file and loaded model parameters.
### 4. Export inference model
Once we got the model after pruning and fine-tuning, we can export it as an inference model for the deployment of predictive tasks:
......
......@@ -17,9 +17,9 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
sys.path.append(
os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
sys.path.insert(
0, os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
import argparse
......@@ -129,7 +129,6 @@ def main():
quanter.quantize(model)
load_model(config, model)
model.eval()
# build metric
eval_class = build_metric(config['Metric'])
......@@ -142,6 +141,7 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn)
model.eval()
logger.info('metric eval ***************')
for k, v in metric.items():
......@@ -156,7 +156,6 @@ def main():
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
model.model_list[idx].eval()
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger, quanter)
......
......@@ -161,7 +161,13 @@ def main(config, device, logger, vdl_writer):
if config["Global"]["pretrained_model"] is not None:
pre_best_model_dict = load_model(config, model)
quanter = QAT(config=quant_config, act_preprocess=PACT)
freeze_params = False
if config['Architecture']["algorithm"] in ["Distillation"]:
for key in config['Architecture']["Models"]:
freeze_params = freeze_params or config['Architecture']['Models'][
key].get('freeze_params', False)
act = None if freeze_params else PACT
quanter = QAT(config=quant_config, act_preprocess=act)
quanter.quantize(model)
if config['Global']['distributed']:
......
......@@ -114,7 +114,7 @@ A: PGNet不需要字符级别的标注,NMS操作以及ROI操作。同时提出
(3)端到端统计:
端对端召回率:准确检测并正确识别文本行在全部标注文本行的占比;
端到端准确率:准确检测并正确识别文本行在 检测到的文本行数量 的占比;
准确检测的标准是检测框与标注框的IOU大于某个阈值,正确识别的检测框中的文本与标注的文本相同。
准确检测的标准是检测框与标注框的IOU大于某个阈值,正确识别的检测框中的文本与标注的文本相同。
<a name="15"></a>
......
[English](../doc_en/PP-OCRv3_introduction_en.md) | 简体中文
# PP-OCRv3
- [1. 简介](#1)
- [2. 检测优化](#2)
- [3. 识别优化](#3)
- [4. 端到端评估](#4)
<a name="1"></a>
## 1. 简介
PP-OCRv3在PP-OCRv2的基础上进一步升级。检测模型仍然基于DB算法,优化策略采用了带残差注意力机制的FPN结构RSEFPN、增大感受野的PAN结构LKPAN、基于DML训练的更优的教师模型;识别模型将base模型从CRNN替换成了IJCAI 2022论文[SVTR](),并采用SVTR轻量化、带指导训练CTC、数据增广策略RecConAug、自监督训练的更好的预训练模型、无标签数据的使用进行模型加速和效果提升。更多细节请参考PP-OCRv3[技术报告](./PP-OCRv3_introduction.md)。
PP-OCRv3系统pipeline如下:
<div align="center">
<img src="../ppocrv3_framework.png" width="800">
</div>
<a name="2"></a>
## 2. 检测优化
PP-OCRv3采用PP-OCRv2的[CML](https://arxiv.org/pdf/2109.03144.pdf)蒸馏策略,在蒸馏的student模型、teacher模型精度提升,CML蒸馏策略上分别做了优化。
- 在蒸馏student模型精度提升方面,提出了基于残差结构的通道注意力模块RSEFPN(Residual Squeeze-and-Excitation FPN),用于提升student模型精度和召回。
RSEFPN的网络结构如下图所示,RSEFPN在PP-OCRv2的FPN基础上,将FPN中的卷积层更换为了通道注意力结构的RSEConv层。
<div align="center">
<img src=".././ppocr_v3/RSEFPN.png" width="800">
</div>
RSEFPN将PP-OCR检测模型的精度hmean从81.3%提升到84.5%。模型大小从3M变为3.6M。
*注:PP-OCRv2的FPN通道数仅为96和24,如果直接用SE模块代替FPN的卷积会导致精度下降,RSEConv引入残差结构可以防止训练中包含重要特征的通道被抑制。*
- 在蒸馏的teacher模型精度提升方面,提出了LKPAN结构替换PP-OCRv2的FPN结构,并且使用ResNet50作为Backbone,更大的模型带来更多的精度提升。另外,对teacher模型使用[DML](https://arxiv.org/abs/1706.00384)蒸馏策略进一步提升teacher模型的精度。最终teacher的模型指标相比ppocr_server_v2.0从83.2%提升到了86.0%。
*注:[PP-OCRv2的FPN结构](https://github.com/PaddlePaddle/PaddleOCR/blob/77acb3bfe51c8a46c684527f73cd218cefedb4a3/ppocr/modeling/necks/db_fpn.py#L107)对DB算法FPN结构做了轻量级设计*
LKPAN的网络结构如下图所示:
<div align="center">
<img src="../ppocr_v3/LKPAN.png" width="800">
</div>
LKPAN(Large Kernel PAN)是一个具有更大感受野的轻量级[PAN](https://arxiv.org/pdf/1803.01534.pdf)结构。在LKPAN的path augmentation中,使用kernel size为`9*9`的卷积;更大的kernel size意味着更大的感受野,更容易检测大字体的文字以及极端长宽比的文字。LKPAN将PP-OCR检测模型的精度hmean从81.3%提升到84.9%。
*注:LKPAN相比RSEFPN有更多的精度提升,但是考虑到模型大小和预测速度等因素,在student模型中使用RSEFPN。*
采用上述策略,PP-OCRv3相比PP-OCRv2,hmean指标从83.3%提升到85.4%;预测速度从平均117ms/image变为124ms/image。
3. PP-OCRv3检测模型消融实验
|序号|策略|模型大小|hmean|Intel Gold 6148CPU+mkldnn预测耗时|
|-|-|-|-|-|
|0|PP-OCR|3M|81.3%|117ms|
|1|PP-OCRV2|3M|83.3%|117ms|
|2|0 + RESFPN|3.6M|84.5%|124ms|
|3|0 + LKPAN|4.6M|84.9%|156ms|
|4|ppocr_server_v2.0 |124M|83.2%||171ms|
|5|teacher + DML + LKPAN|124M|86.0%|396ms|
|6|0 + 2 + 5 + CML|3.6M|85.4%|124ms|
<a name="3"></a>
## 3. 识别优化
[SVTR](https://arxiv.org/abs/2205.00159) 证明了强大的单视觉模型(无需序列模型)即可高效准确完成文本识别任务,在中英文数据上均有优秀的表现。经过实验验证,SVTR_Tiny在自建的 [中文数据集上](https://arxiv.org/abs/2109.03144) ,识别精度可以提升10.7%,网络结构如下所示:
<img src="../ppocr_v3/svtr_tiny.jpg" width=800>
由于 MKLDNN 加速库支持的模型结构有限,SVTR 在CPU+MKLDNN上相比PP-OCRv2慢了10倍。
PP-OCRv3 期望在提升模型精度的同时,不带来额外的推理耗时。通过分析发现,SVTR_Tiny结构的主要耗时模块为Mixing Block,因此我们对 SVTR_Tiny 的结构进行了一系列优化(详细速度数据请参考下方消融实验表格):
1. 将SVTR网络前半部分替换为PP-LCNet的前三个stage,保留4个 Global Mixing Block ,精度为76%,加速69%,网络结构如下所示:
<img src="../ppocr_v3/svtr_g4.png" width=800>
2. 将4个 Global Attenntion Block 减小到2个,精度为72.9%,加速69%,网络结构如下所示:
<img src="../ppocr_v3/svtr_g2.png" width=800>
3. 实验发现 Global Attention 的预测速度与输入其特征的shape有关,因此后移Global Mixing Block的位置到池化层之后,精度下降为71.9%,速度超越 CNN-base 的PP-OCRv2 22%,网络结构如下所示:
<img src="../ppocr_v3/ppocr_v3.png" width=800>
为了提升模型精度同时不引入额外推理成本,PP-OCRv3参考GTC策略,使用Attention监督CTC训练,预测时完全去除Attention模块,在推理阶段不增加任何耗时, 精度提升3.8%,训练流程如下所示:
<img src="../ppocr_v3/GTC.png" width=800>
在训练策略方面,PP-OCRv3参考 [SSL](https://github.com/ku21fan/STR-Fewer-Labels) 设计了文本方向任务,训练了适用于文本识别的预训练模型,加速模型收敛过程,精度提升了0.6%; 使用UDML蒸馏策略,进一步提升精度1.5%,训练流程所示:
<img src="../ppocr_v3/SSL.png" width="300"> <img src="../ppocr_v3/UDML.png" width="500">
数据增强方面:
1. 基于 [ConCLR](https://www.cse.cuhk.edu.hk/~byu/papers/C139-AAAI2022-ConCLR.pdf) 中的ConAug方法,设计了 RecConAug 数据增强方法,增强数据多样性,精度提升0.5%,增强可视化效果如下所示:
<img src="../ppocr_v3/recconaug.png" width=800>
2. 使用训练好的 SVTR_large 预测 120W 的 lsvt 无标注数据,取出其中得分大于0.95的数据,共得到81W识别数据加入到PP-OCRv3的训练数据中,精度提升1%。
总体来讲PP-OCRv3识别从网络结构、训练策略、数据增强三个方向做了进一步优化:
- 网络结构上:考虑[SVTR](https://arxiv.org/abs/2205.00159) 在中英文效果上的优越性,采用SVTR_Tiny作为base,选取Global Mixing Block和卷积组合提取特征,并将Global Mixing Block位置后移进行加速; 参考 [GTC](https://arxiv.org/pdf/2002.01276.pdf) 策略,使用注意力机制模块指导CTC训练,定位和识别字符,提升不规则文本的识别精度。
- 训练策略上:参考 [SSL](https://github.com/ku21fan/STR-Fewer-Labels) 设计了方向分类前序任务,获取更优预训练模型,加速模型收敛过程,提升精度; 使用UDML蒸馏策略、监督attention、ctc两个分支得到更优模型。
- 数据增强上:基于 [ConCLR](https://www.cse.cuhk.edu.hk/~byu/papers/C139-AAAI2022-ConCLR.pdf) 中的ConAug方法,改进得到 RecConAug 数据增广方法,支持随机结合任意多张图片,提升训练数据的上下文信息丰富度,增强模型鲁棒性;使用 SVTR_large 预测无标签数据,向训练集中补充81w高质量真实数据。
基于上述策略,PP-OCRv3识别模型相比PP-OCRv2,在速度可比的情况下,精度进一步提升4.5%。 具体消融实验如下所示:
实验细节:
| id | 策略 | 模型大小 | 精度 | 速度(cpu + mkldnn)|
|-----|-----|--------|----| --- |
| 01 | PP-OCRv2 | 8M | 69.3% | 8.54ms |
| 02 | SVTR_Tiny | 21M | 80.1% | 97ms |
| 03 | LCNet_SVTR_G4 | 9.2M | 76% | 30ms |
| 04 | LCNet_SVTR_G2 | 13M | 72.98% | 9.37ms |
| 05 | PP-OCRv3 | 12M | 71.9% | 6.6ms |
| 06 | + large input_shape | 12M | 73.98% | 7.6ms |
| 06 | + GTC | 12M | 75.8% | 7.6ms |
| 07 | + RecConAug | 12M | 76.3% | 7.6ms |
| 08 | + SSL pretrain | 12M | 76.9% | 7.6ms |
| 09 | + UDML | 12M | 78.4% | 7.6ms |
| 10 | + unlabeled data | 12M | 79.4% | 7.6ms |
注: 测试速度时,实验01-05输入图片尺寸均为(3,32,320),06-10输入图片尺寸均为(3,48,320)
<a name="4"></a>
## 4. 端到端评估
......@@ -246,7 +246,7 @@ class MyMetric(object):
def get_metric(self):
"""
return metircs {
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
......
# EAST
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [EAST: An Efficient and Accurate Scene Text Detector](https://arxiv.org/abs/1704.03155)
> Xinyu Zhou, Cong Yao, He Wen, Yuzhi Wang, Shuchang Zhou, Weiran He, Jiajun Liang
> CVPR, 2017
在ICDAR2015文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|EAST|ResNet50_vd|88.71%| 81.36%| 84.88%| [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST| MobileNetV3| 78.2%| 79.1%| 78.65%| [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
上表中的EAST训练模型使用ICDAR2015文本检测公开数据集训练得到,数据集下载可参考 [ocr_datasets](./dataset/ocr_datasets.md)
数据下载完成后,请参考[文本检测训练教程](./detection.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的检测模型只需要**更换配置文件**即可。
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)),可以使用如下命令进行转换:
```shell
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_r50_east/
```
EAST文本检测模型推理,需要设置参数--det_algorithm="EAST",执行预测:
```shell
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_r50_east/" --det_algorithm="EAST"
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。
![](../imgs_results/det_res_img_10_east.jpg)
<a name="4-2"></a>
### 4.2 C++推理
由于后处理暂未使用CPP编写,EAST文本检测模型暂不支持CPP推理。
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂未支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂未支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@inproceedings{zhou2017east,
title={East: an efficient and accurate scene text detector},
author={Zhou, Xinyu and Yao, Cong and Wen, He and Wang, Yuzhi and Zhou, Shuchang and He, Weiran and Liang, Jiajun},
booktitle={Proceedings of the IEEE conference on Computer Vision and Pattern Recognition},
pages={5551--5560},
year={2017}
}
```
# SAST
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [A Single-Shot Arbitrarily-Shaped Text Detector based on Context Attended Multi-Task Learning](https://arxiv.org/abs/1908.05498)
> Wang, Pengfei and Zhang, Chengquan and Qi, Fei and Huang, Zuming and En, Mengyi and Han, Junyu and Liu, Jingtuo and Ding, Errui and Shi, Guangming
> ACM MM, 2019
在ICDAR2015文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|SAST|ResNet50_vd|[configs/det/det_r50_vd_sast_icdar15.yml](../../configs/det/det_r50_vd_sast_icdar15.yml)|91.39%|83.77%|87.42%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
在Total-text文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|SAST|ResNet50_vd|[configs/det/det_r50_vd_sast_totaltext.yml](../../configs/det/det_r50_vd_sast_totaltext.yml)|89.63%|78.44%|83.66%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
请参考[文本检测训练教程](./detection.md)。PaddleOCR对代码进行了模块化,训练不同的检测模型只需要**更换配置文件**即可。
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
#### (1). 四边形文本检测模型(ICDAR2015)
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_ic15
```
**SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`**,可以执行如下命令:
```
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_sast_ic15/"
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
![](../imgs_results/det_res_img_10_sast.jpg)
#### (2). 弯曲文本检测模型(Total-Text)
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_tt
```
SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,可以执行如下命令:
```
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
![](../imgs_results/det_res_img623_sast.jpg)
**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
<a name="4-2"></a>
### 4.2 C++推理
暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂未支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂未支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@inproceedings{wang2019single,
title={A Single-Shot Arbitrarily-Shaped Text Detector based on Context Attended Multi-Task Learning},
author={Wang, Pengfei and Zhang, Chengquan and Qi, Fei and Huang, Zuming and En, Mengyi and Han, Junyu and Liu, Jingtuo and Ding, Errui and Shi, Guangming},
booktitle={Proceedings of the 27th ACM International Conference on Multimedia},
pages={1277--1285},
year={2019}
}
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment