Commit 4f8b5113 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents d73ed79c 370f0fef
...@@ -704,8 +704,9 @@ class Canvas(QWidget): ...@@ -704,8 +704,9 @@ class Canvas(QWidget):
def keyPressEvent(self, ev): def keyPressEvent(self, ev):
key = ev.key() key = ev.key()
shapesBackup = []
shapesBackup = copy.deepcopy(self.shapes) shapesBackup = copy.deepcopy(self.shapes)
if len(shapesBackup) == 0:
return
self.shapesBackups.pop() self.shapesBackups.pop()
self.shapesBackups.append(shapesBackup) self.shapesBackups.append(shapesBackup)
if key == Qt.Key_Escape and self.current: if key == Qt.Key_Escape and self.current:
......
...@@ -18,6 +18,7 @@ Global: ...@@ -18,6 +18,7 @@ Global:
Architecture: Architecture:
name: DistillationModel name: DistillationModel
algorithm: Distillation algorithm: Distillation
model_type: det
Models: Models:
Teacher: Teacher:
freeze_params: true freeze_params: true
......
...@@ -111,7 +111,7 @@ def main(): ...@@ -111,7 +111,7 @@ def main():
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type'] model_type = config['Architecture'].get('model_type', None)
# start eval # start eval
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn) eval_class, model_type, use_srn)
...@@ -120,8 +120,7 @@ def main(): ...@@ -120,8 +120,7 @@ def main():
for k, v in metric.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
infer_shape = [3, 32, 100] if config['Architecture'][ infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640]
'model_type'] != "det" else [3, 640, 640]
save_path = config["Global"]["save_inference_dir"] save_path = config["Global"]["save_inference_dir"]
......
...@@ -49,7 +49,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429 ...@@ -49,7 +49,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429
- 每个样本固定10个字符,字符随机截取自语料库中的句子 - 每个样本固定10个字符,字符随机截取自语料库中的句子
- 图片分辨率统一为280x32 - 图片分辨率统一为280x32
![](../datasets/ch_doc1.jpg) ![](../datasets/ch_doc1.jpg)
![](../datasets/ch_doc2.jpg)
![](../datasets/ch_doc3.jpg) ![](../datasets/ch_doc3.jpg)
- **下载地址**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (密码:lu7m) - **下载地址**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (密码:lu7m)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
```shell ```shell
python3 -m paddle.distributed.launch \ python3 -m paddle.distributed.launch \
--log_dir=./log/ \ --log_dir=./log/ \
--gpus '0,1,2,3,4,5,6,7' \ --gpus "0,1,2,3,4,5,6,7" \
tools/train.py \ tools/train.py \
-c configs/rec/rec_mv3_none_bilstm_ctc.yml -c configs/rec/rec_mv3_none_bilstm_ctc.yml
``` ```
......
...@@ -50,7 +50,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429 ...@@ -50,7 +50,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429
- Each sample is fixed with 10 characters, and the characters are randomly intercepted from the sentences in the corpus - Each sample is fixed with 10 characters, and the characters are randomly intercepted from the sentences in the corpus
- Image resolution is 280x32 - Image resolution is 280x32
![](../datasets/ch_doc1.jpg) ![](../datasets/ch_doc1.jpg)
![](../datasets/ch_doc2.jpg)
![](../datasets/ch_doc3.jpg) ![](../datasets/ch_doc3.jpg)
- **Download link**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (Password: lu7m) - **Download link**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (Password: lu7m)
......
...@@ -13,7 +13,7 @@ Take recognition as an example. After the data is prepared locally, start the tr ...@@ -13,7 +13,7 @@ Take recognition as an example. After the data is prepared locally, start the tr
```shell ```shell
python3 -m paddle.distributed.launch \ python3 -m paddle.distributed.launch \
--log_dir=./log/ \ --log_dir=./log/ \
--gpus '0,1,2,3,4,5,6,7' \ --gpus "0,1,2,3,4,5,6,7" \
tools/train.py \ tools/train.py \
-c configs/rec/rec_mv3_none_bilstm_ctc.yml -c configs/rec/rec_mv3_none_bilstm_ctc.yml
``` ```
......
...@@ -32,6 +32,7 @@ class CopyPaste(object): ...@@ -32,6 +32,7 @@ class CopyPaste(object):
self.aug = IaaAugment(augmenter_args) self.aug = IaaAugment(augmenter_args)
def __call__(self, data): def __call__(self, data):
point_num = data['polys'].shape[1]
src_img = data['image'] src_img = data['image']
src_polys = data['polys'].tolist() src_polys = data['polys'].tolist()
src_ignores = data['ignore_tags'].tolist() src_ignores = data['ignore_tags'].tolist()
...@@ -57,6 +58,9 @@ class CopyPaste(object): ...@@ -57,6 +58,9 @@ class CopyPaste(object):
src_img, box = self.paste_img(src_img, box_img, src_polys) src_img, box = self.paste_img(src_img, box_img, src_polys)
if box is not None: if box is not None:
box = box.tolist()
for _ in range(len(box), point_num):
box.append(box[-1])
src_polys.append(box) src_polys.append(box)
src_ignores.append(tag) src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import numpy as np import numpy as np
import os import os
import random import random
import traceback
from paddle.io import Dataset from paddle.io import Dataset
from .imaug import transform, create_operators from .imaug import transform, create_operators
...@@ -93,7 +94,8 @@ class SimpleDataSet(Dataset): ...@@ -93,7 +94,8 @@ class SimpleDataSet(Dataset):
img = f.read() img = f.read()
data['image'] = img data['image'] = img
data = transform(data, load_data_ops) data = transform(data, load_data_ops)
if data is None:
if data is None or data['polys'].shape[1]!=4:
continue continue
ext_data.append(data) ext_data.append(data)
return ext_data return ext_data
...@@ -115,10 +117,10 @@ class SimpleDataSet(Dataset): ...@@ -115,10 +117,10 @@ class SimpleDataSet(Dataset):
data['image'] = img data['image'] = img
data['ext_data'] = self.get_ext_data() data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops) outs = transform(data, self.ops)
except Exception as e: except:
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(
data_line, e)) data_line, traceback.format_exc()))
outs = None outs = None
if outs is None: if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation. # during evaluation, we should fix the idx to get same results for many times of evaluation.
......
...@@ -25,16 +25,14 @@ __all__ = ["ResNet"] ...@@ -25,16 +25,14 @@ __all__ = ["ResNet"]
class ConvBNLayer(nn.Layer): class ConvBNLayer(nn.Layer):
def __init__( def __init__(self,
self, in_channels,
in_channels, out_channels,
out_channels, kernel_size,
kernel_size, stride=1,
stride=1, groups=1,
groups=1, is_vd_mode=False,
is_vd_mode=False, act=None):
act=None,
name=None, ):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode self.is_vd_mode = is_vd_mode
...@@ -47,19 +45,8 @@ class ConvBNLayer(nn.Layer): ...@@ -47,19 +45,8 @@ class ConvBNLayer(nn.Layer):
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
groups=groups, groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False) bias_attr=False)
if name == "conv1": self._batch_norm = nn.BatchNorm(out_channels, act=act)
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs): def forward(self, inputs):
if self.is_vd_mode: if self.is_vd_mode:
...@@ -75,29 +62,25 @@ class BottleneckBlock(nn.Layer): ...@@ -75,29 +62,25 @@ class BottleneckBlock(nn.Layer):
out_channels, out_channels,
stride, stride,
shortcut=True, shortcut=True,
if_first=False, if_first=False):
name=None):
super(BottleneckBlock, self).__init__() super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=1, kernel_size=1,
act='relu', act='relu')
name=name + "_branch2a")
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
stride=stride, stride=stride,
act='relu', act='relu')
name=name + "_branch2b")
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels * 4, out_channels=out_channels * 4,
kernel_size=1, kernel_size=1,
act=None, act=None)
name=name + "_branch2c")
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
...@@ -105,8 +88,7 @@ class BottleneckBlock(nn.Layer): ...@@ -105,8 +88,7 @@ class BottleneckBlock(nn.Layer):
out_channels=out_channels * 4, out_channels=out_channels * 4,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
is_vd_mode=False if if_first else True, is_vd_mode=False if if_first else True)
name=name + "_branch1")
self.shortcut = shortcut self.shortcut = shortcut
...@@ -125,13 +107,13 @@ class BottleneckBlock(nn.Layer): ...@@ -125,13 +107,13 @@ class BottleneckBlock(nn.Layer):
class BasicBlock(nn.Layer): class BasicBlock(nn.Layer):
def __init__(self, def __init__(
in_channels, self,
out_channels, in_channels,
stride, out_channels,
shortcut=True, stride,
if_first=False, shortcut=True,
name=None): if_first=False, ):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.stride = stride self.stride = stride
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
...@@ -139,14 +121,12 @@ class BasicBlock(nn.Layer): ...@@ -139,14 +121,12 @@ class BasicBlock(nn.Layer):
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
stride=stride, stride=stride,
act='relu', act='relu')
name=name + "_branch2a")
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
act=None, act=None)
name=name + "_branch2b")
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
...@@ -154,8 +134,7 @@ class BasicBlock(nn.Layer): ...@@ -154,8 +134,7 @@ class BasicBlock(nn.Layer):
out_channels=out_channels, out_channels=out_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
is_vd_mode=False if if_first else True, is_vd_mode=False if if_first else True)
name=name + "_branch1")
self.shortcut = shortcut self.shortcut = shortcut
...@@ -201,22 +180,19 @@ class ResNet(nn.Layer): ...@@ -201,22 +180,19 @@ class ResNet(nn.Layer):
out_channels=32, out_channels=32,
kernel_size=3, kernel_size=3,
stride=2, stride=2,
act='relu', act='relu')
name="conv1_1")
self.conv1_2 = ConvBNLayer( self.conv1_2 = ConvBNLayer(
in_channels=32, in_channels=32,
out_channels=32, out_channels=32,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
act='relu', act='relu')
name="conv1_2")
self.conv1_3 = ConvBNLayer( self.conv1_3 = ConvBNLayer(
in_channels=32, in_channels=32,
out_channels=64, out_channels=64,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
act='relu', act='relu')
name="conv1_3")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = [] self.stages = []
...@@ -226,13 +202,6 @@ class ResNet(nn.Layer): ...@@ -226,13 +202,6 @@ class ResNet(nn.Layer):
block_list = [] block_list = []
shortcut = False shortcut = False
for i in range(depth[block]): for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer( bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
BottleneckBlock( BottleneckBlock(
...@@ -241,8 +210,7 @@ class ResNet(nn.Layer): ...@@ -241,8 +210,7 @@ class ResNet(nn.Layer):
out_channels=num_filters[block], out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1, stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut, shortcut=shortcut,
if_first=block == i == 0, if_first=block == i == 0))
name=conv_name))
shortcut = True shortcut = True
block_list.append(bottleneck_block) block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4) self.out_channels.append(num_filters[block] * 4)
...@@ -252,7 +220,6 @@ class ResNet(nn.Layer): ...@@ -252,7 +220,6 @@ class ResNet(nn.Layer):
block_list = [] block_list = []
shortcut = False shortcut = False
for i in range(depth[block]): for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer( basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
BasicBlock( BasicBlock(
...@@ -261,8 +228,7 @@ class ResNet(nn.Layer): ...@@ -261,8 +228,7 @@ class ResNet(nn.Layer):
out_channels=num_filters[block], out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1, stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut, shortcut=shortcut,
if_first=block == i == 0, if_first=block == i == 0))
name=conv_name))
shortcut = True shortcut = True
block_list.append(basic_block) block_list.append(basic_block)
self.out_channels.append(num_filters[block]) self.out_channels.append(num_filters[block])
......
# 视觉问答(VQA) # 文档视觉问答(DOC-VQA)
VQA主要特性如下: VQA指视觉问答,主要针对图像内容进行提问和回答,DOC-VQA是VQA任务中的一种,DOC-VQA主要针对文本图像的文字内容提出问题。
PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进行开发。
主要特性如下:
- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。 - 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。
- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取(比如判断问题对) - 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)。
- 支持SER任务与OCR引擎联合的端到端系统预测与评估。 - 支持SER任务和RE任务的自定义训练。
- 支持SER任务和RE任务的自定义训练 - 支持OCR+SER的端到端系统预测与评估。
- 支持OCR+SER+RE的端到端系统预测。
本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现, 本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现,
包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。 包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。
## 1. 效果演示 ## 1 性能
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 评估数据集上对算法进行了评估,性能如下
|任务| f1 | 模型下载地址|
|:---:|:---:| :---:|
|SER|0.9056| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)|
|RE|0.7113| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar)|
## 2. 效果演示
**注意:** 测试图片来源于XFUN数据集。 **注意:** 测试图片来源于XFUN数据集。
### 1.1 SER ### 2.1 SER
<div align="center"> ![](./images/result_ser/zh_val_0_ser.jpg) | ![](./images/result_ser/zh_val_42_ser.jpg)
<img src="./images/result_ser/zh_val_0_ser.jpg" width = "600" /> ---|---
</div>
<div align="center"> 图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
<img src="./images/result_ser/zh_val_42_ser.jpg" width = "600" />
</div>
其中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别,在OCR检测框的左上方也标出了对应的类别和OCR识别结果。 * 深紫色:HEADER
* 浅紫色:QUESTION
* 军绿色:ANSWER
在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
### 1.2 RE
* Coming soon! ### 2.2 RE
![](./images/result_re/zh_val_21_re.jpg) | ![](./images/result_re/zh_val_40_re.jpg)
---|---
## 2. 安装 图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
### 2.1 安装依赖
## 3. 安装
### 3.1 安装依赖
- **(1) 安装PaddlePaddle** - **(1) 安装PaddlePaddle**
...@@ -53,12 +73,12 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple ...@@ -53,12 +73,12 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
### 2.2 安装PaddleOCR(包含 PP-OCR 和 VQA ) ### 3.2 安装PaddleOCR(包含 PP-OCR 和 VQA )
- **(1)pip快速安装PaddleOCR whl包(仅预测)** - **(1)pip快速安装PaddleOCR whl包(仅预测)**
```bash ```bash
pip install "paddleocr>=2.2" # 推荐使用2.2+版本 pip install paddleocr
``` ```
- **(2)下载VQA源码(预测+训练)** - **(2)下载VQA源码(预测+训练)**
...@@ -85,13 +105,14 @@ pip install -e . ...@@ -85,13 +105,14 @@ pip install -e .
- **(4)安装VQA的`requirements`** - **(4)安装VQA的`requirements`**
```bash ```bash
cd ppstructure/vqa
pip install -r requirements.txt pip install -r requirements.txt
``` ```
## 3. 使用 ## 4. 使用
### 3.1 数据和预训练模型准备 ### 4.1 数据和预训练模型准备
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar) 处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)
...@@ -104,18 +125,15 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar ...@@ -104,18 +125,15 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py) 如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)
如果希望直接体验预测过程,可以下载我们提供的SER预训练模型,跳过训练过程,直接预测即可。 如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
* SER任务预训练模型下载链接:[链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)
* RE任务预训练模型下载链接:coming soon!
### 3.2 SER任务 ### 4.2 SER任务
* 启动训练 * 启动训练
```shell ```shell
python train_ser.py \ python3.7 train_ser.py \
--model_name_or_path "layoutxlm-base-uncased" \ --model_name_or_path "layoutxlm-base-uncased" \
--train_data_dir "XFUND/zh_train/image" \ --train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
...@@ -131,13 +149,7 @@ python train_ser.py \ ...@@ -131,13 +149,7 @@ python train_ser.py \
--seed 2048 --seed 2048
``` ```
最终会打印出`precision`, `recall`, `f1`等指标,如下所示。 最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。
```
best metrics: {'loss': 1.066644651549203, 'precision': 0.8770182068017863, 'recall': 0.9361936193619362, 'f1': 0.9056402979780063}
```
模型和训练日志会保存在`./output/ser/`文件夹中。
* 使用评估集合中提供的OCR识别结果进行预测 * 使用评估集合中提供的OCR识别结果进行预测
...@@ -159,21 +171,73 @@ export CUDA_VISIBLE_DEVICES=0 ...@@ -159,21 +171,73 @@ export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_e2e.py \ python3.7 infer_ser_e2e.py \
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \ --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
--max_seq_length 512 \ --max_seq_length 512 \
--output_dir "output_res_e2e/" --output_dir "output_res_e2e/" \
--infer_imgs "images/input/zh_val_0.jpg"
``` ```
*`OCR引擎 + SER`预测系统进行端到端评估 *`OCR引擎 + SER`预测系统进行端到端评估
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
``` ```
3.3 RE任务 ### 3.3 RE任务
coming soon! * 启动训练
```shell
python3 train_re.py \
--model_name_or_path "layoutxlm-base-uncased" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path 'labels/labels_ser.txt' \
--num_train_epochs 2 \
--eval_steps 10 \
--save_steps 500 \
--output_dir "output/re/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \
--evaluate_during_training \
--seed 2048
```
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。
* 使用评估集合中提供的OCR识别结果进行预测
```shell
export CUDA_VISIBLE_DEVICES=0
python3 infer_re.py \
--model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path 'labels/labels_ser.txt' \
--output_dir "output_res" \
--per_gpu_eval_batch_size 1 \
--seed 2048
```
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`
* 使用`OCR引擎 + SER + RE`串联结果
```shell
export CUDA_VISIBLE_DEVICES=0
# python3.7 infer_ser_re_e2e.py \
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
--re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--output_dir "output_ser_re_e2e_train/" \
--infer_imgs "images/input/zh_val_21.jpg"
```
## 参考链接 ## 参考链接
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numbers
import numpy as np
class DataCollator:
"""
data batch
"""
def __call__(self, batch):
data_dict = {}
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
data_dict[k].append(v)
for k in to_tensor_keys:
data_dict[k] = paddle.to_tensor(data_dict[k])
return data_dict
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import cv2
import matplotlib.pyplot as plt
import numpy as np
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, draw_re_results
from data_collator import DataCollator
from ppocr.utils.logging import get_logger
def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger()
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction.from_pretrained(
args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=8,
shuffle=False,
collate_fn=DataCollator())
# 读取gt的oct数据
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
for idx, batch in enumerate(eval_dataloader):
logger.info("[Infer] process: {}/{}".format(idx, len(eval_dataloader)))
with paddle.no_grad():
outputs = model(**batch)
pred_relations = outputs['pred_relations']
ocr_info = ocr_info_list[idx]
image_path = ocr_info['image_path']
ocr_info = ocr_info['ocr_info']
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
# 进行 relations 到 ocr信息的转换
result = []
used_tail_id = []
for relations in pred_relations:
for relation in relations:
if relation['tail_id'] in used_tail_id:
continue
if relation['head_id'] not in ocr_info or relation[
'tail_id'] not in ocr_info:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ocr_info[relation['head_id']]
ocr_info_tail = ocr_info[relation['tail_id']]
result.append((ocr_info_head, ocr_info_tail))
img = cv2.imread(image_path)
img_show = draw_re_results(img, result)
save_path = os.path.join(args.output_dir, os.path.basename(image_path))
cv2.imwrite(save_path, img_show)
def load_ocr(img_folder, json_path):
import json
d = []
with open(json_path, "r") as fin:
lines = fin.readlines()
for line in lines:
image_name, info_str = line.split("\t")
info_dict = json.loads(info_str)
info_dict['image_path'] = os.path.join(img_folder, image_name)
d.append(info_dict)
return d
def filter_bg_by_txt(ocr_info, batch, tokenizer):
entities = batch['entities'][0]
input_ids = batch['input_ids'][0]
new_info_dict = {}
for i in range(len(entities['start'])):
entitie_head = entities['start'][i]
entitie_tail = entities['end'][i]
word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist()
txt = tokenizer.convert_ids_to_tokens(word_input_ids)
txt = tokenizer.convert_tokens_to_string(txt)
for i, info in enumerate(ocr_info):
if info['text'] == txt:
new_info_dict[i] = info
return new_info_dict
def post_process(pred_relations, ocr_info, img):
result = []
for relations in pred_relations:
for relation in relations:
ocr_info_head = ocr_info[relation['head_id']]
ocr_info_tail = ocr_info[relation['tail_id']]
result.append((ocr_info_head, ocr_info_tail))
return result
def draw_re(result, image_path, output_folder):
img = cv2.imread(image_path)
from matplotlib import pyplot as plt
for ocr_info_head, ocr_info_tail in result:
cv2.rectangle(
img,
tuple(ocr_info_head['bbox'][:2]),
tuple(ocr_info_head['bbox'][2:]), (255, 0, 0),
thickness=2)
cv2.rectangle(
img,
tuple(ocr_info_tail['bbox'][:2]),
tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255),
thickness=2)
center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2]
center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2]
cv2.line(
img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2)
plt.imshow(img)
plt.savefig(
os.path.join(output_folder, os.path.basename(image_path)), dpi=600)
# plt.show()
if __name__ == "__main__":
args = parse_args()
infer(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment