"...composable_kernel_onnxruntime.git" did not exist on "e6ee659488f8a6368273d8f664057ec5b8e724f5"
Unverified Commit 45a4aba4 authored by Double_V's avatar Double_V Committed by GitHub
Browse files

Merge branch 'dygraph' into sdmgr

parents 98162be4 033cc4cf
# Enhanced CTC Loss
在OCR识别中, CRNN是一种在工业界广泛使用的文字识别算法。 在训练阶段,其采用CTCLoss来计算网络损失; 在推理阶段,其采用CTCDecode来获得解码结果。虽然CRNN算法在实际业务中被证明能够获得很好的识别效果, 然而用户对识别准确率的要求却是无止境的,如何进一步提升文字识别的准确率呢? 本文以CTCLoss为切人点,分别从难例挖掘、 多任务学习、 Metric Learning 3个不同的角度探索了CTCLoss的改进融合方案,提出了EnhancedCTCLoss,其包括如下3个组成部分: Focal-CTC Loss,A-CTC Loss, C-CTC Loss。
## 1. Focal-CTC Loss
Focal Loss 出自论文《Focal Loss for Dense Object Detection》, 该loss最先提出的时候主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。
其损失函数形式如下:
<div align="center">
<img src="./focal_loss_formula.png" width = "600" />
</div>
其中, y' 是经过激活函数的输出,取值在0-1之间。其在原始的交叉熵损失的基础上加了一个调制系数(1 – y’)^ &gamma;和平衡因子&alpha;。 当&alpha; = 1,y=1时,其损失函数与交叉熵损失的对比如下图所示:
<div align="center">
<img src="./focal_loss_image.png" width = "600" />
</div>
从上图可以看到, 当&gamma;> 0时,调整系数(1-y’)^&gamma; 赋予易分类样本损失一个更小的权重,使得网络更关注于困难的、错分的样本。 调整因子&gamma;用于调节简单样本权重降低的速率,当&gamma;为0时即为交叉熵损失函数,当&gamma;增加时,调整因子的影响也会随之增大。实验发现&gamma;为2是最优。平衡因子&alpha;用来平衡正负样本本身的比例不均,文中&alpha;取0.25。
对于经典的CTC算法,假设某个特征序列(f<sub>1</sub>, f<sub>2</sub>, ......f<sub>t</sub>), 经过CTC解码之后结果等于label的概率为y’, 则CTC解码结果不为label的概率即为(1-y’);不难发现 CTCLoss值和y’有如下关系:
<div align="center">
<img src="./equation_ctcloss.png" width = "250" />
</div>
结合Focal Loss的思想,赋予困难样本较大的权重,简单样本较小的权重,可以使网络更加聚焦于对困难样本的挖掘,进一步提升识别的准确率,由此我们提出了Focal-CTC Loss; 其定义如下所示:
<div align="center">
<img src="./equation_focal_ctc.png" width = "500" />
</div>
实验中,&gamma;取值为2, &alpha;= 1, 具体实现见: [rec_ctc_loss.py](../../ppocr/losses/rec_ctc_loss.py)
## 2. A-CTC Loss
A-CTC Loss是CTC Loss + ACE Loss的简称。 其中ACE Loss出自论文< Aggregation Cross-Entropy for Sequence Recognition>. ACE Loss相比于CTCLoss,主要有如下两点优势:
+ ACE Loss能够解决2-D文本的识别问题; CTCLoss只能够处理1-D文本
+ ACE Loss 在时间复杂度和空间复杂度上优于CTC loss
前人总结的OCR识别算法的优劣如下图所示:
<div align="center">
<img src="./rec_algo_compare.png" width = "1000" />
</div>
虽然ACELoss确实如上图所说,可以处理2D预测,在内存占用及推理速度方面具备优势,但在实践过程中,我们发现单独使用ACE Loss, 识别效果并不如CTCLoss. 因此,我们尝试将CTCLoss和ACELoss进行组合,同时以CTCLoss为主,将ACELoss 定位为一个辅助监督loss。 这一尝试收到了效果,在我们内部的实验数据集上,相比单独使用CTCLoss,识别准确率可以提升1%左右。
A_CTC Loss定义如下:
<div align="center">
<img src="./equation_a_ctc.png" width = "300" />
</div>
实验中,λ = 0.1. ACE loss实现代码见: [ace_loss.py](../../ppocr/losses/ace_loss.py)
## 3. C-CTC Loss
C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大累间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。
在中文OCR识别任务中,通过对badcase分析, 我们发现中文识别的一大难点是相似字符多,容易误识。 由此我们想到是否可以借鉴Metric Learing的想法, 增大相似字符的类间距,从而提高识别准确率。然而,MetricLearning主要用于图像识别领域,训练数据的标签为一个固定的值;而对于OCR识别来说,其本质上是一个序列识别任务,特征和label之间并不具有显式的对齐关系,因此两者如何结合依然是一个值得探索的方向。
通过尝试Arcmargin, Cosmargin等方法, 我们最终发现Centerloss 有助于进一步提升识别的准确率。C_CTC Loss定义如下:
<div align="center">
<img src="./equation_c_ctc.png" width = "300" />
</div>
实验中,我们设置λ=0.25. center_loss实现代码见: [center_loss.py](../../ppocr/losses/center_loss.py)
值得一提的是, 在C-CTC Loss中,选择随机初始化Center并不能够带来明显的提升. 我们的Center初始化方法如下:
+ 基于原始的CTCLoss, 训练得到一个网络N
+ 挑选出训练集中,识别完全正确的部分, 组成集合G
+ 将G中的每个样本送入网络,进行前向计算, 提取最后一个FC层的输入(即feature)及其经过argmax计算的结果(即index)之间的对应关系
+ 将相同index的feature进行聚合,计算平均值,得到各自字符的初始center.
以配置文件`configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml`为例, center提取命令如下所示:
```
python tools/export_center.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml -o Global.pretrained_model: "./output/rec_mobile_pp-OCRv2/best_accuracy"
```
运行完后,会在PaddleOCR主目录下生成`train_center.pkl`.
## 4. 实验
对于上述的三种方案,我们基于百度内部数据集进行了训练、评测,实验情况如下表所示:
|algorithm| Focal_CTC | A_CTC | C-CTC |
|:------| :------| ------: | :------: |
|gain| +0.3% | +0.7% | +1.7% |
基于上述实验结论,我们在PP-OCRv2中,采用了C-CTC的策略。 值得一提的是,由于PP-OCRv2 处理的是6625个中文字符的识别任务,字符集比较大,形似字较多,所以在该任务上C-CTC 方案带来的提升较大。 但如果换做其他OCR识别任务,结论可能会有所不同。大家可以尝试Focal-CTC,A-CTC, C-CTC以及组合方案EnhancedCTC,相信会带来不同程度的提升效果。
统一的融合方案见如下文件: [rec_enhanced_ctc_loss.py](../../ppocr/losses/rec_enhanced_ctc_loss.py)
......@@ -38,7 +38,7 @@ class CTCLoss(nn.Layer):
if self.use_focal_loss:
weight = paddle.exp(-loss)
weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
weight = paddle.square(weight) * self.focal_loss_alpha
weight = paddle.square(weight)
loss = paddle.multiply(loss, weight)
loss = loss.mean() # sum
loss = loss.mean()
return {'loss': loss}
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .ace_loss import ACELoss
from .center_loss import CenterLoss
from .rec_ctc_loss import CTCLoss
class EnhancedCTCLoss(nn.Layer):
def __init__(self,
use_focal_loss=False,
use_ace_loss=False,
ace_loss_weight=0.1,
use_center_loss=False,
center_loss_weight=0.05,
num_classes=6625,
feat_dim=96,
init_center=False,
center_file_path=None,
**kwargs):
super(EnhancedCTCLoss, self).__init__()
self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss)
self.use_ace_loss = False
if use_ace_loss:
self.use_ace_loss = use_ace_loss
self.ace_loss_func = ACELoss()
self.ace_loss_weight = ace_loss_weight
self.use_center_loss = False
if use_center_loss:
self.use_center_loss = use_center_loss
self.center_loss_func = CenterLoss(
num_classes=num_classes,
feat_dim=feat_dim,
init_center=init_center,
center_file_path=center_file_path)
self.center_loss_weight = center_loss_weight
def __call__(self, predicts, batch):
loss = self.ctc_loss_func(predicts, batch)["loss"]
if self.use_center_loss:
center_loss = self.center_loss_func(
predicts, batch)["loss_center"] * self.center_loss_weight
loss = loss + center_loss
if self.use_ace_loss:
ace_loss = self.ace_loss_func(
predicts, batch)["loss_ace"] * self.ace_loss_weight
loss = loss + ace_loss
return {'enhanced_ctc_loss': loss}
......@@ -40,13 +40,13 @@ infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:6
--cpu_threads:1|6
--rec_batch_num:1
--use_tensorrt:False|True
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
--save_log_path:null
null:null
--benchmark:True
null:null
===========================cpp_infer_params===========================
......@@ -79,4 +79,20 @@ op.det.local_service_conf.thread_num:1|6
op.det.local_service_conf.use_trt:False|True
op.det.local_service_conf.precision:fp32|fp16|int8
pipline:pipeline_http_client.py --image_dir=../../doc/imgs
===========================kl_quant_params===========================
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
--use_tensorrt:False|True
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
null:null
null:null
\ No newline at end of file
......@@ -12,10 +12,10 @@ train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train|pact_train
norm_train:tools/train.py -c tests/configs/det_r50_vd_db.yml -o Global.pretrained_model=""
pact_train:null
fpgm_train:null
trainer:norm_train|pact_train|fpgm_export
norm_train:tools/train.py -c tests/configs/det_r50_vd_db.yml -o
quant_export:deploy/slim/quantization/export_model.py -c tests/configs/det_r50_vd_db.yml -o
fpgm_export:deploy/slim/prune/export_prune_model.py -c tests/configs/det_r50_vd_db.yml -o
distill_train:null
null:null
null:null
......@@ -34,8 +34,8 @@ distill_export:null
export1:null
export2:null
##
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
infer_export:null
train_model:./inference/ch_ppocr_server_v2.0_det_train/best_accuracy
infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
......
===========================train_params===========================
model_name:ocr_system
python:python3.7
gpu_list:null
Global.use_gpu:null
Global.auto_cast:null
Global.epoch_num:null
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:null
Global.pretrained_model:null
train_model_name:null
train_infer_img_dir:null
null:null
##
trainer:
norm_train:null
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.pretrained_model:
norm_export:null
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
kl_quant:deploy/slim/quantization/quant_kl.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
infer_quant:True
inference:tools/infer/predict_det.py
--use_gpu:TrueFalse
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
--use_tensorrt:False|True
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
--save_log_path:null
--benchmark:True
null:null
#!/bin/bash
FILENAME=$1
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer', 'serving_infer']
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer',
# 'cpp_infer', 'serving_infer', 'klquant_infer']
MODE=$2
dataline=$(cat ${FILENAME})
......@@ -72,9 +74,9 @@ elif [ ${MODE} = "infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ocr_server_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ocr_system_mobile" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
......@@ -98,6 +100,12 @@ elif [ ${MODE} = "infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar
cd ./inference && tar xf ${eval_model_name}.tar && tar xf rec_inference.tar && cd ../
fi
elif [ ${MODE} = "klquant_infer" ];then
if [ ${model_name} = "ocr_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
fi
elif [ ${MODE} = "cpp_infer" ];then
if [ ${model_name} = "ocr_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment