Commit 253b8453 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents 7cad4817 bc999986
# 训练版面分析
* [1. 安装](#安装)
* [1.1 环境要求](#环境要求)
* [1.2 安装PaddleDetection](#安装PaddleDetection)
* [2. 准备数据](#准备数据)
* [3. 配置文件改动和说明](#配置文件改动和说明)
* [4. PaddleDetection训练](#训练)
* [5. PaddleDetection预测](#预测)
* [6. 预测部署](#预测部署)
* [6.1 模型导出](#模型导出)
* [6.2 layout parser预测](#layout_parser预测)
<a name="安装"></a>
## 1. 安装
<a name="环境要求"></a>
### 1.1 环境要求
- PaddlePaddle 2.1
- OS 64 bit
- Python 3(3.5.1+/3.6/3.7/3.8/3.9),64 bit
- pip/pip3(9.0.1+), 64 bit
- CUDA >= 10.1
- cuDNN >= 7.6
<a name="安装PaddleDetection"></a>
### 1.2 安装PaddleDetection
```bash
# 克隆PaddleDetection仓库
cd <path/to/clone/PaddleDetection>
git clone https://github.com/PaddlePaddle/PaddleDetection.git
cd PaddleDetection
# 安装其他依赖
pip install -r requirements.txt
```
更多安装教程,请参考: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md)
<a name="数据准备"></a>
## 2. 准备数据
下载 [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) 数据集:
```bash
cd PaddleDetection/dataset/
mkdir publaynet
# 执行命令,下载
wget -O publaynet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.104193024.1076900768.1622560733-649911202.1622560733
# 解压
tar -xvf publaynet.tar.gz
```
解压之后PubLayNet目录结构:
| File or Folder | Description | num |
| :------------- | :----------------------------------------------- | ------- |
| `train/` | Images in the training subset | 335,703 |
| `val/` | Images in the validation subset | 11,245 |
| `test/` | Images in the testing subset | 11,405 |
| `train.json` | Annotations for training images | |
| `val.json` | Annotations for validation images | |
| `LICENSE.txt` | Plaintext version of the CDLA-Permissive license | |
| `README.txt` | Text file with the file names and description | |
如果使用其它数据集,请参考[准备训练数据](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/PrepareDataSet.md)
<a name="配置文件改动和说明"></a>
## 3. 配置文件改动和说明
我们使用 `configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml`配置进行训练,配置文件摘要如下:
<div align='center'>
<img src='../../doc/table/PaddleDetection_config.png' width='600px'/>
</div>
从上图看到 `ppyolov2_r50vd_dcn_365e_coco.yml` 配置需要依赖其他的配置文件,在该例子中需要依赖:
```
coco_detection.yml:主要说明了训练数据和验证数据的路径
runtime.yml:主要说明了公共的运行参数,比如是否使用GPU、每多少个epoch存储checkpoint等
optimizer_365e.yml:主要说明了学习率和优化器的配置
ppyolov2_r50vd_dcn.yml:主要说明模型和主干网络的情况
ppyolov2_reader.yml:主要说明数据读取器配置,如batch size,并发加载子进程数等,同时包含读取后预处理操作,如resize、数据增强等等
```
根据实际情况,修改上述文件,比如数据集路径、batch size等。
<a name="训练"></a>
## 4. PaddleDetection训练
PaddleDetection提供了单卡/多卡训练模式,满足用户多种训练需求
* GPU 单卡训练
```bash
export CUDA_VISIBLE_DEVICES=0 #windows和Mac下不需要执行该命令
python tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml
```
* GPU多卡训练
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval
```
--eval:表示边训练边验证
* 模型恢复训练
在日常训练过程中,有的用户由于一些原因导致训练中断,用户可以使用-r的命令恢复训练:
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -r output/ppyolov2_r50vd_dcn_365e_coco/10000
```
注意:如果遇到 "`Out of memory error`" 问题, 尝试在 `ppyolov2_reader.yml` 文件中调小`batch_size`
<a name="预测"></a>
## 5. PaddleDetection预测
设置参数,使用PaddleDetection预测:
```bash
export CUDA_VISIBLE_DEVICES=0
python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer_img=images/paper-image.jpg --output_dir=infer_output/ --draw_threshold=0.5 -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final --use_vdl=Ture
```
`--draw_threshold` 是个可选参数. 根据 [NMS](https://ieeexplore.ieee.org/document/1699659) 的计算,不同阈值会产生不同的结果 `keep_top_k`表示设置输出目标的最大数量,默认值为100,用户可以根据自己的实际情况进行设定。
<a name="预测部署"></a>
## 6. 预测部署
在layout parser中使用自己训练好的模型,
<a name="模型导出"></a>
### 6.1 模型导出
在模型训练过程中保存的模型文件是包含前向预测和反向传播的过程,在实际的工业部署则不需要反向传播,因此需要将模型进行导成部署需要的模型格式。 在PaddleDetection中提供了 `tools/export_model.py`脚本来导出模型。
导出模型名称默认是`model.*`,layout parser代码模型名称是`inference.*`, 所以修改[PaddleDetection/ppdet/engine/trainer.py ](https://github.com/PaddlePaddle/PaddleDetection/blob/b87a1ea86fa18ce69e44a17ad1b49c1326f19ff9/ppdet/engine/trainer.py#L512) (点开链接查看详细代码行),将`model`改为`inference`即可。
执行导出模型脚本:
```bash
python tools/export_model.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --output_dir=./inference -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final.pdparams
```
预测模型会导出到`inference/ppyolov2_r50vd_dcn_365e_coco`目录下,分别为`infer_cfg.yml`(预测不需要), `inference.pdiparams`, `inference.pdiparams.info`,`inference.pdmodel`
更多模型导出教程,请参考:[EXPORT_MODEL](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/deploy/EXPORT_MODEL.md)
<a name="layout parser预测"></a>
### 6.2 layout_parser预测
`model_path`指定训练好的模型路径,使用layout parser进行预测:
```bash
import layoutparser as lp
model = lp.PaddleDetectionLayoutModel(model_path="inference/ppyolov2_r50vd_dcn_365e_coco", threshold=0.5,label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},enforce_cpu=True,enable_mkldnn=True)
```
***
更多PaddleDetection训练教程,请参考:[PaddleDetection训练](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/GETTING_STARTED_cn.md)
***
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import cv2
import numpy as np
from pathlib import Path
from ppocr.utils.logging import get_logger
from test1.predict_system import OCRSystem, save_res
from test1.table.predict_table import to_excel
from test1.utility import init_args, draw_result
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link
__all__ = ['PaddleStructure', 'draw_result', 'save_res']
VERSION = '2.1'
BASE_DIR = os.path.expanduser("~/.paddlestructure/")
model_urls = {
'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar',
'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'table': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar'
}
def parse_args(mMain=True):
import argparse
parser = init_args()
parser.add_help = mMain
for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
action.default = None
if mMain:
return parser.parse_args()
else:
inference_args_dict = {}
for action in parser._actions:
inference_args_dict[action.dest] = action.default
return argparse.Namespace(**inference_args_dict)
class PaddleStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if not params.show_log:
logger.setLevel(logging.INFO)
params.use_angle_cls = False
# init model dir
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'det'),
model_urls['det'])
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'rec'),
model_urls['rec'])
params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir,
os.path.join(BASE_DIR, VERSION, 'table'),
model_urls['table'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.table_model_dir, table_url)
if params.rec_char_dict_path is None:
params.rec_char_type = 'EN'
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
else:
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
if params.table_char_dict_path is None:
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
params.table_char_dict_path = str(
Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
else:
params.table_char_dict_path = str(
Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
print(params)
super().__init__(params)
def __call__(self, img):
if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img
img, flag = check_and_read_gif(image_file)
if not flag:
with open(image_file, 'rb') as f:
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
res = super().__call__(img)
return res
def main():
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
save_folder = args.output
if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir))
return
structure_engine = PaddleStructure(**(args.__dict__))
for img_path in image_file_list:
img_name = os.path.basename(img_path).split('.')[0]
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = structure_engine(img_path)
for item in result:
logger.info(item['res'])
save_res(result, save_folder, img_name)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
if __name__ == '__main__':
table_engine = PaddleStructure(show_log=True)
img_path = '../test/test_imgs/PMC1173095_006_00.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_res(result, '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table',
os.path.basename(img_path).split('.')[0])
for line in result:
print(line)
from PIL import Image
font_path = '../doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result, font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import time
import logging
import layoutparser as lp
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from test1.table.predict_table import TableSystem, to_excel
from test1.utility import parse_args, draw_result
logger = get_logger()
class OCRSystem(object):
def __init__(self, args):
args.det_limit_type = 'resize_long'
args.drop_score = 0
if not args.show_log:
logger.setLevel(logging.INFO)
self.text_system = TextSystem(args)
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
def __call__(self, img):
ori_im = img.copy()
layout_res = self.table_layout.detect(img[..., ::-1])
res_list = []
for region in layout_res:
x1, y1, x2, y2 = region.coordinates
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
if region.type == 'Table':
res = self.table_system(roi_img)
else:
filter_boxes, filter_rec_res = self.text_system(roi_img)
filter_boxes = [x + [x1, y1] for x in filter_boxes]
filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
res = (filter_boxes, filter_rec_res)
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
return res_list
def save_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['type'] == 'Table':
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
elif region['type'] == 'Figure':
pass
else:
with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
for box, rec_res in zip(region['res'][0], region['res'][1]):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list
image_file_list = image_file_list[args.process_id::args.total_process_num]
save_folder = args.output
os.makedirs(save_folder, exist_ok=True)
structure_sys = OCRSystem(args)
img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
img_name = os.path.basename(image_file).split('.')[0]
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
res = structure_sys(img)
save_res(res, save_folder, img_name)
draw_img = draw_result(img, res, args.vis_font_path)
cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
if __name__ == "__main__":
args = parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = [sys.executable, "-u"] + sys.argv + [
"--process_id={}".format(process_id),
"--use_mp={}".format(False)
]
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
else:
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from setuptools import setup
from io import open
import shutil
with open('../requirements.txt', encoding="utf-8-sig") as f:
requirements = f.readlines()
requirements.append('tqdm')
requirements.append('layoutparser')
requirements.append('iopath')
def readme():
with open('api_ch.md', encoding="utf-8-sig") as f:
README = f.read()
return README
shutil.copytree('./table', './test1/table')
shutil.copyfile('./predict_system.py', './test1/predict_system.py')
shutil.copyfile('./utility.py', './test1/utility.py')
shutil.copytree('../ppocr', './ppocr')
shutil.copytree('../tools', './tools')
shutil.copyfile('../LICENSE', './LICENSE')
setup(
name='paddlestructure',
packages=['paddlestructure'],
package_dir={'paddlestructure': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]},
version='1.0',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
long_description=readme(),
long_description_content_type='text/markdown',
url='https://github.com/PaddlePaddle/PaddleOCR',
download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
keywords=[
'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
],
classifiers=[
'Intended Audience :: Developers', 'Operating System :: OS Independent',
'Natural Language :: Chinese (Simplified)',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.2',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
], )
shutil.rmtree('ppocr')
shutil.rmtree('tools')
shutil.rmtree('test1')
os.remove('LICENSE')
# Table structure and content prediction
## 1. pipeline
The ocr of the table mainly contains three models
1. Single line text detection-DB
2. Single line text recognition-CRNN
3. Table structure and cell coordinate prediction-RARE
The table ocr flow chart is as follows
![tableocr_pipeline](../../doc/table/tableocr_pipeline.png)
1. The coordinates of single-line text is detected by DB model, and then sends it to the recognition model to get the recognition result.
2. The table structure and cell coordinates is predicted by RARE model.
3. The recognition result of the cell is combined by the coordinates, recognition result of the single line and the coordinates of the cell.
4. The cell recognition result and the table structure together construct the html string of the table.
## 2. How to use
### 2.1 Train
TBD
### 2.2 Eval
First cd to the PaddleOCR/ppstructure directory
The table uses TEDS (Tree-Edit-Distance-based Similarity) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows:
```json
{"PMC4289340_004_00.png": [["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"], [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], [["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]]}
```
In gt json, the key is the image name, the value is the corresponding gt, and gt is a list composed of four items, and each item is
1. HTML string list of table structure
2. The coordinates of each cell (not including the empty text in the cell)
3. The text information in each cell (not including the empty text in the cell)
4. The text information in each cell (including the empty text in the cell)
Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
```python
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
### 2.3 Inference
First cd to the PaddleOCR/ppstructure directory
```python
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
After running, the excel sheet of each picture will be saved in the directory specified by the output field
\ No newline at end of file
# 表格结构和内容预测
## 1. pipeline
表格的ocr主要包含三个模型
1. 单行文本检测-DB
2. 单行文本识别-CRNN
3. 表格结构和cell坐标预测-RARE
具体流程图如下
![tableocr_pipeline](../../doc/table/tableocr_pipeline.png)
1. 图片由单行文字检测检测模型到单行文字的坐标,然后送入识别模型拿到识别结果。
2. 图片由表格结构和cell坐标预测模型拿到表格的结构信息和单元格的坐标信息。
3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。
4. 单元格的识别结果和表格结构一起构造表格的html字符串。
## 2. 使用
### 2.1 训练
TBD
### 2.2 评估
先cd到PaddleOCR/ppstructure目录下
表格使用 TEDS(Tree-Edit-Distance-based Similarity) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
```json
{"PMC4289340_004_00.png": [["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"], [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], [["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]]}
```
json 中,key为图片名,value为对于的gt,gt是一个由四个item组成的list,每个item分别为
1. 表格结构的html字符串list
2. 每个cell的坐标 (不包括cell里文字为空的)
3. 每个cell里的文字信息 (不包括cell里文字为空的)
4. 每个cell里的文字信息 (包括cell里文字为空的)
准备完成后使用如下命令进行评估,评估完成后会输出teds指标。
```python
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
### 2.3 预测
先cd到PaddleOCR/ppstructure目录下
```python
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2
import json
from tqdm import tqdm
from test1.table.table_metric import TEDS
from test1.table.predict_table import TableSystem
from test1.utility import init_args
from ppocr.utils.logging import get_logger
logger = get_logger()
def parse_args():
parser = init_args()
parser.add_argument("--gt_path", type=str)
return parser.parse_args()
def main(gt_path, img_root, args):
teds = TEDS(n_jobs=16)
text_sys = TableSystem(args)
jsons_gt = json.load(open(gt_path)) # gt
pred_htmls = []
gt_htmls = []
for img_name in tqdm(jsons_gt):
# read image
img = cv2.imread(os.path.join(img_root,img_name))
pred_html = text_sys(img)
pred_htmls.append(pred_html)
gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
gt_html, gt = get_gt_html(gt_structures, contents_with_block)
gt_htmls.append(gt_html)
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
logger.info('teds:', sum(scores) / len(scores))
def get_gt_html(gt_structures, contents_with_block):
end_html = []
td_index = 0
for tag in gt_structures:
if '</td>' in tag:
if contents_with_block[td_index] != []:
end_html.extend(contents_with_block[td_index])
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
if __name__ == '__main__':
args = parse_args()
main(args.gt_path,args.image_dir, args)
import json
def distance(box_1, box_2):
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4- x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
def compute_iou(rec1, rec2):
"""
computing IoU
:param rec1: (y0, x0, y1, x1), which reflects
(top, left, bottom, right)
:param rec2: (y0, x0, y1, x1)
:return: scala value of IoU
"""
# computing area of each rectangles
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
# computing the sum_area
sum_area = S_rec1 + S_rec2
# find the each edge of intersect rectangle
left_line = max(rec1[1], rec2[1])
right_line = min(rec1[3], rec2[3])
top_line = max(rec1[0], rec2[0])
bottom_line = min(rec1[2], rec2[2])
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0.0
else:
intersect = (right_line - left_line) * (bottom_line - top_line)
return (intersect / (sum_area - intersect))*1.0
def matcher_merge(ocr_bboxes, pred_bboxes):
all_dis = []
ious = []
matched = {}
for i, gt_box in enumerate(ocr_bboxes):
distances = []
for j, pred_box in enumerate(pred_bboxes):
# compute l1 distence and IOU between two boxes
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
sorted_distances = distances.copy()
# select nearest cell
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched#, sum(ious) / len(ious)
def complex_num(pred_bboxes):
complex_nums = []
for bbox in pred_bboxes:
distances = []
temp_ious = []
for pred_bbox in pred_bboxes:
if bbox != pred_bbox:
distances.append(distance(bbox, pred_bbox))
temp_ious.append(compute_iou(bbox, pred_bbox))
complex_nums.append(temp_ious[distances.index(min(distances))])
return sum(complex_nums) / len(complex_nums)
def get_rows(pred_bboxes):
pre_bbox = pred_bboxes[0]
res = []
step = 0
for i in range(len(pred_bboxes)):
bbox = pred_bboxes[i]
if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
break
else:
res.append(bbox)
step += 1
for i in range(step):
pred_bboxes.pop(0)
return res, pred_bboxes
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
ys_1 = []
ys_2 = []
for box in pred_bboxes:
ys_1.append(box[1])
ys_2.append(box[3])
min_y_1 = sum(ys_1) / len(ys_1)
min_y_2 = sum(ys_2) / len(ys_2)
re_boxes = []
for box in pred_bboxes:
box[1] = min_y_1
box[3] = min_y_2
re_boxes.append(box)
return re_boxes
def matcher_refine_row(gt_bboxes, pred_bboxes):
before_refine_pred_bboxes = pred_bboxes.copy()
pred_bboxes = []
while(len(before_refine_pred_bboxes) != 0):
row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
print(row_bboxes)
pred_bboxes.extend(refine_rows(row_bboxes))
all_dis = []
ious = []
matched = {}
for i, gt_box in enumerate(gt_bboxes):
distances = []
#temp_ious = []
for j, pred_box in enumerate(pred_bboxes):
distances.append(distance(gt_box, pred_box))
#temp_ious.append(compute_iou(gt_box, pred_box))
#all_dis.append(min(distances))
#ious.append(temp_ious[distances.index(min(distances))])
if distances.index(min(distances)) not in matched.keys():
matched[distances.index(min(distances))] = [i]
else:
matched[distances.index(min(distances))].append(i)
return matched#, sum(ious) / len(ious)
#先挑选出一行,再进行匹配
def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
gt_box_index = 0
delete_gt_bboxes = gt_bboxes.copy()
match_bboxes_ready = []
matched = {}
while(len(delete_gt_bboxes) != 0):
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
if len(pred_bboxes_rows) > 0:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
print(row_bboxes)
for i, gt_box in enumerate(row_bboxes):
#print(gt_box)
pred_distances = []
distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
for j, pred_box in enumerate(match_bboxes_ready):
distances.append(distance(gt_box, pred_box))
index = pred_distances.index(min(distances))
#print('index', index)
if index not in matched.keys():
matched[index] = [gt_box_index]
else:
matched[index].append(gt_box_index)
gt_box_index += 1
return matched
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
'''
gt_bboxes: 排序后
pred_bboxes:
'''
pre_bbox = gt_bboxes[0]
matched = {}
match_bboxes_ready = []
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
for i, gt_box in enumerate(gt_bboxes):
pred_distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
distances = []
gap_pre = gt_box[1] - pre_bbox[1]
gap_pre_1 = gt_box[0] - pre_bbox[2]
#print(gap_pre, len(pred_bboxes_rows))
if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(pred_bboxes_rows) == 1:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
break
#print(match_bboxes_ready)
for j, pred_box in enumerate(match_bboxes_ready):
distances.append(distance(gt_box, pred_box))
index = pred_distances.index(min(distances))
#print(gt_box, index)
#match_bboxes_ready.pop(distances.index(min(distances)))
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
if index not in matched.keys():
matched[index] = [i]
else:
matched[index].append(i)
pre_bbox = gt_box
return matched
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import math
import time
import traceback
import paddle
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from test1.utility import parse_args
logger = get_logger()
class TableStructurer(object):
def __init__(self, args):
pre_process_list = [{
'ResizeTableImage': {
'max_len': args.table_max_len
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'PaddingTableImage': None
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image']
}
}]
postprocess_params = {
'name': 'TableLabelDecode',
"character_type": args.table_char_type,
"character_dict_path": args.table_char_dict_path,
}
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'table', logger)
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img = data[0]
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {}
preds['structure_probs'] = outputs[1]
preds['loc_preds'] = outputs[0]
post_result = self.postprocess_op(preds)
structure_str_list = post_result['structure_str_list']
res_loc = post_result['res_loc']
imgh, imgw = ori_im.shape[0:2]
res_loc_final = []
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
res_loc_final.append([left, top, right, bottom])
structure_str_list = structure_str_list[0][:-1]
structure_str_list = ['<html>', '<body>', '<table>'] + structure_str_list + ['</table>', '</body>', '</html>']
elapse = time.time() - starttime
return (structure_str_list, res_loc_final), elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
table_structurer = TableStructurer(args)
count = 0
total_time = 0
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
structure_res, elapse = table_structurer(img)
logger.info("result: {}".format(structure_res))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(parse_args())
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import copy
import numpy as np
import time
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from test1.table.matcher import distance, compute_iou
from test1.utility import parse_args
import test1.table.predict_structure as predict_strture
logger = get_logger()
def expand(pix, det_box, shape):
x0, y0, x1, y1 = det_box
# print(shape)
h, w, c = shape
tmp_x0 = x0 - pix
tmp_x1 = x1 + pix
tmp_y0 = y0 - pix
tmp_y1 = y1 + pix
x0_ = tmp_x0 if tmp_x0 >= 0 else 0
x1_ = tmp_x1 if tmp_x1 <= w else w
y0_ = tmp_y0 if tmp_y0 >= 0 else 0
y1_ = tmp_y1 if tmp_y1 <= h else h
return x0_, y0_, x1_, y1_
class TableSystem(object):
def __init__(self, args, text_detector=None, text_recognizer=None):
self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
self.table_structurer = predict_strture.TableStructurer(args)
def __call__(self, img):
ori_im = img.copy()
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
dt_boxes = sorted_boxes(dt_boxes)
r_boxes = []
for box in dt_boxes:
x_min = box[:, 0].min() - 1
x_max = box[:, 0].max() + 1
y_min = box[:, 1].min() - 1
y_max = box[:, 1].max() + 1
box = [x_min, y_min, x_max, y_max]
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
return None, None
img_crop_list = []
for i in range(len(dt_boxes)):
det_box = dt_boxes[i]
x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
img_crop_list.append(text_rect)
rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
return pred_html
def rebuild_table(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
matched_index = self.match_result(dt_boxes, pred_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
return pred_html, pred
def match_result(self, dt_boxes, pred_bboxes):
matched = {}
for i, gt_box in enumerate(dt_boxes):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances = []
for j, pred_box in enumerate(pred_bboxes):
distances.append(
(distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU
sorted_distances = distances.copy()
# 根据距离和IOU挑选最"近"的cell
sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
end_html = []
td_index = 0
for tag in pred_structures:
if '</td>' in tag:
if td_index in matched_index.keys():
b_with = False
if '<b>' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1:
b_with = True
end_html.extend('<b>')
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
if content[0] == ' ':
content = content[1:]
if '<b>' in content:
content = content[3:]
if '</b>' in content:
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]:
content += ' '
end_html.extend(content)
if b_with:
end_html.extend('</b>')
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def to_excel(html_table, excel_path):
from tablepyxl import tablepyxl
tablepyxl.document_to_xl(html_table, excel_path)
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
os.makedirs(args.output, exist_ok=True)
text_sys = TableSystem(args)
img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
excel_path = os.path.join(args.output, os.path.basename(image_file).split('.')[0] + '.xlsx')
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
pred_html = text_sys(img)
to_excel(pred_html, excel_path)
logger.info('excel saved to {}'.format(excel_path))
logger.info(pred_html)
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
if __name__ == "__main__":
args = parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = [sys.executable, "-u"] + sys.argv + [
"--process_id={}".format(process_id),
"--use_mp={}".format(False)
]
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
else:
main(args)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['TEDS']
from .table_metric import TEDS
\ No newline at end of file
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
"""
A parallel version of the map function with a progress bar.
Args:
array (array-like): An array to iterate over.
function (function): A python function to apply to the elements of array
n_jobs (int, default=16): The number of cores to use
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
keyword arguments to function
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
Useful for catching bugs
Returns:
[function(array[0]), function(array[1]), ...]
"""
# We run the first few iterations serially to catch bugs
if front_num > 0:
front = [function(**a) if use_kwargs else function(a)
for a in array[:front_num]]
else:
front = []
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if n_jobs == 1:
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
# Assemble the workers
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
# Pass the elements of array into function
if use_kwargs:
futures = [pool.submit(function, **a) for a in array[front_num:]]
else:
futures = [pool.submit(function, a) for a in array[front_num:]]
kwargs = {
'total': len(futures),
'unit': 'it',
'unit_scale': True,
'leave': True
}
# Print out the progress as tasks complete
for f in tqdm(as_completed(futures), **kwargs):
pass
out = []
# Get the results from the futures.
for i, future in tqdm(enumerate(futures)):
try:
out.append(future.result())
except Exception as e:
out.append(e)
return front + out
# Copyright 2020 IBM
# Author: peter.zhong@au1.ibm.com
#
# This is free software; you can redistribute it and/or modify
# it under the terms of the Apache 2.0 License.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Apache 2.0 License for more details.
import distance
from apted import APTED, Config
from apted.helpers import Tree
from lxml import etree, html
from collections import deque
from .parallel import parallel_process
from tqdm import tqdm
class TableTree(Tree):
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
self.tag = tag
self.colspan = colspan
self.rowspan = rowspan
self.content = content
self.children = list(children)
def bracket(self):
"""Show tree using brackets notation"""
if self.tag == 'td':
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
(self.tag, self.colspan, self.rowspan, self.content)
else:
result = '"tag": %s' % self.tag
for child in self.children:
result += child.bracket()
return "{{{}}}".format(result)
class CustomConfig(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2):
"""Compares attributes of trees"""
#print(node1.tag)
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
if node1.content or node2.content:
#print(node1.content, )
return self.normalized_distance(node1.content, node2.content)
return 0.
class CustomConfig_del_short(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2):
"""Compares attributes of trees"""
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
if node1.content or node2.content:
#print('before')
#print(node1.content, node2.content)
#print('after')
node1_content = node1.content
node2_content = node2.content
if len(node1_content) < 3:
node1_content = ['####']
if len(node2_content) < 3:
node2_content = ['####']
return self.normalized_distance(node1_content, node2_content)
return 0.
class CustomConfig_del_block(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2):
"""Compares attributes of trees"""
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
if node1.content or node2.content:
node1_content = node1.content
node2_content = node2.content
while ' ' in node1_content:
print(node1_content.index(' '))
node1_content.pop(node1_content.index(' '))
while ' ' in node2_content:
print(node2_content.index(' '))
node2_content.pop(node2_content.index(' '))
return self.normalized_distance(node1_content, node2_content)
return 0.
class TEDS(object):
''' Tree Edit Distance basead Similarity
'''
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
assert isinstance(n_jobs, int) and (
n_jobs >= 1), 'n_jobs must be an integer greather than 1'
self.structure_only = structure_only
self.n_jobs = n_jobs
self.ignore_nodes = ignore_nodes
self.__tokens__ = []
def tokenize(self, node):
''' Tokenizes table cells
'''
self.__tokens__.append('<%s>' % node.tag)
if node.text is not None:
self.__tokens__ += list(node.text)
for n in node.getchildren():
self.tokenize(n)
if node.tag != 'unk':
self.__tokens__.append('</%s>' % node.tag)
if node.tag != 'td' and node.tail is not None:
self.__tokens__ += list(node.tail)
def load_html_tree(self, node, parent=None):
''' Converts HTML tree to the format required by apted
'''
global __tokens__
if node.tag == 'td':
if self.structure_only:
cell = []
else:
self.__tokens__ = []
self.tokenize(node)
cell = self.__tokens__[1:-1].copy()
new_node = TableTree(node.tag,
int(node.attrib.get('colspan', '1')),
int(node.attrib.get('rowspan', '1')),
cell, *deque())
else:
new_node = TableTree(node.tag, None, None, None, *deque())
if parent is not None:
parent.children.append(new_node)
if node.tag != 'td':
for n in node.getchildren():
self.load_html_tree(n, new_node)
if parent is None:
return new_node
def evaluate(self, pred, true):
''' Computes TEDS score between the prediction and the ground truth of a
given sample
'''
if (not pred) or (not true):
return 0.0
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
pred = html.fromstring(pred, parser=parser)
true = html.fromstring(true, parser=parser)
if pred.xpath('body/table') and true.xpath('body/table'):
pred = pred.xpath('body/table')[0]
true = true.xpath('body/table')[0]
if self.ignore_nodes:
etree.strip_tags(pred, *self.ignore_nodes)
etree.strip_tags(true, *self.ignore_nodes)
n_nodes_pred = len(pred.xpath(".//*"))
n_nodes_true = len(true.xpath(".//*"))
n_nodes = max(n_nodes_pred, n_nodes_true)
tree_pred = self.load_html_tree(pred)
tree_true = self.load_html_tree(true)
distance = APTED(tree_pred, tree_true,
CustomConfig()).compute_edit_distance()
return 1.0 - (float(distance) / n_nodes)
else:
return 0.0
def batch_evaluate(self, pred_json, true_json):
''' Computes TEDS score between the prediction and the ground truth of
a batch of samples
@params pred_json: {'FILENAME': 'HTML CODE', ...}
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
@output: {'FILENAME': 'TEDS SCORE', ...}
'''
samples = true_json.keys()
if self.n_jobs == 1:
scores = [self.evaluate(pred_json.get(
filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
else:
inputs = [{'pred': pred_json.get(
filename, ''), 'true': true_json[filename]['html']} for filename in samples]
scores = parallel_process(
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
scores = dict(zip(samples, scores))
return scores
def batch_evaluate_html(self, pred_htmls, true_htmls):
''' Computes TEDS score between the prediction and the ground truth of
a batch of samples
'''
if self.n_jobs == 1:
scores = [self.evaluate(pred_html, true_html) for (
pred_html, true_html) in zip(pred_htmls, true_htmls)]
else:
inputs = [{"pred": pred_html, "true": true_html} for(
pred_html, true_html) in zip(pred_htmls, true_htmls)]
scores = parallel_process(
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
return scores
if __name__ == '__main__':
import json
import pprint
with open('sample_pred.json') as fp:
pred_json = json.load(fp)
with open('sample_gt.json') as fp:
true_json = json.load(fp)
teds = TEDS(n_jobs=4)
scores = teds.batch_evaluate(pred_json, true_json)
pp = pprint.PrettyPrinter()
pp.pprint(scores)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# This is where we handle translating css styles into openpyxl styles
# and cascading those from parent to child in the dom.
from openpyxl.cell import cell
from openpyxl.styles import Font, Alignment, PatternFill, NamedStyle, Border, Side, Color
from openpyxl.styles.fills import FILL_SOLID
from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE
from openpyxl.styles.colors import BLACK
FORMAT_DATE_MMDDYYYY = 'mm/dd/yyyy'
def colormap(color):
"""
Convenience for looking up known colors
"""
cmap = {'black': BLACK}
return cmap.get(color, color)
def style_string_to_dict(style):
"""
Convert css style string to a python dictionary
"""
def clean_split(string, delim):
return (s.strip() for s in string.split(delim))
styles = [clean_split(s, ":") for s in style.split(";") if ":" in s]
return dict(styles)
def get_side(style, name):
return {'border_style': style.get('border-{}-style'.format(name)),
'color': colormap(style.get('border-{}-color'.format(name)))}
known_styles = {}
def style_dict_to_named_style(style_dict, number_format=None):
"""
Change css style (stored in a python dictionary) to openpyxl NamedStyle
"""
style_and_format_string = str({
'style_dict': style_dict,
'parent': style_dict.parent,
'number_format': number_format,
})
if style_and_format_string not in known_styles:
# Font
font = Font(bold=style_dict.get('font-weight') == 'bold',
color=style_dict.get_color('color', None),
size=style_dict.get('font-size'))
# Alignment
alignment = Alignment(horizontal=style_dict.get('text-align', 'general'),
vertical=style_dict.get('vertical-align'),
wrap_text=style_dict.get('white-space', 'nowrap') == 'normal')
# Fill
bg_color = style_dict.get_color('background-color')
fg_color = style_dict.get_color('foreground-color', Color())
fill_type = style_dict.get('fill-type')
if bg_color and bg_color != 'transparent':
fill = PatternFill(fill_type=fill_type or FILL_SOLID,
start_color=bg_color,
end_color=fg_color)
else:
fill = PatternFill()
# Border
border = Border(left=Side(**get_side(style_dict, 'left')),
right=Side(**get_side(style_dict, 'right')),
top=Side(**get_side(style_dict, 'top')),
bottom=Side(**get_side(style_dict, 'bottom')),
diagonal=Side(**get_side(style_dict, 'diagonal')),
diagonal_direction=None,
outline=Side(**get_side(style_dict, 'outline')),
vertical=None,
horizontal=None)
name = 'Style {}'.format(len(known_styles) + 1)
pyxl_style = NamedStyle(name=name, font=font, fill=fill, alignment=alignment, border=border,
number_format=number_format)
known_styles[style_and_format_string] = pyxl_style
return known_styles[style_and_format_string]
class StyleDict(dict):
"""
It's like a dictionary, but it looks for items in the parent dictionary
"""
def __init__(self, *args, **kwargs):
self.parent = kwargs.pop('parent', None)
super(StyleDict, self).__init__(*args, **kwargs)
def __getitem__(self, item):
if item in self:
return super(StyleDict, self).__getitem__(item)
elif self.parent:
return self.parent[item]
else:
raise KeyError('{} not found'.format(item))
def __hash__(self):
return hash(tuple([(k, self.get(k)) for k in self._keys()]))
# Yielding the keys avoids creating unnecessary data structures
# and happily works with both python2 and python3 where the
# .keys() method is a dictionary_view in python3 and a list in python2.
def _keys(self):
yielded = set()
for k in self.keys():
yielded.add(k)
yield k
if self.parent:
for k in self.parent._keys():
if k not in yielded:
yielded.add(k)
yield k
def get(self, k, d=None):
try:
return self[k]
except KeyError:
return d
def get_color(self, k, d=None):
"""
Strip leading # off colors if necessary
"""
color = self.get(k, d)
if hasattr(color, 'startswith') and color.startswith('#'):
color = color[1:]
if len(color) == 3: # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
color = ''.join(2 * c for c in color)
return color
class Element(object):
"""
Our base class for representing an html element along with a cascading style.
The element is created along with a parent so that the StyleDict that we store
can point to the parent's StyleDict.
"""
def __init__(self, element, parent=None):
self.element = element
self.number_format = None
parent_style = parent.style_dict if parent else None
self.style_dict = StyleDict(style_string_to_dict(element.get('style', '')), parent=parent_style)
self._style_cache = None
def style(self):
"""
Turn the css styles for this element into an openpyxl NamedStyle.
"""
if not self._style_cache:
self._style_cache = style_dict_to_named_style(self.style_dict, number_format=self.number_format)
return self._style_cache
def get_dimension(self, dimension_key):
"""
Extracts the dimension from the style dict of the Element and returns it as a float.
"""
dimension = self.style_dict.get(dimension_key)
if dimension:
if dimension[-2:] in ['px', 'em', 'pt', 'in', 'cm']:
dimension = dimension[:-2]
dimension = float(dimension)
return dimension
class Table(Element):
"""
The concrete implementations of Elements are semantically named for the types of elements we are interested in.
This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
"""
def __init__(self, table):
"""
takes an html table object (from lxml)
"""
super(Table, self).__init__(table)
table_head = table.find('thead')
self.head = TableHead(table_head, parent=self) if table_head is not None else None
table_body = table.find('tbody')
self.body = TableBody(table_body if table_body is not None else table, parent=self)
class TableHead(Element):
"""
This class maps to the `<th>` element of the html table.
"""
def __init__(self, head, parent=None):
super(TableHead, self).__init__(head, parent=parent)
self.rows = [TableRow(tr, parent=self) for tr in head.findall('tr')]
class TableBody(Element):
"""
This class maps to the `<tbody>` element of the html table.
"""
def __init__(self, body, parent=None):
super(TableBody, self).__init__(body, parent=parent)
self.rows = [TableRow(tr, parent=self) for tr in body.findall('tr')]
class TableRow(Element):
"""
This class maps to the `<tr>` element of the html table.
"""
def __init__(self, tr, parent=None):
super(TableRow, self).__init__(tr, parent=parent)
self.cells = [TableCell(cell, parent=self) for cell in tr.findall('th') + tr.findall('td')]
def element_to_string(el):
return _element_to_string(el).strip()
def _element_to_string(el):
string = ''
for x in el.iterchildren():
string += '\n' + _element_to_string(x)
text = el.text.strip() if el.text else ''
tail = el.tail.strip() if el.tail else ''
return text + string + '\n' + tail
class TableCell(Element):
"""
This class maps to the `<td>` element of the html table.
"""
CELL_TYPES = {'TYPE_STRING', 'TYPE_FORMULA', 'TYPE_NUMERIC', 'TYPE_BOOL', 'TYPE_CURRENCY', 'TYPE_PERCENTAGE',
'TYPE_NULL', 'TYPE_INLINE', 'TYPE_ERROR', 'TYPE_FORMULA_CACHE_STRING', 'TYPE_INTEGER'}
def __init__(self, cell, parent=None):
super(TableCell, self).__init__(cell, parent=parent)
self.value = element_to_string(cell)
self.number_format = self.get_number_format()
def data_type(self):
cell_types = self.CELL_TYPES & set(self.element.get('class', '').split())
if cell_types:
if 'TYPE_FORMULA' in cell_types:
# Make sure TYPE_FORMULA takes precedence over the other classes in the set.
cell_type = 'TYPE_FORMULA'
elif cell_types & {'TYPE_CURRENCY', 'TYPE_INTEGER', 'TYPE_PERCENTAGE'}:
cell_type = 'TYPE_NUMERIC'
else:
cell_type = cell_types.pop()
else:
cell_type = 'TYPE_STRING'
return getattr(cell, cell_type)
def get_number_format(self):
if 'TYPE_CURRENCY' in self.element.get('class', '').split():
return FORMAT_CURRENCY_USD_SIMPLE
if 'TYPE_INTEGER' in self.element.get('class', '').split():
return '#,##0'
if 'TYPE_PERCENTAGE' in self.element.get('class', '').split():
return FORMAT_PERCENTAGE
if 'TYPE_DATE' in self.element.get('class', '').split():
return FORMAT_DATE_MMDDYYYY
if self.data_type() == cell.TYPE_NUMERIC:
try:
int(self.value)
except ValueError:
return '#,##0.##'
else:
return '#,##0'
def format(self, cell):
cell.style = self.style()
data_type = self.data_type()
if data_type:
cell.data_type = data_type
\ No newline at end of file
# Do imports like python3 so our package works for 2 and 3
from __future__ import absolute_import
from lxml import html
from openpyxl import Workbook
from openpyxl.utils import get_column_letter
from premailer import Premailer
from tablepyxl.style import Table
def string_to_int(s):
if s.isdigit():
return int(s)
return 0
def get_Tables(doc):
tree = html.fromstring(doc)
comments = tree.xpath('//comment()')
for comment in comments:
comment.drop_tag()
return [Table(table) for table in tree.xpath('//table')]
def write_rows(worksheet, elem, row, column=1):
"""
Writes every tr child element of elem to a row in the worksheet
returns the next row after all rows are written
"""
from openpyxl.cell.cell import MergedCell
initial_column = column
for table_row in elem.rows:
for table_cell in table_row.cells:
cell = worksheet.cell(row=row, column=column)
while isinstance(cell, MergedCell):
column += 1
cell = worksheet.cell(row=row, column=column)
colspan = string_to_int(table_cell.element.get("colspan", "1"))
rowspan = string_to_int(table_cell.element.get("rowspan", "1"))
if rowspan > 1 or colspan > 1:
worksheet.merge_cells(start_row=row, start_column=column,
end_row=row + rowspan - 1, end_column=column + colspan - 1)
cell.value = table_cell.value
table_cell.format(cell)
min_width = table_cell.get_dimension('min-width')
max_width = table_cell.get_dimension('max-width')
if colspan == 1:
# Initially, when iterating for the first time through the loop, the width of all the cells is None.
# As we start filling in contents, the initial width of the cell (which can be retrieved by:
# worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
# cell in the same column (i.e. width of A2 = width of A1)
width = max(worksheet.column_dimensions[get_column_letter(column)].width or 0, len(table_cell.value) + 2)
if max_width and width > max_width:
width = max_width
elif min_width and width < min_width:
width = min_width
worksheet.column_dimensions[get_column_letter(column)].width = width
column += colspan
row += 1
column = initial_column
return row
def table_to_sheet(table, wb):
"""
Takes a table and workbook and writes the table to a new sheet.
The sheet title will be the same as the table attribute name.
"""
ws = wb.create_sheet(title=table.element.get('name'))
insert_table(table, ws, 1, 1)
def document_to_workbook(doc, wb=None, base_url=None):
"""
Takes a string representation of an html document and writes one sheet for
every table in the document.
The workbook is returned
"""
if not wb:
wb = Workbook()
wb.remove(wb.active)
inline_styles_doc = Premailer(doc, base_url=base_url, remove_classes=False).transform()
tables = get_Tables(inline_styles_doc)
for table in tables:
table_to_sheet(table, wb)
return wb
def document_to_xl(doc, filename, base_url=None):
"""
Takes a string representation of an html document and writes one sheet for
every table in the document. The workbook is written out to a file called filename
"""
wb = document_to_workbook(doc, base_url=base_url)
wb.save(filename)
def insert_table(table, worksheet, column, row):
if table.head:
row = write_rows(worksheet, table.head, row, column)
if table.body:
row = write_rows(worksheet, table.body, row, column)
def insert_table_at_cell(table, cell):
"""
Inserts a table at the location of an openpyxl Cell object.
"""
ws = cell.parent
column, row = cell.column, cell.row
insert_table(table, ws, column, row)
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from PIL import Image
import numpy as np
from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args
def init_args():
parser = infer_args()
# params for output
parser.add_argument("--output", type=str, default='./output/table')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_model_dir", type=str)
parser.add_argument("--table_char_type", type=str, default='en')
parser.add_argument("--table_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
return parser
def parse_args():
parser = init_args()
return parser.parse_args()
def draw_result(image, result, font_path):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
boxes, txts, scores = [], [], []
for region in result:
if region['type'] == 'Table':
pass
elif region['type'] == 'Figure':
pass
else:
for box, rec_res in zip(region['res'][0], region['res'][1]):
boxes.append(np.array(box).reshape(-1, 2))
txts.append(rec_res[0])
scores.append(rec_res[1])
im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0)
return im_show
\ No newline at end of file
......@@ -44,12 +44,20 @@ def main():
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len(
getattr(post_process_class, 'character'))
char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type']
best_model_dict = init_model(config, model, logger)
best_model_dict = init_model(config, model)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
......@@ -60,7 +68,7 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, use_srn)
eval_class, model_type, use_srn)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
......
......@@ -17,7 +17,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
import argparse
......@@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config['PostProcess'],
config['Global'])
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
init_model(config, model, logger)
model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
if config['Architecture']['algorithm'] == "SRN":
max_text_length = config['Architecture']['Head']['max_text_length']
def export_single_model(model, arch_config, save_path, logger):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 64, 256], dtype='float32'), [
shape=[None, 1, 64, 256], dtype="float32"), [
paddle.static.InputSpec(
shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec(
......@@ -71,24 +51,67 @@ def main():
model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
if arch_config["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
if "Transform" in arch_config and arch_config[
"Transform"] is not None and arch_config["Transform"][
"name"] == "TPS":
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
shape=[None] + infer_shape, dtype="float32")
])
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))
logger.info("inference model is saved to {}".format(save_path))
return
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config["PostProcess"],
config["Global"])
# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
if config["Architecture"]["algorithm"] in ["Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
init_model(config, model)
model.eval()
save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(model, arch_config, save_path, logger)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment