Commit 2735e9e3 authored by LDOUBLEV's avatar LDOUBLEV
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dyg_db

parents 493a7171 52671b7d
doc/imgs_results/det_res_2.jpg

79.5 KB | W: | H:

doc/imgs_results/det_res_2.jpg

77.3 KB | W: | H:

doc/imgs_results/det_res_2.jpg
doc/imgs_results/det_res_2.jpg
doc/imgs_results/det_res_2.jpg
doc/imgs_results/det_res_2.jpg
  • 2-up
  • Swipe
  • Onion skin
doc/joinus.PNG

15.7 KB | W: | H:

doc/joinus.PNG

408 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -35,44 +35,45 @@ __all__ = ['PaddleOCR'] ...@@ -35,44 +35,45 @@ __all__ = ['PaddleOCR']
model_urls = { model_urls = {
'det': 'det':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'rec': { 'rec': {
'ch': { 'ch': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt' 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
}, },
'en': { 'en': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/en/en_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ic15_dict.txt' 'dict_path': './ppocr/utils/dict/en_dict.txt'
}, },
'french': { 'french': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/fr/french_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt' 'dict_path': './ppocr/utils/dict/french_dict.txt'
}, },
'german': { 'german': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/ge/german_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt' 'dict_path': './ppocr/utils/dict/german_dict.txt'
}, },
'korean': { 'korean': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/kr/korean_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt' 'dict_path': './ppocr/utils/dict/korean_dict.txt'
}, },
'japan': { 'japan': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/jp/japan_ppocr_mobile_v1.1_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt' 'dict_path': './ppocr/utils/dict/japan_dict.txt'
} }
}, },
'cls': 'cls':
'https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar' 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
} }
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = 2.0
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
...@@ -94,20 +95,24 @@ def download_with_progressbar(url, save_path): ...@@ -94,20 +95,24 @@ def download_with_progressbar(url, save_path):
def maybe_download(model_storage_directory, url): def maybe_download(model_storage_directory, url):
# using custom model # using custom model
if not os.path.exists(os.path.join( tar_file_name_list = [
model_storage_directory, 'model')) or not os.path.exists( 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
os.path.join(model_storage_directory, 'params')): ]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path)) print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True) os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path) download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj: with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers(): for member in tarObj.getmembers():
if "model" in member.name: filename = None
filename = 'model' for tar_file_name in tar_file_name_list:
elif "params" in member.name: if tar_file_name in member.name:
filename = 'params' filename = tar_file_name
else: if filename is None:
continue continue
file = tarObj.extractfile(member) file = tarObj.extractfile(member)
with open( with open(
...@@ -176,43 +181,43 @@ def parse_args(mMain=True, add_help=True): ...@@ -176,43 +181,43 @@ def parse_args(mMain=True, add_help=True):
parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--use_angle_cls", type=str2bool, default=False)
return parser.parse_args() return parser.parse_args()
else: else:
return argparse.Namespace(use_gpu=True, return argparse.Namespace(
ir_optim=True, use_gpu=True,
use_tensorrt=False, ir_optim=True,
gpu_mem=8000, use_tensorrt=False,
image_dir='', gpu_mem=8000,
det_algorithm='DB', image_dir='',
det_model_dir=None, det_algorithm='DB',
det_limit_side_len=960, det_model_dir=None,
det_limit_type='max', det_limit_side_len=960,
det_db_thresh=0.3, det_limit_type='max',
det_db_box_thresh=0.5, det_db_thresh=0.3,
det_db_unclip_ratio=2.0, det_db_box_thresh=0.5,
det_east_score_thresh=0.8, det_db_unclip_ratio=2.0,
det_east_cover_thresh=0.1, det_east_score_thresh=0.8,
det_east_nms_thresh=0.2, det_east_cover_thresh=0.1,
rec_algorithm='CRNN', det_east_nms_thresh=0.2,
rec_model_dir=None, rec_algorithm='CRNN',
rec_image_shape="3, 32, 320", rec_model_dir=None,
rec_char_type='ch', rec_image_shape="3, 32, 320",
rec_batch_num=30, rec_char_type='ch',
max_text_length=25, rec_batch_num=30,
rec_char_dict_path=None, max_text_length=25,
use_space_char=True, rec_char_dict_path=None,
drop_score=0.5, use_space_char=True,
cls_model_dir=None, drop_score=0.5,
cls_image_shape="3, 48, 192", cls_model_dir=None,
label_list=['0', '180'], cls_image_shape="3, 48, 192",
cls_batch_num=30, label_list=['0', '180'],
cls_thresh=0.9, cls_batch_num=30,
enable_mkldnn=False, cls_thresh=0.9,
use_zero_copy_run=False, enable_mkldnn=False,
use_pdserving=False, use_zero_copy_run=False,
lang='ch', use_pdserving=False,
det=True, lang='ch',
rec=True, det=True,
use_angle_cls=False rec=True,
) use_angle_cls=False)
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
...@@ -228,19 +233,21 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -228,19 +233,21 @@ class PaddleOCR(predict_system.TextSystem):
lang = postprocess_params.lang lang = postprocess_params.lang
assert lang in model_urls[ assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format( 'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang) model_urls['rec'].keys(), lang)
if postprocess_params.rec_char_dict_path is None: if postprocess_params.rec_char_dict_path is None:
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path'] 'dict_path']
# init model dir # init model dir
if postprocess_params.det_model_dir is None: if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det') postprocess_params.det_model_dir = os.path.join(
BASE_DIR, '{}/det'.format(VERSION))
if postprocess_params.rec_model_dir is None: if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join( postprocess_params.rec_model_dir = os.path.join(
BASE_DIR, 'rec/{}'.format(lang)) BASE_DIR, '{}/rec/{}'.format(VERSION, lang))
if postprocess_params.cls_model_dir is None: if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls') postprocess_params.cls_model_dir = os.path.join(
BASE_DIR, '{}/cls'.format(VERSION))
print(postprocess_params) print(postprocess_params)
# download model # download model
maybe_download(postprocess_params.det_model_dir, model_urls['det']) maybe_download(postprocess_params.det_model_dir, model_urls['det'])
......
...@@ -32,9 +32,8 @@ class ClsMetric(object): ...@@ -32,9 +32,8 @@ class ClsMetric(object):
def get_metric(self): def get_metric(self):
""" """
return metircs { return metrics {
'acc': 0, 'acc': 0
'norm_edit_dis': 0,
} }
""" """
acc = self.correct_num / self.all_num acc = self.correct_num / self.all_num
......
...@@ -57,7 +57,7 @@ class DetMetric(object): ...@@ -57,7 +57,7 @@ class DetMetric(object):
def get_metric(self): def get_metric(self):
""" """
return metircs { return metrics {
'precision': 0, 'precision': 0,
'recall': 0, 'recall': 0,
'hmean': 0 'hmean': 0
......
...@@ -43,7 +43,7 @@ class RecMetric(object): ...@@ -43,7 +43,7 @@ class RecMetric(object):
def get_metric(self): def get_metric(self):
""" """
return metircs { return metrics {
'acc': 0, 'acc': 0,
'norm_edit_dis': 0, 'norm_edit_dis': 0,
} }
......
...@@ -40,7 +40,7 @@ class DBPostProcess(object): ...@@ -40,7 +40,7 @@ class DBPostProcess(object):
self.max_candidates = max_candidates self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.dilation_kernel = None if not use_dilation else [[1, 1], [1, 1]] self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' '''
......
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
...@@ -132,4 +132,5 @@ j ...@@ -132,4 +132,5 @@ j
³ ³
Å Å
$ $
# #
\ No newline at end of file
...@@ -123,4 +123,5 @@ z ...@@ -123,4 +123,5 @@ z
â â
å å
æ æ
é é
\ No newline at end of file
...@@ -4395,4 +4395,5 @@ z ...@@ -4395,4 +4395,5 @@ z
\ No newline at end of file
...@@ -179,7 +179,7 @@ z ...@@ -179,7 +179,7 @@ z
с с
т т
я я
...@@ -3684,4 +3684,5 @@ z ...@@ -3684,4 +3684,5 @@ z
\ No newline at end of file
...@@ -33,4 +33,4 @@ v ...@@ -33,4 +33,4 @@ v
w w
x x
y y
z z
\ No newline at end of file
...@@ -28,37 +28,16 @@ from ppocr.modeling.architectures import build_model ...@@ -28,37 +28,16 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.program import load_config from tools.program import load_config, merge_config, ArgsParser
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="configuration file to use")
parser.add_argument(
"-o", "--output_path", type=str, default='./output/infer/')
return parser.parse_args()
class Model(paddle.nn.Layer):
def __init__(self, model):
super(Model, self).__init__()
self.pre_model = model
# Please modify the 'shape' according to actual needs
@to_static(input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 640, 640], dtype='float32')
])
def forward(self, inputs):
x = self.pre_model(inputs)
return x
def main(): def main():
FLAGS = parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger() logger = get_logger()
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
config['Global']) config['Global'])
...@@ -71,9 +50,15 @@ def main(): ...@@ -71,9 +50,15 @@ def main():
init_model(config, model, logger) init_model(config, model, logger)
model.eval() model.eval()
model = Model(model) save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
save_path = '{}/{}'.format(FLAGS.output_path, infer_shape = [3, 32, 100] if config['Architecture'][
config['Architecture']['model_type']) 'model_type'] != "det" else [3, 640, 640]
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path)) logger.info('inference model is saved to {}'.format(save_path))
......
...@@ -63,6 +63,7 @@ class TextDetector(object): ...@@ -63,6 +63,7 @@ class TextDetector(object):
postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000 postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = True
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
...@@ -111,7 +112,7 @@ class TextDetector(object): ...@@ -111,7 +112,7 @@ class TextDetector(object):
box = self.clip_det_res(box, img_height, img_width) box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1])) rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3])) rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 10 or rect_height <= 10: if rect_width <= 3 or rect_height <= 3:
continue continue
dt_boxes_new.append(box) dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new) dt_boxes = np.array(dt_boxes_new)
...@@ -186,4 +187,4 @@ if __name__ == "__main__": ...@@ -186,4 +187,4 @@ if __name__ == "__main__":
cv2.imwrite(img_path, src_im) cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path)) logger.info("The visualized image saved in {}".format(img_path))
if count > 1: if count > 1:
logger.info("Avg Time:", total_time / (count - 1)) logger.info("Avg Time: {}".format(total_time / (count - 1)))
...@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger): ...@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger):
if model_dir is None: if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir)) logger.info("not find {} model file path {}".format(mode, model_dir))
sys.exit(0) sys.exit(0)
model_file_path = model_dir + "/model" model_file_path = model_dir + "/inference.pdmodel"
params_file_path = model_dir + "/params" params_file_path = model_dir + "/inference.pdiparams"
if not os.path.exists(model_file_path): if not os.path.exists(model_file_path):
logger.info("not find model file path {}".format(model_file_path)) logger.info("not find model file path {}".format(model_file_path))
sys.exit(0) sys.exit(0)
......
...@@ -113,7 +113,6 @@ def merge_config(config): ...@@ -113,7 +113,6 @@ def merge_config(config):
global_config.keys(), sub_keys[0]) global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]] cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]): for idx, sub_key in enumerate(sub_keys[1:]):
assert (sub_key in cur)
if idx == len(sub_keys) - 2: if idx == len(sub_keys) - 2:
cur[sub_key] = value cur[sub_key] = value
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment