Commit d55e9065 authored by WenmuZhou's avatar WenmuZhou
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into doc

parents 0d6a4862 2b6c887a
Global:
debug: false
use_gpu: true
epoch_num: 100
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_ppocr_v3_rotnet
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model: null
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: 25
infer_mode: false
use_space_char: true
save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
regularizer:
name: L2
factor: 1.0e-05
Architecture:
model_type: cls
algorithm: CLS
Transform: null
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2]
last_pool_type: avg
Neck:
Head:
name: ClsHead
class_dim: 4
Loss:
name: ClsLoss
main_indicator: acc
PostProcess:
name: ClsPostProcess
Metric:
name: ClsMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
use_tia: False
- RandAugment:
- SSLRotateResize:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys: ["image", "label"]
loader:
collate_fn: "SSLRotateCollate"
shuffle: true
batch_size_per_card: 32
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- SSLRotateResize:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys: ["image", "label"]
loader:
collate_fn: "SSLRotateCollate"
shuffle: false
drop_last: false
batch_size_per_card: 64
num_workers: 8
profiler_options: null
Global:
use_gpu: True
epoch_num: 20
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/svtr/
save_epoch_step: 1
# evaluation is run every 2000 iterations after the 0th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
character_type: en
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_svtr_tiny.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
epsilon: 0.00000008
weight_decay: 0.05
no_weight_decay_name: norm pos_embed
one_dim_param_no_weight_decay: true
lr:
name: Cosine
learning_rate: 0.0005
warmup_epoch: 2
Architecture:
model_type: rec
algorithm: SVTR
Transform:
name: STN_ON
tps_inputsize: [32, 64]
tps_outputsize: [32, 100]
num_control_points: 20
tps_margins: [0.05,0.05]
stn_activation: none
Backbone:
name: SVTRNet
img_size: [32, 100]
out_char_num: 25
out_channels: 192
patch_merging: 'Conv'
embed_dim: [64, 128, 256]
depth: [3, 6, 3]
num_heads: [2, 4, 8]
mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global']
local_mixer: [[7, 11], [7, 11], [7, 11]]
last_stage: True
prenorm: false
Neck:
name: SequenceEncoder
encoder_type: reshape
Head:
name: CTCHead
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 512
drop_last: True
num_workers: 4
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 2
...@@ -22,9 +22,11 @@ PP-OCR has supported muti deployment schemes. Click the link to get the specific ...@@ -22,9 +22,11 @@ PP-OCR has supported muti deployment schemes. Click the link to get the specific
- [Python Inference](../doc/doc_en/inference_ppocr_en.md) - [Python Inference](../doc/doc_en/inference_ppocr_en.md)
- [C++ Inference](./cpp_infer/readme.md) - [C++ Inference](./cpp_infer/readme.md)
- [Serving](./pdserving/README.md) - [Serving (Python/C++)](./pdserving/README.md)
- [Paddle-Lite](./lite/readme.md) - [Paddle-Lite (ARM CPU/OpenCL ARM GPU/Metal ARM GPU)](./lite/readme.md)
- [Paddle.js](./paddlejs/README.md) - [Paddle.js](./paddlejs/README.md)
- [Jetson Inference]()
- [XPU Inference]()
- [Paddle2ONNX](./paddle2onnx/readme.md) - [Paddle2ONNX](./paddle2onnx/readme.md)
If you need the deployment tutorial of academic algorithm models other than PP-OCR, please directly enter the main page of corresponding algorithms, [entrance](../doc/doc_en/algorithm_overview_en.md) If you need the deployment tutorial of academic algorithm models other than PP-OCR, please directly enter the main page of corresponding algorithms, [entrance](../doc/doc_en/algorithm_overview_en.md)
\ No newline at end of file
...@@ -22,9 +22,11 @@ PP-OCR模型已打通多种场景部署方案,点击链接获取具体的使 ...@@ -22,9 +22,11 @@ PP-OCR模型已打通多种场景部署方案,点击链接获取具体的使
- [Python 推理](../doc/doc_ch/inference_ppocr.md) - [Python 推理](../doc/doc_ch/inference_ppocr.md)
- [C++ 推理](./cpp_infer/readme_ch.md) - [C++ 推理](./cpp_infer/readme_ch.md)
- [Serving 服务化部署](./pdserving/README_CN.md) - [Serving 服务化部署(Python/C++)](./pdserving/README_CN.md)
- [Paddle-Lite 端侧部署](./lite/readme_ch.md) - [Paddle-Lite 端侧部署(ARM CPU/OpenCL ARM GPU/Metal ARM GPU)](./lite/readme_ch.md)
- [Paddle.js 服务化部署](./paddlejs/README_ch.md) - [Paddle.js 部署](./paddlejs/README_ch.md)
- [Jetson 推理]()
- [XPU 推理]()
- [Paddle2ONNX 推理](./paddle2onnx/readme_ch.md) - [Paddle2ONNX 推理](./paddle2onnx/readme_ch.md)
需要PP-OCR以外的学术算法模型的推理部署,请直接进入相应算法主页面,[入口](../doc/doc_ch/algorithm_overview.md) 需要PP-OCR以外的学术算法模型的推理部署,请直接进入相应算法主页面,[入口](../doc/doc_ch/algorithm_overview.md)
\ No newline at end of file
...@@ -35,17 +35,7 @@ from ppocr.metrics import build_metric ...@@ -35,17 +35,7 @@ from ppocr.metrics import build_metric
import tools.program as program import tools.program as program
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from tools.export_model import export_single_model
def export_single_model(quanter, model, infer_shape, save_path, logger):
quanter.save_quantized_model(
model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
logger.info('inference QAT model is saved to {}'.format(save_path))
def main(): def main():
...@@ -84,17 +74,54 @@ def main(): ...@@ -84,17 +74,54 @@ def main():
config['Global']) config['Global'])
# build model # build model
# for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation", if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model ]: # distillation model
for key in config['Architecture']["Models"]: for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][ if config['Architecture']['Models'][key]['Head'][
'out_channels'] = char_num 'name'] == 'MultiHead': # for multi head
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][-1].keys())[
0] == 'DistillationSARLoss'
config['Loss']['loss_config_list'][-1][
'DistillationSARLoss']['ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][1].keys())[
0] == 'SARLoss'
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
config['Loss']['loss_config_list'][1]['SARLoss'] = {
'ignore_index': char_num + 1
}
else:
config['Loss']['loss_config_list'][1]['SARLoss'][
'ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
# get QAT model # get QAT model
...@@ -120,21 +147,22 @@ def main(): ...@@ -120,21 +147,22 @@ def main():
for k, v in metric.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640]
save_path = config["Global"]["save_inference_dir"] save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"] arch_config = config["Architecture"]
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list): for idx, name in enumerate(model.model_name_list):
model.model_list[idx].eval() model.model_list[idx].eval()
sub_model_save_path = os.path.join(save_path, name, "inference") sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(quanter, model.model_list[idx], infer_shape, export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger) sub_model_save_path, logger, quanter)
else: else:
save_path = os.path.join(save_path, "inference") save_path = os.path.join(save_path, "inference")
model.eval() export_single_model(model, arch_config, save_path, logger, quanter)
export_single_model(quanter, model, infer_shape, save_path, logger)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -112,10 +112,48 @@ def main(config, device, logger, vdl_writer): ...@@ -112,10 +112,48 @@ def main(config, device, logger, vdl_writer):
if config['Architecture']["algorithm"] in ["Distillation", if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model ]: # distillation model
for key in config['Architecture']["Models"]: for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][ if config['Architecture']['Models'][key]['Head'][
'out_channels'] = char_num 'name'] == 'MultiHead': # for multi head
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][-1].keys())[
0] == 'DistillationSARLoss'
config['Loss']['loss_config_list'][-1][
'DistillationSARLoss']['ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][1].keys())[
0] == 'SARLoss'
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
config['Loss']['loss_config_list'][1]['SARLoss'] = {
'ignore_index': char_num + 1
}
else:
config['Loss']['loss_config_list'][1]['SARLoss'][
'ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
pre_best_model_dict = dict() pre_best_model_dict = dict()
...@@ -137,7 +175,7 @@ def main(config, device, logger, vdl_writer): ...@@ -137,7 +175,7 @@ def main(config, device, logger, vdl_writer):
config['Optimizer'], config['Optimizer'],
epochs=config['Global']['epoch_num'], epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_dataloader), step_each_epoch=len(train_dataloader),
parameters=model.parameters()) model=model)
# resume PACT training process # resume PACT training process
if config["Global"]["checkpoints"] is not None: if config["Global"]["checkpoints"] is not None:
......
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接| |模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- |
|DB|ResNet50_vd|configs/det/det_r50_vd_db.yml|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| |DB|ResNet50_vd|[configs/det/det_r50_vd_db.yml](../../configs/det/det_r50_vd_db.yml)|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|configs/det/det_mv3_db.yml|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |DB|MobileNetV3|[configs/det/det_mv3_db.yml](../../configs/det/det_mv3_db.yml)|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
<a name="2"></a> <a name="2"></a>
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#### 2、CDLA数据集 #### 2、CDLA数据集
- **数据来源**:https://github.com/buptlihang/CDLA - **数据来源**:https://github.com/buptlihang/CDLA
- **数据简介**publaynet数据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是: `Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`。部分图像以及标注框可视化如下所示。 - **数据简介**CDLA据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是: `Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`。部分图像以及标注框可视化如下所示。
<div align="center"> <div align="center">
<img src="../datasets/CDLA_demo/val_0633.jpg" width="500"> <img src="../datasets/CDLA_demo/val_0633.jpg" width="500">
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
PP-OCR是PaddleOCR自研的实用的超轻量OCR系统。在实现[前沿算法](algorithm.md)的基础上,考虑精度与速度的平衡,进行**模型瘦身****深度优化**,使其尽可能满足产业落地需求。 PP-OCR是PaddleOCR自研的实用的超轻量OCR系统。在实现[前沿算法](algorithm.md)的基础上,考虑精度与速度的平衡,进行**模型瘦身****深度优化**,使其尽可能满足产业落地需求。
#### PP-OCR
PP-OCR是一个两阶段的OCR系统,其中文本检测算法选用[DB](algorithm_det_db.md),文本识别算法选用[CRNN](algorithm_rec_crnn.md),并在检测和识别模块之间添加[文本方向分类器](angle_class.md),以应对不同方向的文本识别。 PP-OCR是一个两阶段的OCR系统,其中文本检测算法选用[DB](algorithm_det_db.md),文本识别算法选用[CRNN](algorithm_rec_crnn.md),并在检测和识别模块之间添加[文本方向分类器](angle_class.md),以应对不同方向的文本识别。
PP-OCR系统pipeline如下: PP-OCR系统pipeline如下:
...@@ -28,9 +30,13 @@ PP-OCR系统pipeline如下: ...@@ -28,9 +30,13 @@ PP-OCR系统pipeline如下:
PP-OCR系统在持续迭代优化,目前已发布PP-OCR和PP-OCRv2两个版本: PP-OCR系统在持续迭代优化,目前已发布PP-OCR和PP-OCRv2两个版本:
[1] PP-OCR从骨干网络选择和调整、预测头部的设计、数据增强、学习率变换策略、正则化参数选择、预训练模型使用以及模型自动裁剪量化8个方面,采用19个有效策略,对各个模块的模型进行效果调优和瘦身(如绿框所示),最终得到整体大小为3.5M的超轻量中英文OCR和2.8M的英文数字OCR。更多细节请参考PP-OCR技术方案 https://arxiv.org/abs/2009.09941 PP-OCR从骨干网络选择和调整、预测头部的设计、数据增强、学习率变换策略、正则化参数选择、预训练模型使用以及模型自动裁剪量化8个方面,采用19个有效策略,对各个模块的模型进行效果调优和瘦身(如绿框所示),最终得到整体大小为3.5M的超轻量中英文OCR和2.8M的英文数字OCR。更多细节请参考PP-OCR技术方案 https://arxiv.org/abs/2009.09941
#### PP-OCRv2
PP-OCRv2在PP-OCR的基础上,进一步在5个方面重点优化,检测模型采用CML协同互学习知识蒸馏策略和CopyPaste数据增广策略;识别模型采用LCNet轻量级骨干网络、UDML 改进知识蒸馏策略和[Enhanced CTC loss](./doc/doc_ch/enhanced_ctc_loss.md)损失函数改进(如上图红框所示),进一步在推理速度和预测效果上取得明显提升。更多细节请参考PP-OCRv2[技术报告](https://arxiv.org/abs/2109.03144)
[2] PP-OCRv2在PP-OCR的基础上,进一步在5个方面重点优化,检测模型采用CML协同互学习知识蒸馏策略和CopyPaste数据增广策略;识别模型采用LCNet轻量级骨干网络、UDML 改进知识蒸馏策略和[Enhanced CTC loss](./doc/doc_ch/enhanced_ctc_loss.md)损失函数改进(如上图红框所示),进一步在推理速度和预测效果上取得明显提升。更多细节请参考PP-OCRv2[技术报告](https://arxiv.org/abs/2109.03144) #### PP-OCRv3
<a name="2"></a> <a name="2"></a>
......
...@@ -25,8 +25,8 @@ On the ICDAR2015 dataset, the text detection result is as follows: ...@@ -25,8 +25,8 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|Model|Backbone|Configuration|Precision|Recall|Hmean|Download| |Model|Backbone|Configuration|Precision|Recall|Hmean|Download|
| --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- |
|DB|ResNet50_vd|configs/det/det_r50_vd_db.yml|86.41%|78.72%|82.38%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| |DB|ResNet50_vd|[configs/det/det_r50_vd_db.yml](../../configs/det/det_r50_vd_db.yml)|86.41%|78.72%|82.38%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|configs/det/det_mv3_db.yml|77.29%|73.08%|75.12%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |DB|MobileNetV3|[configs/det/det_mv3_db.yml](../../configs/det/det_mv3_db.yml)|77.29%|73.08%|75.12%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
<a name="2"></a> <a name="2"></a>
......
doc/joinus.PNG

201 KB | W: | H:

doc/joinus.PNG

200 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -72,6 +72,7 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -72,6 +72,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
use_shared_memory = loader_config['use_shared_memory'] use_shared_memory = loader_config['use_shared_memory']
else: else:
use_shared_memory = True use_shared_memory = True
if mode == "Train": if mode == "Train":
# Distribute data to multiple cards # Distribute data to multiple cards
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
......
...@@ -56,3 +56,17 @@ class ListCollator(object): ...@@ -56,3 +56,17 @@ class ListCollator(object):
for idx in to_tensor_idxs: for idx in to_tensor_idxs:
data_dict[idx] = paddle.to_tensor(data_dict[idx]) data_dict[idx] = paddle.to_tensor(data_dict[idx])
return list(data_dict.values()) return list(data_dict.values())
class SSLRotateCollate(object):
"""
bach: [
[(4*3xH*W), (4,)]
[(4*3xH*W), (4,)]
...
]
"""
def __call__(self, batch):
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
return output
...@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask ...@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
from .ColorJitter import ColorJitter from .ColorJitter import ColorJitter
......
...@@ -113,14 +113,14 @@ class BaseRecLabelEncode(object): ...@@ -113,14 +113,14 @@ class BaseRecLabelEncode(object):
dict_character = list(self.character_str) dict_character = list(self.character_str)
self.lower = True self.lower = True
else: else:
self.character_str = "" self.character_str = []
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n") line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line self.character_str.append(line)
if use_space_char: if use_space_char:
self.character_str += " " self.character_str.append(" ")
dict_character = list(self.character_str) dict_character = list(self.character_str)
dict_character = self.add_special_char(dict_character) dict_character = self.add_special_char(dict_character)
self.dict = {} self.dict = {}
......
...@@ -16,6 +16,7 @@ import math ...@@ -16,6 +16,7 @@ import math
import cv2 import cv2
import numpy as np import numpy as np
import random import random
import copy
from PIL import Image from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort from .text_image_aug import tia_perspective, tia_stretch, tia_distort
...@@ -81,7 +82,7 @@ class ClsResizeImg(object): ...@@ -81,7 +82,7 @@ class ClsResizeImg(object):
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
norm_img = resize_norm_img(img, self.image_shape) norm_img, _ = resize_norm_img(img, self.image_shape)
data['image'] = norm_img data['image'] = norm_img
return data return data
...@@ -206,6 +207,25 @@ class PRENResizeImg(object): ...@@ -206,6 +207,25 @@ class PRENResizeImg(object):
return data return data
class SVTRRecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_svtr(img, self.image_shape, self.padding)
data['image'] = norm_img
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0] h = img.shape[0]
...@@ -324,6 +344,58 @@ def resize_norm_img_srn(img, image_shape): ...@@ -324,6 +344,58 @@ def resize_norm_img_srn(img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32) return np.reshape(img_black, (c, row, col)).astype(np.float32)
def resize_norm_img_svtr(img, image_shape, padding=False):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
if h > 2.0 * w:
image = Image.fromarray(img)
image1 = image.rotate(90, expand=True)
image2 = image.rotate(-90, expand=True)
img1 = np.array(image1)
img2 = np.array(image2)
else:
img1 = copy.deepcopy(img)
img2 = copy.deepcopy(img)
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image1 = cv2.resize(
img1, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image2 = cv2.resize(
img2, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image1 = resized_image1.astype('float32')
resized_image2 = resized_image2.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image1 = resized_image1.transpose((2, 0, 1)) / 255
resized_image2 = resized_image2.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resized_image1 -= 0.5
resized_image1 /= 0.5
resized_image2 -= 0.5
resized_image2 /= 0.5
padding_im = np.zeros((3, imgC, imgH, imgW), dtype=np.float32)
padding_im[0, :, :, 0:resized_w] = resized_image
padding_im[1, :, :, 0:resized_w] = resized_image1
padding_im[2, :, :, 0:resized_w] = resized_image2
return padding_im
def srn_other_inputs(image_shape, num_heads, max_text_length): def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
import random
from PIL import Image
from .rec_img_aug import resize_norm_img
class SSLRotateResize(object):
def __init__(self,
image_shape,
padding=False,
select_all=True,
mode="train",
**kwargs):
self.image_shape = image_shape
self.padding = padding
self.select_all = select_all
self.mode = mode
def __call__(self, data):
img = data["image"]
data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
data["image_r180"] = cv2.rotate(data["image_r90"],
cv2.ROTATE_90_CLOCKWISE)
data["image_r270"] = cv2.rotate(data["image_r180"],
cv2.ROTATE_90_CLOCKWISE)
images = []
for key in ["image", "image_r90", "image_r180", "image_r270"]:
images.append(
resize_norm_img(
data.pop(key),
image_shape=self.image_shape,
padding=self.padding)[0])
data["image"] = np.stack(images, axis=0)
data["label"] = np.array(list(range(4)))
if not self.select_all:
data["image"] = data["image"][0::2] # just choose 0 and 180
data["label"] = data["label"][0:2] # label needs to be continuous
if self.mode == "test":
data["image"] = data["image"][0]
data["label"] = data["label"][0]
return data
...@@ -296,47 +296,49 @@ class PatchEmbed(nn.Layer): ...@@ -296,47 +296,49 @@ class PatchEmbed(nn.Layer):
if sub_num == 2: if sub_num == 2:
self.proj = nn.Sequential( self.proj = nn.Sequential(
ConvBNLayer( ConvBNLayer(
in_channels, in_channels=in_channels,
embed_dim // 2, out_channels=embed_dim // 2,
3, kernel_size=3,
2, stride=2,
1, padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None), bias_attr=None),
ConvBNLayer( ConvBNLayer(
embed_dim // 2, in_channels=embed_dim // 2,
embed_dim, out_channels=embed_dim,
3, kernel_size=3,
2, stride=2,
1, padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None)) bias_attr=None))
if sub_num == 3: if sub_num == 3:
self.proj = nn.Sequential( self.proj = nn.Sequential(
ConvBNLayer( ConvBNLayer(
in_channels, in_channels=in_channels,
embed_dim // 4, out_channels=embed_dim // 4,
3, kernel_size=3,
2, stride=2,
1, padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None), bias_attr=None),
ConvBNLayer( ConvBNLayer(
embed_dim // 4, in_channels=embed_dim // 4,
embed_dim // 2, out_channels=embed_dim // 2,
3, kernel_size=3,
2, stride=2,
1, padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None), bias_attr=None),
ConvBNLayer( ConvBNLayer(
embed_dim // 2, embed_dim // 2,
embed_dim, embed_dim,
3, in_channels=embed_dim // 2,
2, out_channels=embed_dim,
1, kernel_size=3,
stride=2,
padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None), ) bias_attr=None))
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
...@@ -455,7 +457,7 @@ class SVTRNet(nn.Layer): ...@@ -455,7 +457,7 @@ class SVTRNet(nn.Layer):
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_scale=qk_scale, qk_scale=qk_scale,
drop=drop_rate, drop=drop_rate,
act_layer=nn.Swish, act_layer=eval(act),
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[0:depth[0]][i], drop_path=dpr[0:depth[0]][i],
norm_layer=norm_layer, norm_layer=norm_layer,
......
...@@ -128,6 +128,8 @@ class STN_ON(nn.Layer): ...@@ -128,6 +128,8 @@ class STN_ON(nn.Layer):
self.out_channels = in_channels self.out_channels = in_channels
def forward(self, image): def forward(self, image):
if len(image.shape)==5:
image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]])
stn_input = paddle.nn.functional.interpolate( stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True) image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input) stn_img_feat, ctrl_points = self.stn_head(stn_input)
......
...@@ -138,9 +138,9 @@ class TPSSpatialTransformer(nn.Layer): ...@@ -138,9 +138,9 @@ class TPSSpatialTransformer(nn.Layer):
assert source_control_points.shape[2] == 2 assert source_control_points.shape[2] == 2
batch_size = paddle.shape(source_control_points)[0] batch_size = paddle.shape(source_control_points)[0]
self.padding_matrix = paddle.expand( padding_matrix = paddle.expand(
self.padding_matrix, shape=[batch_size, 3, 2]) self.padding_matrix, shape=[batch_size, 3, 2])
Y = paddle.concat([source_control_points, self.padding_matrix], 1) Y = paddle.concat([source_control_points, padding_matrix], 1)
mapping_matrix = paddle.matmul(self.inverse_kernel, Y) mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
source_coordinate = paddle.matmul(self.target_coordinate_repr, source_coordinate = paddle.matmul(self.target_coordinate_repr,
mapping_matrix) mapping_matrix)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment