Commit 037e17fc authored by WenmuZhou's avatar WenmuZhou
Browse files

merge dygraph

parent 6127aad9
include LICENSE.txt
include LICENSE
include README.md
recursive-include ppocr/utils *.txt utility.py logging.py
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
......
......@@ -19,17 +19,16 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import cv2
import logging
import numpy as np
from pathlib import Path
import tarfile
import requests
from tqdm import tqdm
from tools.infer import predict_system
from ppocr.utils.logging import get_logger
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar
from tools.infer.utility import draw_ocr, init_args, str2bool
__all__ = ['PaddleOCR']
......@@ -37,84 +36,84 @@ __all__ = ['PaddleOCR']
model_urls = {
'det': {
'ch':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'en':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
},
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt'
},
'french': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
},
'te': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt'
},
'ka': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt'
},
'latin': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt'
},
'arabic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
},
'cyrillic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
}
},
'cls':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
}
SUPPORT_DET_MODEL = ['DB']
......@@ -123,50 +122,6 @@ SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
def download_with_progressbar(url, save_path):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def parse_args(mMain=True):
import argparse
parser = init_args()
......@@ -194,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem):
args:
**kwargs: other params show in paddleocr --help
"""
postprocess_params = parse_args(mMain=False)
postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if params.show_log:
logger.setLevel(logging.DEBUG)
self.use_angle_cls = params.use_angle_cls
lang = params.lang
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
......@@ -223,46 +180,46 @@ class PaddleOCR(predict_system.TextSystem):
lang = "devanagari"
assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang)
model_urls['rec'].keys(), lang)
if lang == "ch":
det_lang = "ch"
else:
det_lang = "en"
use_inner_dict = False
if postprocess_params.rec_char_dict_path is None:
if params.rec_char_dict_path is None:
use_inner_dict = True
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path']
# init model dir
if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
if params.det_model_dir is None:
params.det_model_dir = os.path.join(BASE_DIR, VERSION,
'det', det_lang)
if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
if params.rec_model_dir is None:
params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
'rec', lang)
if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
print(postprocess_params)
if params.cls_model_dir is None:
params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
# download model
maybe_download(postprocess_params.det_model_dir,
maybe_download(params.det_model_dir,
model_urls['det'][det_lang])
maybe_download(postprocess_params.rec_model_dir,
maybe_download(params.rec_model_dir,
model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
maybe_download(params.cls_model_dir, model_urls['cls'])
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
sys.exit(0)
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
if params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0)
if use_inner_dict:
postprocess_params.rec_char_dict_path = str(
Path(__file__).parent / postprocess_params.rec_char_dict_path)
params.rec_char_dict_path = str(
Path(__file__).parent / params.rec_char_dict_path)
print(params)
# init det_model and rec_model
super().__init__(postprocess_params)
super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=True):
"""
......
......@@ -81,7 +81,7 @@ class NormalizeImage(object):
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
img.astype('float32') * self.scale - self.mean) / self.std
return data
......@@ -163,7 +163,7 @@ class DetResizeForTest(object):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, _ = img.shape
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
......@@ -174,7 +174,7 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w
else:
ratio = 1.
else:
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
......@@ -182,6 +182,10 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h,w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
......
......@@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = ""
self.character_str = []
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
self.character_str.append(line)
if use_space_char:
self.character_str += " "
self.character_str.append(" ")
dict_character = list(self.character_str)
else:
......@@ -319,3 +319,156 @@ class SRNLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class TableLabelDecode(object):
""" """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_sp_tokens(self):
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num])
return sp_tokens
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs,paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds,paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
</overline>
α

$
ω
ψ
χ
(
υ
σ
,
ρ
ε
0
4
8
b
<
Ψ
Ω
D
3
Π
H
</strike>
L
Φ
Χ
θ
P
κ
λ
μ
T
ξ
X
β
γ
δ
\
ζ
η
`
d
<strike>
h
f
l
Θ
p
t
</sub>
x
Β
Γ
Δ
|
ǂ
ɛ
j
̧
̌
«
#
</b>
'
Ι
+
/
·
7
;
?
C
÷
G
K
<sup>
O
S
С
W
Α
[
_
c
z
g
<i>
o
<sub>
s
w
φ
ʹ
{
»
̆
e
ˆ
τ
ι
Ø
ß
×
˃
˂
"
i
&
π
*
æ
.
ø
Q
6
:
>
a
B
F
J
̄
N
R
V
<overline>
Z
^
¤
¥
§
<underline>
¢
£
­
Λ
©
n
r
°
±
v
<b>
k
~
̇
@
ł
®
!
</sup>
%
)
-
1
5
9
=
А
A
Σ
E
I
M
m
̨
</i>
U
Y
]
̸
2
̂
̀
́
̊
̈
q
u
ı
y
</underline>
̃
}
ν
This diff is collapsed.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tarfile
import requests
from tqdm import tqdm
from ppocr.utils.logging import get_logger
def download_with_progressbar(url, save_path):
logger = get_logger()
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
assert url.endswith('.tar'), 'Only supports tar compressed package'
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def is_link(s):
return s is not None and s.startswith('http')
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
url = default_url
if model_dir is None or is_link(model_dir):
if is_link(model_dir):
url = model_dir
file_name = url.split('/')[-1][:-4]
model_dir = default_model_dir
model_dir = os.path.join(model_dir, file_name)
return model_dir, url
include LICENSE
include README.md
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
recursive-include ppstructure *.py
# PaddleStructure
## pipeline介绍
PaddleStructure 是一个用于复杂板式文字OCR的工具包,流程如下
![pipeline](../doc/table/pipeline.png)
在PaddleStructure中,图片会先经由layoutparser进行版面分析,在版面分析中,会对图片里的区域进行分类,根据根据类别进行对于的ocr流程。
目前layoutparser会输出五个类别:
1. Text
2. Title
3. Figure
4. List
5. Table
1-4类走传统的OCR流程,5走表格的OCR流程。
## LayoutParser
## Table OCR
[文档](table/README_ch.md)
## PaddleStructure whl包介绍
### 使用
1. 代码使用
```python
import cv2
from paddlestructure import PaddleStructure,draw_result
table_engine = PaddleStructure(
output='./output/table',
show_log=True)
img_path = '../doc/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
for line in result:
print(line)
from PIL import Image
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
2. 命令行使用
```bash
paddlestructure --image_dir=../doc/table/1.png
```
### 参数说明
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .paddlestructure import PaddleStructure, draw_result, to_excel
__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import cv2
import numpy as np
from pathlib import Path
from ppocr.utils.logging import get_logger
from ppstructure.predict_system import OCRSystem, save_res
from ppstructure.table.predict_table import to_excel
from ppstructure.utility import init_args, draw_result
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link
__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
VERSION = '2.1'
BASE_DIR = os.path.expanduser("~/.paddlestructure/")
model_urls = {
'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar',
'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar'
}
def parse_args(mMain=True):
import argparse
parser = init_args()
parser.add_help = mMain
for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']:
action.default = None
if mMain:
return parser.parse_args()
else:
inference_args_dict = {}
for action in parser._actions:
inference_args_dict[action.dest] = action.default
return argparse.Namespace(**inference_args_dict)
class PaddleStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if params.show_log:
logger.setLevel(logging.DEBUG)
params.use_angle_cls = False
# init model dir
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'det'),
model_urls['det'])
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'rec'),
model_urls['rec'])
params.structure_model_dir, structure_url = confirm_model_dir_url(params.structure_model_dir,
os.path.join(BASE_DIR, VERSION, 'structure'),
model_urls['structure'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.structure_model_dir, structure_url)
if params.rec_char_dict_path is None:
params.rec_char_type = 'EN'
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
else:
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
if params.structure_char_dict_path is None:
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
params.structure_char_dict_path = str(
Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
else:
params.structure_char_dict_path = str(
Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
print(params)
super().__init__(params)
def __call__(self, img):
if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img
img, flag = check_and_read_gif(image_file)
if not flag:
with open(image_file, 'rb') as f:
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
res = super().__call__(img)
return res
def main():
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
save_folder = args.output
if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir))
return
structure_engine = PaddleStructure(**(args.__dict__))
for img_path in image_file_list:
img_name = os.path.basename(img_path).split('.')[0]
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = structure_engine(img_path)
for item in result:
logger.info(item['res'])
save_res(result, save_folder, img_name)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import time
import layoutparser as lp
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args,draw_result
logger = get_logger()
class OCRSystem(object):
def __init__(self, args):
args.det_limit_type = 'resize_long'
args.drop_score = 0
self.text_system = TextSystem(args)
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
def __call__(self, img):
ori_im = img.copy()
layout_res = self.table_layout.detect(img[..., ::-1])
res_list = []
for region in layout_res:
x1, y1, x2, y2 = region.coordinates
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
if region.type == 'Table':
res = self.table_system(roi_img)
else:
filter_boxes, filter_rec_res = self.text_system(roi_img)
filter_boxes = [x + [x1, y1] for x in filter_boxes]
filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
res = (filter_boxes, filter_rec_res)
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
return res_list
def save_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['type'] == 'Table':
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
elif region['type'] == 'Figure':
pass
else:
with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
for box, rec_res in zip(region['res'][0], region['res'][1]):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list
image_file_list = image_file_list[args.process_id::args.total_process_num]
save_folder = args.output
os.makedirs(save_folder, exist_ok=True)
structure_sys = OCRSystem(args)
img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
img_name = os.path.basename(image_file).split('.')[0]
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
res = structure_sys(img)
save_res(res, save_folder, img_name)
draw_img = draw_result(img,res, args.vis_font_path)
cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
if __name__ == "__main__":
args = parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = [sys.executable, "-u"] + sys.argv + [
"--process_id={}".format(process_id),
"--use_mp={}".format(False)
]
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
else:
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from setuptools import setup
from io import open
import shutil
with open('../requirements.txt', encoding="utf-8-sig") as f:
requirements = f.readlines()
requirements.append('tqdm')
requirements.append('layoutparser')
requirements.append('iopath')
def readme():
with open('README_ch.md', encoding="utf-8-sig") as f:
README = f.read()
return README
shutil.copytree('../ppstructure/table', './ppstructure/table')
shutil.copyfile('../ppstructure/predict_system.py', './ppstructure/predict_system.py')
shutil.copyfile('../ppstructure/utility.py', './ppstructure/utility.py')
shutil.copytree('../ppocr', './ppocr')
shutil.copytree('../tools', './tools')
shutil.copyfile('../LICENSE', './LICENSE')
setup(
name='paddlestructure',
packages=['paddlestructure'],
package_dir={'paddlestructure': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]},
version='1.0',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
long_description=readme(),
long_description_content_type='text/markdown',
url='https://github.com/PaddlePaddle/PaddleOCR',
download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
keywords=[
'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
],
classifiers=[
'Intended Audience :: Developers', 'Operating System :: OS Independent',
'Natural Language :: Chinese (Simplified)',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.2',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
], )
shutil.rmtree('ppocr')
shutil.rmtree('tools')
shutil.rmtree('ppstructure')
os.remove('LICENSE')
# 表格结构和内容预测
## pipeline
表格的ocr主要包含三个模型
1. 单行文本检测-DB
2. 单行文本识别-CRNN
3. 表格结构和cell坐标预测-RARE
具体流程图如下
![tableocr_pipeline](../../doc/table/tableocr_pipeline.png)
1. 图片由单行文字检测检测到单行文字的坐标,然后送入识别模型拿到识别结果。
2. 图片由表格结构和cell坐标预测拿到表格的结构信息和单元格的坐标信息。
3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。
4. 单元格的识别结果和表格结构一起构造表格的html字符串。
## 使用
### 训练
TBD
### 评估
先cd到PaddleOCR/ppstructure目录下
表格使用 TEDS(Tree-Edit-Distance-based Similarity) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
```json
{"PMC4289340_004_00.png": [["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"], [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], [["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]]}
```
示例对应的表格如下
![tableocr_pipeline](../../doc/table/table_example.png)
准备完成后使用如下命令进行评估,评估完成后会输出teds指标。
```python
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --structure_model_dir=path/to/structure_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
### 预测
先cd到PaddleOCR/ppstructure目录下
```python
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --structure_model_dir=path/to/structure_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment