Commit a323fce6 authored by WenmuZhou's avatar WenmuZhou
Browse files

vqa code integrated into ppocr training system

parent 1ded2ac4
......@@ -63,6 +63,10 @@ class BaseModel(nn.Layer):
in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls
if 'Head' not in config or config['Head'] is None:
self.use_head = False
else:
self.use_head = True
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
......@@ -77,6 +81,7 @@ class BaseModel(nn.Layer):
if self.use_neck:
x = self.neck(x)
y["neck_out"] = x
if self.use_head:
x = self.head(x, targets=data)
if isinstance(x, dict):
y.update(x)
......
......@@ -43,6 +43,9 @@ def build_backbone(config, model_type):
from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"]
elif model_type == 'vqa':
from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe
support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe']
else:
raise NotImplementedError
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from paddle import nn
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
class NLPBaseModel(nn.Layer):
def __init__(self,
base_model_class,
model_class,
type='ser',
pretrained_model=None,
checkpoints=None,
**kwargs):
super(NLPBaseModel, self).__init__()
assert pretrained_model is not None or checkpoints is not None, "one of pretrained_model and checkpoints must be not None"
if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints)
else:
base_model = base_model_class.from_pretrained(pretrained_model)
if type == 'ser':
self.model = model_class(
base_model, num_classes=kwargs['num_classes'], dropout=None)
else:
self.model = model_class(base_model, dropout=None)
self.out_channels = 1
class LayoutXLMForSer(NLPBaseModel):
def __init__(self,
num_classes,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
**kwargs):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
'ser',
pretrained_model,
checkpoints,
num_classes=num_classes)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
image=x[3],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
head_mask=None,
labels=None)
return x[0]
class LayoutLMForSer(NLPBaseModel):
def __init__(self,
num_classes,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
**kwargs):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
'ser',
pretrained_model,
checkpoints,
num_classes=num_classes)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
output_hidden_states=False)
return x
class LayoutXLMForRe(NLPBaseModel):
def __init__(self,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
**kwargs):
super(LayoutXLMForRe, self).__init__(
LayoutXLMModel, LayoutXLMForRelationExtraction, 're',
pretrained_model, checkpoints)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[1],
labels=None,
image=x[2],
attention_mask=x[3],
token_type_ids=x[4],
position_ids=None,
head_mask=None,
entities=x[5],
relations=x[6])
return x
......@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None:
reg_config = config.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay'
reg_name = reg_config.pop('name')
if not hasattr(regularizer, reg_name):
reg_name += 'Decay'
reg = getattr(regularizer, reg_name)(**reg_config)()
else:
reg = None
......
......@@ -158,3 +158,38 @@ class Adadelta(object):
name=self.name,
parameters=parameters)
return opt
class AdamW(object):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
weight_decay=0.01,
grad_clip=None,
name=None,
lazy_mode=False,
**kwargs):
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.learning_rate = learning_rate
self.weight_decay = 0.01 if weight_decay is None else weight_decay
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
def __call__(self, parameters):
opt = optim.AdamW(
learning_rate=self.learning_rate,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
name=self.name,
lazy_mode=self.lazy_mode,
parameters=parameters)
return opt
......@@ -50,3 +50,18 @@ class L2Decay(object):
def __call__(self):
reg = paddle.regularizer.L2Decay(self.regularization_coeff)
return reg
class ConstDecay(object):
"""
Const L2 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def __init__(self, factor=0.0):
super(ConstDecay, self).__init__()
self.regularization_coeff = factor
def __call__(self):
return self.regularization_coeff
......@@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
def build_post_process(config, global_config=None):
......@@ -36,7 +38,8 @@ def build_post_process(config, global_config=None):
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode'
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess'
]
if config['name'] == 'PSEPostProcess':
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class VQAReTokenLayoutLMPostProcess(object):
""" Convert between text-label and text-index """
def __init__(self, **kwargs):
super(VQAReTokenLayoutLMPostProcess, self).__init__()
def __call__(self, preds, label=None, *args, **kwargs):
if label is not None:
return self._metric(preds, label)
else:
return self._infer(preds, *args, **kwargs)
def _metric(self, preds, label):
return preds['pred_relations'], label[6], label[5]
def _infer(self, preds, *args, **kwargs):
ser_results = kwargs['ser_results']
entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
pred_relations = preds['pred_relations']
# 进行 relations 到 ocr信息的转换
results = []
for pred_relation, ser_result, entity_idx_dict in zip(
pred_relations, ser_results, entity_idx_dict_batch):
result = []
used_tail_id = []
for relation in pred_relation:
if relation['tail_id'] in used_tail_id:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
result.append((ocr_info_head, ocr_info_tail))
results.append(result)
return results
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from ppocr.utils.utility import load_vqa_bio_label_maps
class VQASerTokenLayoutLMPostProcess(object):
""" Convert between text-label and text-index """
def __init__(self, class_path, **kwargs):
super(VQASerTokenLayoutLMPostProcess, self).__init__()
label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
self.label2id_map_for_draw = dict()
for key in label2id_map:
if key.startswith("I-"):
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else:
self.label2id_map_for_draw[key] = label2id_map[key]
self.id2label_map_for_show = dict()
for key in self.label2id_map_for_draw:
val = self.label2id_map_for_draw[key]
if key == "O":
self.id2label_map_for_show[val] = key
if key.startswith("B-") or key.startswith("I-"):
self.id2label_map_for_show[val] = key[2:]
else:
self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if batch is not None:
return self._metric(preds, batch[1])
else:
return self._infer(preds, **kwargs)
def _metric(self, preds, label):
pred_idxs = preds.argmax(axis=2)
decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
for i in range(pred_idxs.shape[0]):
for j in range(pred_idxs.shape[1]):
if label[i, j] != -100:
label_decode_out_list[i].append(self.id2label_map[label[i,
j]])
decode_out_list[i].append(self.id2label_map[pred_idxs[i,
j]])
return decode_out_list, label_decode_out_list
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
results = []
for pred, attention_mask, segment_offset_id, ocr_info in zip(
preds, attention_masks, segment_offset_ids, ocr_infos):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
for idx in range(len(segment_offset_id)):
if idx == 0:
start_id = 0
else:
start_id = segment_offset_id[idx - 1]
end_id = segment_offset_id[idx]
curr_pred = pred[start_id:end_id]
curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
if len(curr_pred) <= 0:
pred_id = 0
else:
counts = np.bincount(curr_pred)
pred_id = np.argmax(counts)
ocr_info[idx]["pred_id"] = int(pred_id)
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
results.append(ocr_info)
return results
......@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path))
def load_model(config, model, optimizer=None):
def load_model(config, model, optimizer=None, model_type='det'):
"""
load model from checkpoint or pretrained_model
"""
......@@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None):
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
if model_type == 'vqa':
checkpoints = config['Architecture']['Backbone']['checkpoints']
# load vqa method metric
if checkpoints:
if os.path.exists(os.path.join(checkpoints, 'metric.states')):
with open(os.path.join(checkpoints, 'metric.states'),
'rb') as f:
states_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
if optimizer is not None:
if checkpoints[-1] in ['/', '\\']:
checkpoints = checkpoints[:-1]
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
return best_model_dict
if checkpoints:
if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '')
......@@ -127,6 +154,7 @@ def save_model(model,
optimizer,
model_path,
logger,
config,
is_best=False,
prefix='ppocr',
**kwargs):
......@@ -135,13 +163,20 @@ def save_model(model,
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
paddle.save(model.state_dict(), model_prefix + '.pdparams')
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa':
paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix
else:
if config['Global']['distributed']:
model._layers.backbone.model.save_pretrained(model_prefix)
else:
model.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config
with open(model_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
if is_best:
with open(metric_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
logger.info('save best model is to {}'.format(model_prefix))
else:
logger.info("save model in {}".format(model_prefix))
......@@ -78,3 +78,21 @@ def check_and_read_gif(img_path):
imgvalue = frame[:, :, ::-1]
return imgvalue, True
return None, False
def load_vqa_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
lines = [line.strip() for line in lines]
if "O" not in lines:
lines.insert(0, "O")
labels = []
for line in lines:
if line == "O":
labels.append("O")
else:
labels.append("B-" + line)
labels.append("I-" + line)
label2id_map = {label: idx for idx, label in enumerate(labels)}
id2label_map = {idx: label for idx, label in enumerate(labels)}
return label2id_map, id2label_map
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from PIL import Image, ImageDraw, ImageFont
def draw_ser_results(image,
ocr_results,
font_path="doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(2021)
color = (np.random.permutation(range(255)),
np.random.permutation(range(255)),
np.random.permutation(range(255)))
color_map = {
idx: (color[0][idx], color[1][idx], color[2][idx])
for idx in range(1, 255)
}
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str) and os.path.isfile(image):
image = Image.open(image).convert('RGB')
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
for ocr_info in ocr_results:
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
def draw_box_txt(bbox, text, draw, font, font_size, color):
# draw ocr results outline
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
draw.rectangle(bbox, fill=color)
# draw ocr results
start_y = max(0, bbox[0][1] - font_size)
tw = font.getsize(text)[0]
draw.rectangle(
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
fill=(0, 0, 255))
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
def draw_re_results(image,
result,
font_path="doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(0)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str) and os.path.isfile(image):
image = Image.open(image).convert('RGB')
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
color_head = (0, 0, 255)
color_tail = (255, 0, 0)
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
font_size, color_tail)
center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
center_tail = (
(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
draw.line([center_head, center_tail], fill=color_line, width=5)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
......@@ -24,8 +24,8 @@
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
## 3. KIE模型
......
......@@ -20,11 +20,11 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下
| 模型 | 任务 | f1 | 模型下载地址 |
| 模型 | 任务 | hmean | 模型下载地址 |
|:---:|:---:|:---:| :---:|
| LayoutXLM | RE | 0.7113 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
| LayoutXLM | SER | 0.9056 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
| LayoutLM | SER | 0.78 | [链接](https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_ser_pretrained.tar) |
| LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
| LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
| LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
......@@ -65,10 +65,10 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
pip3 install --upgrade pip
# GPU安装
python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple
python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple
# CPU安装
python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple
```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
......@@ -79,7 +79,7 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
- **(1)pip快速安装PaddleOCR whl包(仅预测)**
```bash
pip install paddleocr
python3 -m pip install paddleocr
```
- **(2)下载VQA源码(预测+训练)**
......@@ -93,18 +93,10 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
```
- **(3)安装PaddleNLP**
- **(3)安装VQA的`requirements`**
```bash
pip3 install "paddlenlp>=2.2.1"
```
- **(4)安装VQA的`requirements`**
```bash
cd ppstructure/vqa
pip install -r requirements.txt
python3 -m pip install -r ppstructure/vqa/requirements.txt
```
## 4. 使用
......@@ -112,6 +104,10 @@ pip install -r requirements.txt
### 4.1 数据和预训练模型准备
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
* 下载处理好的数据集
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)
......@@ -121,98 +117,62 @@ pip install -r requirements.txt
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
```
如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)
* 转换数据集
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
若需进行其他XFUN数据集的训练,可使用下面的命令进行数据集的转换
```bash
python3 ppstructure/vqa/helper/trans_xfun_data.py --ori_gt_path=path/to/json_path --output_path=path/to/save_path
```
### 4.2 SER任务
* 启动训练
启动训练之前,需要修改下面的四个字段
1. `Train.dataset.data_dir`:指向训练集图片存放目录
2. `Train.dataset.label_file_list`:指向训练集标注文件
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
4. `Eval.dataset.label_file_list`:指向验证集标注文件
* 启动训练
```shell
python3.7 train_ser.py \
--model_name_or_path "layoutxlm-base-uncased" \
--ser_model_type "LayoutXLM" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--num_train_epochs 200 \
--eval_steps 10 \
--output_dir "./output/ser/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--evaluate_during_training \
--seed 2048
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml
```
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。
最终会打印出`precision`, `recall`, `hmean`等指标。
`./output/ser_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
* 恢复训练
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell
python3.7 train_ser.py \
--model_name_or_path "model_path" \
--ser_model_type "LayoutXLM" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--num_train_epochs 200 \
--eval_steps 10 \
--output_dir "./output/ser/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--evaluate_during_training \
--num_workers 8 \
--seed 2048 \
--resume
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
* 评估
```shell
export CUDA_VISIBLE_DEVICES=0
python3 eval_ser.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--output_dir "output/ser/" \
--seed 2048
```
最终会打印出`precision`, `recall`, `f1`等指标
* 使用评估集合中提供的OCR识别结果进行预测
评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--output_dir "output/ser/" \
--infer_imgs "XFUND/zh_val/image/" \
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
最终会打印出`precision`, `recall`, `hmean`等指标
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`
* 使用`OCR引擎 + SER`串联预测
* 使用`OCR引擎 + SER`串联结果
使用如下命令即可完成`OCR引擎 + SER`串联预测
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_e2e.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--max_seq_length 512 \
--output_dir "output/ser_e2e/" \
--infer_imgs "images/input/zh_val_0.jpg"
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/ Global.infer_img=ppstructure/vqa/images/input/zh_val_42.jpg
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
*`OCR引擎 + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
......@@ -223,102 +183,48 @@ python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_nor
* 启动训练
```shell
export CUDA_VISIBLE_DEVICES=0
python3 train_re.py \
--model_name_or_path "layoutxlm-base-uncased" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--num_train_epochs 200 \
--eval_steps 10 \
--output_dir "output/re/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--evaluate_during_training \
--seed 2048
```
启动训练之前,需要修改下面的四个字段
* 恢复训练
1. `Train.dataset.data_dir`:指向训练集图片存放目录
2. `Train.dataset.label_file_list`:指向训练集标注文件
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
4. `Eval.dataset.label_file_list`:指向验证集标注文件
```shell
export CUDA_VISIBLE_DEVICES=0
python3 train_re.py \
--model_name_or_path "model_path" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--num_train_epochs 2 \
--eval_steps 10 \
--output_dir "output/re/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--evaluate_during_training \
--seed 2048 \
--resume
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml
```
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。
最终会打印出`precision`, `recall`, `hmean`等指标。
`./output/re_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
* 恢复训练
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
* 评估
```shell
export CUDA_VISIBLE_DEVICES=0
python3 eval_re.py \
--model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--output_dir "output/re/" \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--seed 2048
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
最终会打印出`precision`, `recall`, `f1`等指标
* 评估
* 使用评估集合中提供的OCR识别结果进行预测
评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell
export CUDA_VISIBLE_DEVICES=0
python3 infer_re.py \
--model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--output_dir "output/re/" \
--per_gpu_eval_batch_size 1 \
--seed 2048
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
最终会打印出`precision`, `recall`, `hmean`等指标
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`
* 使用`OCR引擎 + SER + RE`串联结果
* 使用`OCR引擎 + SER + RE`串联预测
使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_re_e2e.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--re_model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--ser_model_type "LayoutXLM" \
--max_seq_length 512 \
--output_dir "output/ser_re_e2e/" \
--infer_imgs "images/input/zh_val_21.jpg"
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_re_pretrained/ Global.infer_img=ppstructure/vqa/images/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
## 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from vqa_utils import parse_args, get_bio_label_maps, print_arguments
from data_collator import DataCollator
from metric import re_score
from ppocr.utils.logging import get_logger
def cal_metric(re_preds, re_labels, entities):
gt_relations = []
for b in range(len(re_labels)):
rel_sent = []
for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (entities[b]["start"][rel["head_id"]],
entities[b]["end"][rel["head_id"]])
rel["head_type"] = entities[b]["label"][rel["head_id"]]
rel["tail_id"] = tail
rel["tail"] = (entities[b]["start"][rel["tail_id"]],
entities[b]["end"][rel["tail_id"]])
rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
return re_metrics
def evaluate(model, eval_dataloader, logger, prefix=""):
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
re_preds = []
re_labels = []
entities = []
eval_loss = 0.0
model.eval()
for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad():
outputs = model(**batch)
loss = outputs['loss'].mean().item()
if paddle.distributed.get_rank() == 0:
logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), loss))
eval_loss += loss
re_preds.extend(outputs['pred_relations'])
re_labels.extend(batch['relations'])
entities.extend(batch['entities'])
re_metrics = cal_metric(re_preds, re_labels, entities)
re_metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
"f1": re_metrics["ALL"]["f1"],
}
model.train()
return re_metrics
def eval(args):
logger = get_logger()
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction.from_pretrained(
args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=args.num_workers,
shuffle=False,
collate_fn=DataCollator())
results = evaluate(model, eval_dataloader, logger)
logger.info("eval results: {}".format(results))
if __name__ == "__main__":
args = parse_args()
eval(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import time
import copy
import logging
import argparse
import paddle
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset
from losses import SERLoss
from vqa_utils import parse_args, get_bio_label_maps, print_arguments
from ppocr.utils.logging import get_logger
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def eval(args):
logger = get_logger()
print_arguments(args, logger)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
pad_token_label_id=pad_token_label_id,
contains_re=False,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=None, )
loss_class = SERLoss(len(label2id_map))
results, _ = evaluate(args, model, tokenizer, loss_class, eval_dataloader,
label2id_map, id2label_map, pad_token_label_id,
logger)
logger.info(results)
def evaluate(args,
model,
tokenizer,
loss_class,
eval_dataloader,
label2id_map,
id2label_map,
pad_token_label_id,
logger,
prefix=""):
eval_loss = 0.0
nb_eval_steps = 0
preds = None
out_label_ids = None
model.eval()
for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad():
if args.ser_model_type == 'LayoutLM':
if 'image' in batch:
batch.pop('image')
labels = batch.pop('labels')
outputs = model(**batch)
if args.ser_model_type == 'LayoutXLM':
outputs = outputs[0]
loss = loss_class(labels, outputs, batch['attention_mask'])
loss = loss.mean()
if paddle.distributed.get_rank() == 0:
logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), loss.numpy()[0]))
eval_loss += loss.item()
nb_eval_steps += 1
if preds is None:
preds = outputs.numpy()
out_label_ids = labels.numpy()
else:
preds = np.append(preds, outputs.numpy(), axis=0)
out_label_ids = np.append(out_label_ids, labels.numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = np.argmax(preds, axis=2)
# label_map = {i: label.upper() for i, label in enumerate(labels)}
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
preds_list = [[] for _ in range(out_label_ids.shape[0])]
for i in range(out_label_ids.shape[0]):
for j in range(out_label_ids.shape[1]):
if out_label_ids[i, j] != pad_token_label_id:
out_label_list[i].append(id2label_map[out_label_ids[i][j]])
preds_list[i].append(id2label_map[preds[i][j]])
results = {
"loss": eval_loss,
"precision": precision_score(out_label_list, preds_list),
"recall": recall_score(out_label_list, preds_list),
"f1": f1_score(out_label_list, preds_list),
}
with open(
os.path.join(args.output_dir, "test_gt.txt"), "w",
encoding='utf-8') as fout:
for lbl in out_label_list:
for l in lbl:
fout.write(l + "\t")
fout.write("\n")
with open(
os.path.join(args.output_dir, "test_pred.txt"), "w",
encoding='utf-8') as fout:
for lbl in preds_list:
for l in lbl:
fout.write(l + "\t")
fout.write("\n")
report = classification_report(out_label_list, preds_list)
logger.info("\n" + report)
logger.info("***** Eval results %s *****", prefix)
for key in sorted(results.keys()):
logger.info(" %s = %s", key, str(results[key]))
model.train()
return results, preds_list
if __name__ == "__main__":
args = parse_args()
eval(args)
......@@ -49,4 +49,16 @@ def transfer_xfun_data(json_path=None, output_file=None):
print("===ok====")
transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json")
def parser_args():
import argparse
parser = argparse.ArgumentParser(description="args for paddleserving")
parser.add_argument(
"--ori_gt_path", type=str, required=True, help='origin xfun gt path')
parser.add_argument(
"--output_path", type=str, required=True, help='path to save')
args = parser.parse_args()
return args
args = parser_args()
transfer_xfun_data(args.ori_gt_path, args.output_path)
export CUDA_VISIBLE_DEVICES=6
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --max_seq_length 512 \
# --output_dir "output_res_e2e/" \
# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --re_model_name_or_path "output/re_test/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e_train/" \
# --infer_imgs "images/input/zh_val_21.jpg"
# python3.7 infer_ser.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --output_dir "ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_21.jpg" \
# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
python3.7 infer_ser.py \
--model_name_or_path "output/ser_new/best_model" \
--ser_model_type "LayoutXLM" \
--output_dir "ser_new/" \
--infer_imgs "images/input/zh_val_21.jpg" \
--ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_new/best_model" \
# --ser_model_type "LayoutXLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_new/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3 infer_re.py \
# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
# --max_seq_length 512 \
# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
# --label_map_path 'labels/labels_ser.txt' \
# --output_dir "output_res" \
# --per_gpu_eval_batch_size 1 \
# --seed 2048
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --re_model_name_or_path "output/re_new/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e/" \
# --infer_imgs "images/input/zh_val_21.jpg"
\ No newline at end of file
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import cv2
import matplotlib.pyplot as plt
import numpy as np
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from vqa_utils import parse_args, get_bio_label_maps, draw_re_results
from data_collator import DataCollator
from ppocr.utils.logging import get_logger
def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger()
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction.from_pretrained(
args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=8,
shuffle=False,
collate_fn=DataCollator())
# 读取gt的oct数据
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
for idx, batch in enumerate(eval_dataloader):
ocr_info = ocr_info_list[idx]
image_path = ocr_info['image_path']
ocr_info = ocr_info['ocr_info']
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg")
logger.info("[Infer] process: {}/{}, save result to {}".format(
idx, len(eval_dataloader), save_img_path))
with paddle.no_grad():
outputs = model(**batch)
pred_relations = outputs['pred_relations']
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
# 进行 relations 到 ocr信息的转换
result = []
used_tail_id = []
for relations in pred_relations:
for relation in relations:
if relation['tail_id'] in used_tail_id:
continue
if relation['head_id'] not in ocr_info or relation[
'tail_id'] not in ocr_info:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ocr_info[relation['head_id']]
ocr_info_tail = ocr_info[relation['tail_id']]
result.append((ocr_info_head, ocr_info_tail))
img = cv2.imread(image_path)
img_show = draw_re_results(img, result)
cv2.imwrite(save_img_path, img_show)
def load_ocr(img_folder, json_path):
import json
d = []
with open(json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
image_name, info_str = line.split("\t")
info_dict = json.loads(info_str)
info_dict['image_path'] = os.path.join(img_folder, image_name)
d.append(info_dict)
return d
def filter_bg_by_txt(ocr_info, batch, tokenizer):
entities = batch['entities'][0]
input_ids = batch['input_ids'][0]
new_info_dict = {}
for i in range(len(entities['start'])):
entitie_head = entities['start'][i]
entitie_tail = entities['end'][i]
word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist()
txt = tokenizer.convert_ids_to_tokens(word_input_ids)
txt = tokenizer.convert_tokens_to_string(txt)
for i, info in enumerate(ocr_info):
if info['text'] == txt:
new_info_dict[i] = info
return new_info_dict
def post_process(pred_relations, ocr_info, img):
result = []
for relations in pred_relations:
for relation in relations:
ocr_info_head = ocr_info[relation['head_id']]
ocr_info_tail = ocr_info[relation['tail_id']]
result.append((ocr_info_head, ocr_info_tail))
return result
def draw_re(result, image_path, output_folder):
img = cv2.imread(image_path)
from matplotlib import pyplot as plt
for ocr_info_head, ocr_info_tail in result:
cv2.rectangle(
img,
tuple(ocr_info_head['bbox'][:2]),
tuple(ocr_info_head['bbox'][2:]), (255, 0, 0),
thickness=2)
cv2.rectangle(
img,
tuple(ocr_info_tail['bbox'][:2]),
tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255),
thickness=2)
center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2]
center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2]
cv2.line(
img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2)
plt.imshow(img)
plt.savefig(
os.path.join(output_folder, os.path.basename(image_path)), dpi=600)
# plt.show()
if __name__ == "__main__":
args = parse_args()
infer(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
import json
import cv2
import numpy as np
from copy import deepcopy
import paddle
# relative reference
from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def pad_sentences(tokenizer,
encoded_inputs,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
# Padding with larger size, reshape is carried out
max_seq_len = (
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
needs_to_be_padded = pad_to_max_seq_len and \
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
if tokenizer.padding_side == 'right':
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] +
[tokenizer.pad_token_type_id] * difference)
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs[
"special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs[
"input_ids"] + [tokenizer.pad_token_id] * difference
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
] * difference
else:
assert False, "padding_side of tokenizer just supports [\"right\"] but got {}".format(
tokenizer.padding_side)
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"])
return encoded_inputs
def split_page(encoded_inputs, max_seq_len=512):
"""
truncate is often used in training process
"""
for key in encoded_inputs:
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
else: # for bbox
encoded_inputs[key] = encoded_inputs[key].reshape(
[-1, max_seq_len, 4])
return encoded_inputs
def preprocess(
tokenizer,
ori_img,
ocr_info,
img_size=(224, 224),
pad_token_label_id=-100,
max_seq_len=512,
add_special_ids=False,
return_attention_mask=True, ):
ocr_info = deepcopy(ocr_info)
height = ori_img.shape[0]
width = ori_img.shape[1]
img = cv2.resize(ori_img,
(224, 224)).transpose([2, 0, 1]).astype(np.float32)
segment_offset_id = []
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
for info in ocr_info:
# x1, y1, x2, y2
bbox = info["bbox"]
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"]
encode_res = tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
encoded_inputs = {
"input_ids": input_ids_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
}
encoded_inputs = pad_sentences(
tokenizer,
encoded_inputs,
max_seq_len=max_seq_len,
return_attention_mask=return_attention_mask)
encoded_inputs = split_page(encoded_inputs)
fake_bs = encoded_inputs["input_ids"].shape[0]
encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
[fake_bs] + list(img.shape))
encoded_inputs["segment_offset_id"] = segment_offset_id
return encoded_inputs
def postprocess(attention_mask, preds, label_map_path):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds = np.argmax(preds, axis=2)
_, label_map = get_bio_label_maps(label_map_path)
preds_list = [[] for _ in range(preds.shape[0])]
# keep batch info
for i in range(preds.shape[0]):
for j in range(preds.shape[1]):
if attention_mask[i][j] == 1:
preds_list[i].append(label_map[preds[i][j]])
return preds_list
def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id,
preds_list):
# must ensure the preds_list is generated from the same image
preds = [p for pred in preds_list for p in pred]
label2id_map, _ = get_bio_label_maps(label_map_path)
for key in label2id_map:
if key.startswith("I-"):
label2id_map[key] = label2id_map["B" + key[1:]]
id2label_map = dict()
for key in label2id_map:
val = label2id_map[key]
if key == "O":
id2label_map[val] = key
if key.startswith("B-") or key.startswith("I-"):
id2label_map[val] = key[2:]
else:
id2label_map[val] = key
for idx in range(len(segment_offset_id)):
if idx == 0:
start_id = 0
else:
start_id = segment_offset_id[idx - 1]
end_id = segment_offset_id[idx]
curr_pred = preds[start_id:end_id]
curr_pred = [label2id_map[p] for p in curr_pred]
if len(curr_pred) <= 0:
pred_id = 0
else:
counts = np.bincount(curr_pred)
pred_id = np.argmax(counts)
ocr_info[idx]["pred_id"] = int(pred_id)
ocr_info[idx]["pred"] = id2label_map[pred_id]
return ocr_info
@paddle.no_grad()
def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
# init token and model
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
model.eval()
# load ocr results json
ocr_results = dict()
with open(args.ocr_json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
img_name, json_info = line.split("\t")
ocr_results[os.path.basename(img_name)] = json.loads(json_info)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(args.output_dir,
os.path.basename(img_path))
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"]
inputs = preprocess(
tokenizer=tokenizer,
ori_img=img,
ocr_info=ocr_info,
max_seq_len=args.max_seq_length)
if args.ser_model_type == 'LayoutLM':
preds = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
elif args.ser_model_type == 'LayoutXLM':
preds = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = preds[0]
preds = postprocess(inputs["attention_mask"], preds,
args.label_map_path)
ocr_info = merge_preds_list_with_ocr_info(
args.label_map_path, ocr_info, inputs["segment_offset_id"],
preds)
fout.write(img_path + "\t" + json.dumps(
{
"ocr_info": ocr_info,
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, ocr_info)
cv2.imwrite(save_img_path, img_res)
return
if __name__ == "__main__":
args = parse_args()
infer(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment