Commit 253b8453 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents 7cad4817 bc999986
...@@ -110,9 +110,9 @@ In ppocr, the network is divided into four stages: Transform, Backbone, Neck and ...@@ -110,9 +110,9 @@ In ppocr, the network is divided into four stages: Transform, Backbone, Neck and
| Parameter | Use | Defaults | Note | | Parameter | Use | Defaults | Note |
| :---------------------: | :---------------------: | :--------------: | :--------------------: | | :---------------------: | :---------------------: | :--------------: | :--------------------: |
| **dataset** | Return one sample per iteration | - | - | | **dataset** | Return one sample per iteration | - | - |
| name | dataset class name | SimpleDataSet | Currently support`SimpleDataSet`,`LMDBDateSet` | | name | dataset class name | SimpleDataSet | Currently support`SimpleDataSet`,`LMDBDataSet` |
| data_dir | Image folder path | ./train_data | \ | | data_dir | Image folder path | ./train_data | \ |
| label_file_list | Groundtruth file path | ["./train_data/train_list.txt"] | This parameter is not required when dataset is LMDBDateSet | | label_file_list | Groundtruth file path | ["./train_data/train_list.txt"] | This parameter is not required when dataset is LMDBDataSet |
| ratio_list | Ratio of data set | [1.0] | If there are two train_lists in label_file_list and ratio_list is [0.4,0.6], 40% will be sampled from train_list1, and 60% will be sampled from train_list2 to combine the entire dataset | | ratio_list | Ratio of data set | [1.0] | If there are two train_lists in label_file_list and ratio_list is [0.4,0.6], 40% will be sampled from train_list1, and 60% will be sampled from train_list2 to combine the entire dataset |
| transforms | List of methods to transform images and labels | [DecodeImage,CTCLabelEncode,RecResizeImg,KeepKeys] | see[ppocr/data/imaug](../../ppocr/data/imaug) | | transforms | List of methods to transform images and labels | [DecodeImage,CTCLabelEncode,RecResizeImg,KeepKeys] | see[ppocr/data/imaug](../../ppocr/data/imaug) |
| **loader** | dataloader related | - | | | **loader** | dataloader related | - | |
......
# Distributed training
## Introduction
The high performance of distributed training is one of the core advantages of PaddlePaddle. In the classification task, distributed training can achieve almost linear speedup ratio. Generally, OCR training task need massive training data. Such as recognition, ppocrv2.0 model is trained based on 1800W dataset, which is very time-consuming if using single machine. Therefore, the distributed training is used in paddleocr to speedup the training task. For more information about distributed training, please refer to [distributed training quick start tutorial](https://fleet-x.readthedocs.io/en/latest/paddle_fleet_rst/parameter_server/ps_quick_start.html).
## Quick Start
### Training with single machine
Take recognition as an example. After the data is prepared locally, start the training task with the interface of `paddle.distributed.launch`. The start command as follows:
```shell
python3 -m paddle.distributed.launch \
--log_dir=./log/ \
--gpus '0,1,2,3,4,5,6,7' \
tools/train.py \
-c configs/rec/rec_mv3_none_bilstm_ctc.yml
```
### Training with multi machine
Compared with single machine, training with multi machine only needs to add the parameter `--ips` to start command, which represents the IP list of machines used for distributed training, and the IP of different machines are separated by commas. The start command as follows:
```shell
ip_list="192.168.0.1,192.168.0.2"
python3 -m paddle.distributed.launch \
--log_dir=./log/ \
--ips="${ip_list}" \
--gpus="0,1,2,3,4,5,6,7" \
tools/train.py \
-c configs/rec/rec_mv3_none_bilstm_ctc.yml
```
**Notice:**
* The IP addresses of different machines need to be separated by commas, which can be queried through `ifconfig` or `ipconfig`.
* Different machines need to be set to be secret free and can `ping` success with others directly, otherwise communication cannot establish between them.
* The code, data and start command betweent different machines must be completely consistent and then all machines need to run start command. The first machine in the `ip_list` is set to `trainer0`, and so on.
## Performance comparison
* Based on 26W public recognition dataset (LSVT, rctw, mtwi), training on single 8-card P40 and dual 8-card P40, the final time consumption is as follows.
| Model | Config file | Number of machines | Number of GPUs per machine | Training time | Recognition acc | Speedup ratio |
| :-------: | :------------: | :----------------: | :----------------------------: | :------------------: | :--------------: | :-----------: |
| CRNN | configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml | 1 | 8 | 60h | 66.7% | - |
| CRNN | configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml | 2 | 8 | 40h | 67.0% | 150% |
It can be seen that the training time is shortened from 60h to 40h, the speedup ratio can reach 150% (60h / 40h), and the efficiency is 75% (60h / (40h * 2)).
...@@ -237,7 +237,7 @@ Optimizer: ...@@ -237,7 +237,7 @@ Optimizer:
Train: Train:
dataset: dataset:
# Type of dataset,we support LMDBDateSet and SimpleDataSet # Type of dataset,we support LMDBDataSet and SimpleDataSet
name: SimpleDataSet name: SimpleDataSet
# Path of dataset # Path of dataset
data_dir: ./train_data/ data_dir: ./train_data/
...@@ -257,7 +257,7 @@ Train: ...@@ -257,7 +257,7 @@ Train:
Eval: Eval:
dataset: dataset:
# Type of dataset,we support LMDBDateSet and SimpleDataSet # Type of dataset,we support LMDBDataSet and SimpleDataSet
name: SimpleDataSet name: SimpleDataSet
# Path of dataset # Path of dataset
data_dir: ./train_data data_dir: ./train_data
...@@ -394,7 +394,7 @@ Global: ...@@ -394,7 +394,7 @@ Global:
Train: Train:
dataset: dataset:
# Type of dataset,we support LMDBDateSet and SimpleDataSet # Type of dataset,we support LMDBDataSet and SimpleDataSet
name: SimpleDataSet name: SimpleDataSet
# Path of dataset # Path of dataset
data_dir: ./train_data/ data_dir: ./train_data/
...@@ -404,7 +404,7 @@ Train: ...@@ -404,7 +404,7 @@ Train:
Eval: Eval:
dataset: dataset:
# Type of dataset,we support LMDBDateSet and SimpleDataSet # Type of dataset,we support LMDBDataSet and SimpleDataSet
name: SimpleDataSet name: SimpleDataSet
# Path of dataset # Path of dataset
data_dir: ./train_data data_dir: ./train_data
......
...@@ -59,7 +59,7 @@ Visualization of results ...@@ -59,7 +59,7 @@ Visualization of results
from paddleocr import PaddleOCR,draw_ocr from paddleocr import PaddleOCR,draw_ocr
ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg' img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg'
result = ocr.ocr(img_path) result = ocr.ocr(img_path, cls=False)
for line in result: for line in result:
print(line) print(line)
...@@ -362,3 +362,5 @@ im_show.save('result.jpg') ...@@ -362,3 +362,5 @@ im_show.save('result.jpg')
| det | Enable detction when `ppocr.ocr` func exec | TRUE | | det | Enable detction when `ppocr.ocr` func exec | TRUE |
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE | | rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE | | cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
| show_log | Whether to print log in det and rec
| FALSE |
\ No newline at end of file
...@@ -19,18 +19,17 @@ __dir__ = os.path.dirname(__file__) ...@@ -19,18 +19,17 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, '')) sys.path.append(os.path.join(__dir__, ''))
import cv2 import cv2
import logging
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import tarfile
import requests
from tqdm import tqdm
from tools.infer import predict_system from tools.infer import predict_system
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from tools.infer.utility import draw_ocr from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, init_args, str2bool
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
...@@ -123,150 +122,24 @@ SUPPORT_REC_MODEL = ['CRNN'] ...@@ -123,150 +122,24 @@ SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
def download_with_progressbar(url, save_path): def parse_args(mMain=True):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def parse_args(mMain=True, add_help=True):
import argparse import argparse
parser = init_args()
def str2bool(v): parser.add_help = mMain
return v.lower() in ("true", "t", "1")
if mMain:
parser = argparse.ArgumentParser(add_help=add_help)
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str, default=None)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str, default=None)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument("--rec_char_dict_path", type=str, default=None)
parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier
parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=6)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--lang", type=str, default='ch') parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
for action in parser._actions:
if action.dest == 'rec_char_dict_path':
action.default = None
if mMain:
return parser.parse_args() return parser.parse_args()
else: else:
return argparse.Namespace( inference_args_dict = {}
use_gpu=True, for action in parser._actions:
ir_optim=True, inference_args_dict[action.dest] = action.default
use_tensorrt=False, return argparse.Namespace(**inference_args_dict)
gpu_mem=8000,
image_dir='',
det_algorithm='DB',
det_model_dir=None,
det_limit_side_len=960,
det_limit_type='max',
det_db_thresh=0.3,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
use_dilation=False,
det_db_score_mode="fast",
det_east_score_thresh=0.8,
det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2,
rec_algorithm='CRNN',
rec_model_dir=None,
rec_image_shape="3, 32, 320",
rec_char_type='ch',
rec_batch_num=6,
max_text_length=25,
rec_char_dict_path=None,
use_space_char=True,
drop_score=0.5,
cls_model_dir=None,
cls_image_shape="3, 48, 192",
label_list=['0', '180'],
cls_batch_num=6,
cls_thresh=0.9,
enable_mkldnn=False,
use_zero_copy_run=False,
use_pdserving=False,
lang='ch',
det=True,
rec=True,
use_angle_cls=False)
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
...@@ -276,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -276,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem):
args: args:
**kwargs: other params show in paddleocr --help **kwargs: other params show in paddleocr --help
""" """
postprocess_params = parse_args(mMain=False, add_help=False) params = parse_args(mMain=False)
postprocess_params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls if not params.show_log:
lang = postprocess_params.lang logger.setLevel(logging.INFO)
self.use_angle_cls = params.use_angle_cls
lang = params.lang
latin_lang = [ latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
...@@ -311,42 +186,41 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -311,42 +186,41 @@ class PaddleOCR(predict_system.TextSystem):
else: else:
det_lang = "en" det_lang = "en"
use_inner_dict = False use_inner_dict = False
if postprocess_params.rec_char_dict_path is None: if params.rec_char_dict_path is None:
use_inner_dict = True use_inner_dict = True
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path'] 'dict_path']
# init model dir # init model dir
if postprocess_params.det_model_dir is None: params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION, os.path.join(BASE_DIR, VERSION, 'det', det_lang),
'det', det_lang)
if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
'rec', lang)
if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
print(postprocess_params)
# download model
maybe_download(postprocess_params.det_model_dir,
model_urls['det'][det_lang]) model_urls['det'][det_lang])
maybe_download(postprocess_params.rec_model_dir, params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'rec', lang),
model_urls['rec'][lang]['url']) model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls']) params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
os.path.join(BASE_DIR, VERSION, 'cls'),
model_urls['cls'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.cls_model_dir, cls_url)
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
sys.exit(0) sys.exit(0)
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: if params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0) sys.exit(0)
if use_inner_dict: if use_inner_dict:
postprocess_params.rec_char_dict_path = str( params.rec_char_dict_path = str(
Path(__file__).parent / postprocess_params.rec_char_dict_path) Path(__file__).parent / params.rec_char_dict_path)
print(params)
# init det_model and rec_model # init det_model and rec_model
super().__init__(postprocess_params) super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=False): def ocr(self, img, det=True, rec=True, cls=True):
""" """
ocr with paddleocr ocr with paddleocr
args: args:
...@@ -358,9 +232,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -358,9 +232,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, list) and det == True: if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false') logger.error('When input a list of images, det must be false')
exit(0) exit(0)
if cls == False: if cls == True and self.use_angle_cls == False:
self.use_angle_cls = False
elif cls == True and self.use_angle_cls == False:
logger.warning( logger.warning(
'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
) )
...@@ -382,7 +254,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -382,7 +254,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, np.ndarray) and len(img.shape) == 2: if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if det and rec: if det and rec:
dt_boxes, rec_res = self.__call__(img) dt_boxes, rec_res = self.__call__(img, cls)
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
elif det and not rec: elif det and not rec:
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
...@@ -392,7 +264,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -392,7 +264,7 @@ class PaddleOCR(predict_system.TextSystem):
else: else:
if not isinstance(img, list): if not isinstance(img, list):
img = [img] img = [img]
if self.use_angle_cls: if self.use_angle_cls and cls:
img, cls_res, elapse = self.text_classifier(img) img, cls_res, elapse = self.text_classifier(img)
if not rec: if not rec:
return cls_res return cls_res
...@@ -404,7 +276,7 @@ def main(): ...@@ -404,7 +276,7 @@ def main():
# for cmd # for cmd
args = parse_args(mMain=True) args = parse_args(mMain=True)
image_dir = args.image_dir image_dir = args.image_dir
if image_dir.startswith('http'): if is_link(image_dir):
download_with_progressbar(image_dir, 'tmp.jpg') download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg'] image_file_list = ['tmp.jpg']
else: else:
......
...@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators ...@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp) ...@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
......
...@@ -29,6 +29,7 @@ from .label_ops import * ...@@ -29,6 +29,7 @@ from .label_ops import *
from .east_process import * from .east_process import *
from .sast_process import * from .sast_process import *
from .pg_process import * from .pg_process import *
from .gen_table_mask import *
def transform(data, ops=None): def transform(data, ops=None):
......
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class GenTableMask(object):
""" gen table mask """
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
self.shrink_h_max = 5
self.shrink_w_max = 5
self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
return box_list, projection_map
def projection_cx(self, box_img):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
erode = cv2.erode(thresh1, kernel, iterations=1)
else:
erode = thresh1
# 水平膨胀
kernel = np.ones((1, 5), np.uint8)
erosion = cv2.dilate(erode, kernel, iterations=1)
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
split_bbox_list = []
if len(box_list) > 1:
for i, (h_start, h_end) in enumerate(box_list):
if i == 0:
h_start = 0
if i == len(box_list):
h_end = h
word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
h_end += 1
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
split_bbox_list.append([w_start, h_start, w_end, h_end])
else:
split_bbox_list.append([0, 0, w, h])
return split_bbox_list
def shrink_bbox(self, bbox):
left, top, right, bottom = bbox
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
left_new = left + sh_w
right_new = right - sh_w
top_new = top + sh_h
bottom_new = bottom - sh_h
if left_new >= right_new:
left_new = left
right_new = right
if top_new >= bottom_new:
top_new = top
bottom_new = bottom
return [left_new, top_new, right_new, bottom_new]
def __call__(self, data):
img = data['image']
cells = data['cells']
height, width = img.shape[0:2]
if self.mask_type == 1:
mask_img = np.zeros((height, width), dtype=np.float32)
else:
mask_img = np.zeros((height, width, 3), dtype=np.float32)
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
left, top, right, bottom = bbox
box_img = img[top:bottom, left:right, :].copy()
split_bbox_list = self.projection_cx(box_img)
for sno in range(len(split_bbox_list)):
split_bbox_list[sno][0] += left
split_bbox_list[sno][1] += top
split_bbox_list[sno][2] += left
split_bbox_list[sno][3] += top
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img
else:
mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img
return data
class ResizeTableImage(object):
def __init__(self, max_len, **kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
def get_img_bbox(self, cells):
bbox_list = []
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = []
for bno in range(len(bbox_list)):
left, top, right, bottom = bbox_list[bno].copy()
left = int(left * ratio)
top = int(top * ratio)
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
data['max_len'] = self.max_len
return data
class PaddingTableImage(object):
def __init__(self, **kwargs):
super(PaddingTableImage, self).__init__()
def __call__(self, data):
img = data['image']
max_len = data['max_len']
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img
return data
\ No newline at end of file
...@@ -351,3 +351,162 @@ class SRNLabelEncode(BaseRecLabelEncode): ...@@ -351,3 +351,162 @@ class SRNLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end % beg_or_end
return idx return idx
class TableLabelEncode(object):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
span_weight = 1.0,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
for i, char in enumerate(list_character):
self.dict_character[char] = i
self.dict_elem = {}
for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i
self.span_weight = span_weight
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1+character_num):
character = lines[cno].decode('utf-8').strip("\n")
list_character.append(character)
for eno in range(1+character_num, 1+character_num+elem_num):
elem = lines[eno].decode('utf-8').strip("\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_span_idx_list(self):
span_idx_list = []
for elem in self.dict_elem:
if 'span' in elem:
span_idx_list.append(self.dict_elem[elem])
return span_idx_list
def __call__(self, data):
cells = data['cells']
structure = data['structure']['tokens']
structure = self.encode(structure, 'elem')
if structure is None:
return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1]
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
structure = np.array(structure)
data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
span_idx_list = self.get_span_idx_list()
td_idx_list = np.logical_or(structure == elem_char_idx1, structure == elem_char_idx2)
td_idx_list = np.where(td_idx_list)[0]
structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
img_height, img_width, img_ch = data['image'].shape
if len(span_idx_list) > 0:
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
span_weight = min(max(span_weight, 1.0), self.span_weight)
for cno in range(len(cells)):
if 'bbox' in cells[cno]:
bbox = cells[cno]['bbox'].copy()
bbox[0] = bbox[0] * 1.0 / img_width
bbox[1] = bbox[1] * 1.0 / img_height
bbox[2] = bbox[2] * 1.0 / img_width
bbox[3] = bbox[3] * 1.0 / img_height
td_idx = td_idx_list[cno]
bbox_list[td_idx] = bbox
bbox_list_mask[td_idx] = 1.0
cand_span_idx = td_idx + 1
if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight
data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask
data['structure_mask'] = structure_mask
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num])
return data
def encode(self, text, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
max_len = self.max_text_length
current_dict = self.dict_character
else:
max_len = self.max_elem_length
current_dict = self.dict_elem
if len(text) > max_len:
return None
if len(text) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
text_list = []
for char in text:
if char not in current_dict:
return None
text_list.append(current_dict[char])
if len(text_list) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
return text_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = np.array(self.dict_character[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_character[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = np.array(self.dict_elem[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_elem[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
\ No newline at end of file
...@@ -163,7 +163,7 @@ class DetResizeForTest(object): ...@@ -163,7 +163,7 @@ class DetResizeForTest(object):
img, (ratio_h, ratio_w) img, (ratio_h, ratio_w)
""" """
limit_side_len = self.limit_side_len limit_side_len = self.limit_side_len
h, w, _ = img.shape h, w, c = img.shape
# limit the max side # limit the max side
if self.limit_type == 'max': if self.limit_type == 'max':
...@@ -174,7 +174,7 @@ class DetResizeForTest(object): ...@@ -174,7 +174,7 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w ratio = float(limit_side_len) / w
else: else:
ratio = 1. ratio = 1.
else: elif self.limit_type == 'min':
if min(h, w) < limit_side_len: if min(h, w) < limit_side_len:
if h < w: if h < w:
ratio = float(limit_side_len) / h ratio = float(limit_side_len) / h
...@@ -182,6 +182,10 @@ class DetResizeForTest(object): ...@@ -182,6 +182,10 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w ratio = float(limit_side_len) / w
else: else:
ratio = 1. ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h,w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio) resize_h = int(h * ratio)
resize_w = int(w * ratio) resize_w = int(w * ratio)
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import random
from paddle.io import Dataset
import json
from .imaug import transform, create_operators
class PubTabDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PubTabDataSet, self).__init__()
self.logger = logger
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
label_file_path = dataset_config.pop('label_file_path')
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.do_hard_select = False
if 'hard_select' in loader_config:
self.do_hard_select = loader_config['hard_select']
self.hard_prob = loader_config['hard_prob']
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if 'table_select_type' in loader_config:
self.table_select_type = loader_config['table_select_type']
self.table_select_prob = loader_config['table_select_prob']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path)
with open(label_file_path, "rb") as f:
self.data_lines = f.readlines()
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def __getitem__(self, idx):
try:
data_line = self.data_lines[idx]
data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
select_flag = True
if self.do_hard_select:
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
if self.table_select_type:
structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure)
table_type = "simple"
if 'colspan' in structure_str or 'rowspan' in structure_str:
table_type = "complex"
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
if select_flag:
cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'cells': cells, 'structure':structure}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
else:
outs = None
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
...@@ -13,28 +13,39 @@ ...@@ -13,28 +13,39 @@
# limitations under the License. # limitations under the License.
import copy import copy
import paddle
import paddle.nn as nn
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
def build_loss(config): # rec loss
# det loss from .rec_ctc_loss import CTCLoss
from .det_db_loss import DBLoss from .rec_att_loss import AttentionLoss
from .det_east_loss import EASTLoss from .rec_srn_loss import SRNLoss
from .det_sast_loss import SASTLoss
# cls loss
from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
# rec loss # basic loss function
from .rec_ctc_loss import CTCLoss from .basic_loss import DistanceLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
# cls loss # combined loss function
from .cls_loss import ClsLoss from .combined_loss import CombinedLoss
# e2e loss # table loss
from .e2e_pg_loss import PGLoss from .table_att_loss import TableAttentionLoss
def build_loss(config):
support_dict = [ support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss'] 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format( assert module_name in support_dict, Exception('loss only support {}'.format(
......
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss
class CELoss(nn.Layer):
def __init__(self, epsilon=None):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def forward(self, x, label):
loss_dict = {}
if self.epsilon is not None:
class_num = x.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1)
else:
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
return loss
class DMLLoss(nn.Layer):
"""
DMLLoss
"""
def __init__(self, act=None):
super().__init__()
if act is not None:
assert act in ["softmax", "sigmoid"]
if act == "softmax":
self.act = nn.Softmax(axis=-1)
elif act == "sigmoid":
self.act = nn.Sigmoid()
else:
self.act = None
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0
return loss
class DistanceLoss(nn.Layer):
"""
DistanceLoss:
mode: loss mode
"""
def __init__(self, mode="l2", **kargs):
super().__init__()
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
self.loss_func = nn.L1Loss(**kargs)
elif mode == "l2":
self.loss_func = nn.MSELoss(**kargs)
elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs)
def forward(self, x, y):
return self.loss_func(x, y)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment