Commit ac8c2a89 authored by WenmuZhou's avatar WenmuZhou
Browse files

merge dygraph

parents 88a8be12 e174e9ed
include LICENSE.txt include LICENSE
include README.md include README.md
recursive-include ppocr/utils *.txt utility.py logging.py recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py recursive-include tools/infer *.py
......
...@@ -52,9 +52,10 @@ Architecture: ...@@ -52,9 +52,10 @@ Architecture:
Neck: Neck:
name: SequenceEncoder name: SequenceEncoder
encoder_type: rnn encoder_type: rnn
hidden_size: 48 hidden_size: 64
Head: Head:
name: CTCHead name: CTCHead
mid_channels: 96
fc_decay: 0.00001 fc_decay: 0.00001
Teacher: Teacher:
pretrained: pretrained:
...@@ -71,9 +72,10 @@ Architecture: ...@@ -71,9 +72,10 @@ Architecture:
Neck: Neck:
name: SequenceEncoder name: SequenceEncoder
encoder_type: rnn encoder_type: rnn
hidden_size: 48 hidden_size: 64
Head: Head:
name: CTCHead name: CTCHead
mid_channels: 96
fc_decay: 0.00001 fc_decay: 0.00001
......
Global:
use_gpu: true
epoch_num: 50
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/table_mv3/
save_epoch_step: 5
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 100
max_elem_length: 500
max_cell_num: 500
infer_mode: False
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 5.0
lr:
learning_rate: 0.001
regularizer:
name: 'L2'
factor: 0.00000
Architecture:
model_type: table
algorithm: TableAttn
Backbone:
name: MobileNetV3
scale: 1.0
model_name: small
disable_se: True
Head:
name: TableAttentionHead
hidden_size: 256
l2_decay: 0.00001
loc_type: 2
Loss:
name: TableAttentionLoss
structure_weight: 100.0
loc_weight: 10000.0
PostProcess:
name: TableLabelDecode
Metric:
name: TableMetric
main_indicator: acc
Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ResizeTableImage:
max_len: 488
- TableLabelEncode:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
loader:
shuffle: True
batch_size_per_card: 32
drop_last: True
num_workers: 1
Eval:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/val/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ResizeTableImage:
max_len: 488
- TableLabelEncode:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 16
num_workers: 1
...@@ -47,16 +47,13 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean, ...@@ -47,16 +47,13 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
e /= 255.0; e /= 255.0;
} }
(*im).convertTo(*im, CV_32FC3, e); (*im).convertTo(*im, CV_32FC3, e);
for (int h = 0; h < im->rows; h++) { std::vector<cv::Mat> bgr_channels(3);
for (int w = 0; w < im->cols; w++) { cv::split(*im, bgr_channels);
im->at<cv::Vec3f>(h, w)[0] = for (auto i = 0; i < bgr_channels.size(); i++) {
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) * scale[0]; bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i],
im->at<cv::Vec3f>(h, w)[1] = (0.0 - mean[i]) * scale[i]);
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) * scale[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) * scale[2];
}
} }
cv::merge(bgr_channels, *im);
} }
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
......
...@@ -355,3 +355,4 @@ im_show.save('result.jpg') ...@@ -355,3 +355,4 @@ im_show.save('result.jpg')
| det | 前向时使用启动检测 | TRUE | | det | 前向时使用启动检测 | TRUE |
| rec | 前向时是否启动识别 | TRUE | | rec | 前向时是否启动识别 | TRUE |
| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE | | cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
| show_log | 是否打印det和rec等信息 | FALSE |
...@@ -362,3 +362,5 @@ im_show.save('result.jpg') ...@@ -362,3 +362,5 @@ im_show.save('result.jpg')
| det | Enable detction when `ppocr.ocr` func exec | TRUE | | det | Enable detction when `ppocr.ocr` func exec | TRUE |
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE | | rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE | | cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
| show_log | Whether to print log in det and rec
| FALSE |
\ No newline at end of file
doc/joinus.PNG

78.1 KB | W: | H:

doc/joinus.PNG

205 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -19,17 +19,16 @@ __dir__ = os.path.dirname(__file__) ...@@ -19,17 +19,16 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, '')) sys.path.append(os.path.join(__dir__, ''))
import cv2 import cv2
import logging
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import tarfile
import requests
from tqdm import tqdm
from tools.infer import predict_system from tools.infer import predict_system
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, init_args, str2bool from tools.infer.utility import draw_ocr, init_args, str2bool
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
...@@ -37,84 +36,84 @@ __all__ = ['PaddleOCR'] ...@@ -37,84 +36,84 @@ __all__ = ['PaddleOCR']
model_urls = { model_urls = {
'det': { 'det': {
'ch': 'ch':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'en': 'en':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar' 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
}, },
'rec': { 'rec': {
'ch': { 'ch': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt' 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
}, },
'en': { 'en': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt' 'dict_path': './ppocr/utils/en_dict.txt'
}, },
'french': { 'french': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt' 'dict_path': './ppocr/utils/dict/french_dict.txt'
}, },
'german': { 'german': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt' 'dict_path': './ppocr/utils/dict/german_dict.txt'
}, },
'korean': { 'korean': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt' 'dict_path': './ppocr/utils/dict/korean_dict.txt'
}, },
'japan': { 'japan': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt' 'dict_path': './ppocr/utils/dict/japan_dict.txt'
}, },
'chinese_cht': { 'chinese_cht': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt' 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
}, },
'ta': { 'ta': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt' 'dict_path': './ppocr/utils/dict/ta_dict.txt'
}, },
'te': { 'te': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt' 'dict_path': './ppocr/utils/dict/te_dict.txt'
}, },
'ka': { 'ka': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt' 'dict_path': './ppocr/utils/dict/ka_dict.txt'
}, },
'latin': { 'latin': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt' 'dict_path': './ppocr/utils/dict/latin_dict.txt'
}, },
'arabic': { 'arabic': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt' 'dict_path': './ppocr/utils/dict/arabic_dict.txt'
}, },
'cyrillic': { 'cyrillic': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt' 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
}, },
'devanagari': { 'devanagari': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt' 'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
} }
}, },
'cls': 'cls':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar' 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
} }
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
...@@ -123,50 +122,6 @@ SUPPORT_REC_MODEL = ['CRNN'] ...@@ -123,50 +122,6 @@ SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
def download_with_progressbar(url, save_path):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def parse_args(mMain=True): def parse_args(mMain=True):
import argparse import argparse
parser = init_args() parser = init_args()
...@@ -194,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -194,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem):
args: args:
**kwargs: other params show in paddleocr --help **kwargs: other params show in paddleocr --help
""" """
postprocess_params = parse_args(mMain=False) params = parse_args(mMain=False)
postprocess_params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls if not params.show_log:
lang = postprocess_params.lang logger.setLevel(logging.INFO)
self.use_angle_cls = params.use_angle_cls
lang = params.lang
latin_lang = [ latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
...@@ -223,46 +180,45 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -223,46 +180,45 @@ class PaddleOCR(predict_system.TextSystem):
lang = "devanagari" lang = "devanagari"
assert lang in model_urls[ assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format( 'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang) model_urls['rec'].keys(), lang)
if lang == "ch": if lang == "ch":
det_lang = "ch" det_lang = "ch"
else: else:
det_lang = "en" det_lang = "en"
use_inner_dict = False use_inner_dict = False
if postprocess_params.rec_char_dict_path is None: if params.rec_char_dict_path is None:
use_inner_dict = True use_inner_dict = True
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path'] 'dict_path']
# init model dir # init model dir
if postprocess_params.det_model_dir is None: params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION, os.path.join(BASE_DIR, VERSION, 'det', det_lang),
'det', det_lang) model_urls['det'][det_lang])
if postprocess_params.rec_model_dir is None: params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION, os.path.join(BASE_DIR, VERSION, 'rec', lang),
'rec', lang) model_urls['rec'][lang]['url'])
if postprocess_params.cls_model_dir is None: params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls') os.path.join(BASE_DIR, VERSION, 'cls'),
print(postprocess_params) model_urls['cls'])
# download model # download model
maybe_download(postprocess_params.det_model_dir, maybe_download(params.det_model_dir, det_url)
model_urls['det'][det_lang]) maybe_download(params.rec_model_dir, rec_url)
maybe_download(postprocess_params.rec_model_dir, maybe_download(params.cls_model_dir, cls_url)
model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
sys.exit(0) sys.exit(0)
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: if params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0) sys.exit(0)
if use_inner_dict: if use_inner_dict:
postprocess_params.rec_char_dict_path = str( params.rec_char_dict_path = str(
Path(__file__).parent / postprocess_params.rec_char_dict_path) Path(__file__).parent / params.rec_char_dict_path)
print(params)
# init det_model and rec_model # init det_model and rec_model
super().__init__(postprocess_params) super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=True): def ocr(self, img, det=True, rec=True, cls=True):
""" """
...@@ -320,7 +276,7 @@ def main(): ...@@ -320,7 +276,7 @@ def main():
# for cmd # for cmd
args = parse_args(mMain=True) args = parse_args(mMain=True)
image_dir = args.image_dir image_dir = args.image_dir
if image_dir.startswith('http'): if is_link(image_dir):
download_with_progressbar(image_dir, 'tmp.jpg') download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg'] image_file_list = ['tmp.jpg']
else: else:
......
...@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators ...@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp) ...@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
......
...@@ -30,6 +30,7 @@ from .label_ops import * ...@@ -30,6 +30,7 @@ from .label_ops import *
from .east_process import * from .east_process import *
from .sast_process import * from .sast_process import *
from .pg_process import * from .pg_process import *
from .gen_table_mask import *
def transform(data, ops=None): def transform(data, ops=None):
......
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class GenTableMask(object):
""" gen table mask """
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
self.shrink_h_max = 5
self.shrink_w_max = 5
self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
return box_list, projection_map
def projection_cx(self, box_img):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
erode = cv2.erode(thresh1, kernel, iterations=1)
else:
erode = thresh1
# 水平膨胀
kernel = np.ones((1, 5), np.uint8)
erosion = cv2.dilate(erode, kernel, iterations=1)
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
split_bbox_list = []
if len(box_list) > 1:
for i, (h_start, h_end) in enumerate(box_list):
if i == 0:
h_start = 0
if i == len(box_list):
h_end = h
word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
h_end += 1
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
split_bbox_list.append([w_start, h_start, w_end, h_end])
else:
split_bbox_list.append([0, 0, w, h])
return split_bbox_list
def shrink_bbox(self, bbox):
left, top, right, bottom = bbox
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
left_new = left + sh_w
right_new = right - sh_w
top_new = top + sh_h
bottom_new = bottom - sh_h
if left_new >= right_new:
left_new = left
right_new = right
if top_new >= bottom_new:
top_new = top
bottom_new = bottom
return [left_new, top_new, right_new, bottom_new]
def __call__(self, data):
img = data['image']
cells = data['cells']
height, width = img.shape[0:2]
if self.mask_type == 1:
mask_img = np.zeros((height, width), dtype=np.float32)
else:
mask_img = np.zeros((height, width, 3), dtype=np.float32)
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
left, top, right, bottom = bbox
box_img = img[top:bottom, left:right, :].copy()
split_bbox_list = self.projection_cx(box_img)
for sno in range(len(split_bbox_list)):
split_bbox_list[sno][0] += left
split_bbox_list[sno][1] += top
split_bbox_list[sno][2] += left
split_bbox_list[sno][3] += top
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img
else:
mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img
return data
class ResizeTableImage(object):
def __init__(self, max_len, **kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
def get_img_bbox(self, cells):
bbox_list = []
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = []
for bno in range(len(bbox_list)):
left, top, right, bottom = bbox_list[bno].copy()
left = int(left * ratio)
top = int(top * ratio)
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
data['max_len'] = self.max_len
return data
class PaddingTableImage(object):
def __init__(self, **kwargs):
super(PaddingTableImage, self).__init__()
def __call__(self, data):
img = data['image']
max_len = data['max_len']
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img
return data
\ No newline at end of file
...@@ -351,3 +351,162 @@ class SRNLabelEncode(BaseRecLabelEncode): ...@@ -351,3 +351,162 @@ class SRNLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end % beg_or_end
return idx return idx
class TableLabelEncode(object):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
span_weight = 1.0,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
for i, char in enumerate(list_character):
self.dict_character[char] = i
self.dict_elem = {}
for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i
self.span_weight = span_weight
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1+character_num):
character = lines[cno].decode('utf-8').strip("\n")
list_character.append(character)
for eno in range(1+character_num, 1+character_num+elem_num):
elem = lines[eno].decode('utf-8').strip("\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_span_idx_list(self):
span_idx_list = []
for elem in self.dict_elem:
if 'span' in elem:
span_idx_list.append(self.dict_elem[elem])
return span_idx_list
def __call__(self, data):
cells = data['cells']
structure = data['structure']['tokens']
structure = self.encode(structure, 'elem')
if structure is None:
return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1]
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
structure = np.array(structure)
data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
span_idx_list = self.get_span_idx_list()
td_idx_list = np.logical_or(structure == elem_char_idx1, structure == elem_char_idx2)
td_idx_list = np.where(td_idx_list)[0]
structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
img_height, img_width, img_ch = data['image'].shape
if len(span_idx_list) > 0:
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
span_weight = min(max(span_weight, 1.0), self.span_weight)
for cno in range(len(cells)):
if 'bbox' in cells[cno]:
bbox = cells[cno]['bbox'].copy()
bbox[0] = bbox[0] * 1.0 / img_width
bbox[1] = bbox[1] * 1.0 / img_height
bbox[2] = bbox[2] * 1.0 / img_width
bbox[3] = bbox[3] * 1.0 / img_height
td_idx = td_idx_list[cno]
bbox_list[td_idx] = bbox
bbox_list_mask[td_idx] = 1.0
cand_span_idx = td_idx + 1
if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight
data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask
data['structure_mask'] = structure_mask
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num])
return data
def encode(self, text, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
max_len = self.max_text_length
current_dict = self.dict_character
else:
max_len = self.max_elem_length
current_dict = self.dict_elem
if len(text) > max_len:
return None
if len(text) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
text_list = []
for char in text:
if char not in current_dict:
return None
text_list.append(current_dict[char])
if len(text_list) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
return text_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = np.array(self.dict_character[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_character[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = np.array(self.dict_elem[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_elem[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
\ No newline at end of file
...@@ -81,7 +81,7 @@ class NormalizeImage(object): ...@@ -81,7 +81,7 @@ class NormalizeImage(object):
assert isinstance(img, assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage" np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = ( data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std img.astype('float32') * self.scale - self.mean) / self.std
return data return data
...@@ -163,7 +163,7 @@ class DetResizeForTest(object): ...@@ -163,7 +163,7 @@ class DetResizeForTest(object):
img, (ratio_h, ratio_w) img, (ratio_h, ratio_w)
""" """
limit_side_len = self.limit_side_len limit_side_len = self.limit_side_len
h, w, _ = img.shape h, w, c = img.shape
# limit the max side # limit the max side
if self.limit_type == 'max': if self.limit_type == 'max':
...@@ -174,7 +174,7 @@ class DetResizeForTest(object): ...@@ -174,7 +174,7 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w ratio = float(limit_side_len) / w
else: else:
ratio = 1. ratio = 1.
else: elif self.limit_type == 'min':
if min(h, w) < limit_side_len: if min(h, w) < limit_side_len:
if h < w: if h < w:
ratio = float(limit_side_len) / h ratio = float(limit_side_len) / h
...@@ -182,6 +182,10 @@ class DetResizeForTest(object): ...@@ -182,6 +182,10 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w ratio = float(limit_side_len) / w
else: else:
ratio = 1. ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h,w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio) resize_h = int(h * ratio)
resize_w = int(w * ratio) resize_w = int(w * ratio)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment