Unverified Commit 0458f0cc authored by zhoujun's avatar zhoujun Committed by GitHub
Browse files

Merge pull request #13 from PaddlePaddle/dygraph

Dygraph
parents 04b0318b 836839bb
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
f
e
n
c
h
_
i
m
g
/
r
v
a
l
t
w
o
d
6
1
.
p
B
u
2
à
3
R
y
4
U
E
A
5
P
O
S
T
D
7
Z
8
I
N
L
G
M
H
0
J
K
-
9
F
C
V
é
X
'
s
Q
:
è
x
b
Y
Œ
É
z
W
Ç
È
k
Ô
ô
À
Ê
q
ù
°
ê
î
*
Â
j
"
,
â
%
û
ç
ü
?
!
;
ö
(
)
ï
º
ó
ø
å
+
á
Ë
<
²
Á
Î
&
@
œ
ε
Ü
ë
[
]
í
ò
Ö
ä
ß
«
»
ú
ñ
æ
µ
³
Å
$
#
!
"
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
]
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
¡
¢
£
¤
¥
¦
§
¨
©
ª
«
¬
­
®
¯
°
±
²
³
´
µ
·
¸
¹
º
»
¼
½
¿
Â
Ã
Å
Ê
Î
Ð
á
â
å
æ
é
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
]
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
©
°
²
´
½
Á
Ä
Å
Ç
È
É
Í
Ó
Ö
×
Ü
ß
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
í
ð
ñ
ò
ó
ô
õ
ö
ø
ú
û
ü
ý
ā
ă
ą
ć
Č
č
đ
ē
ė
ę
ğ
ī
ı
Ł
ł
ń
ň
ō
ř
Ş
ş
Š
š
ţ
ū
ż
Ž
ž
Ș
ș
ț
Δ
α
λ
μ
φ
Г
О
а
в
л
о
р
с
т
я
 
丿
使
便
姿
婿
宿
寿
尿
廿
忿
椿
槿
橿
殿
沿
湿
滿
漿
禿
稿
竿
簿
綿
耀
西
調
谿
貿
輿
退
駿
鹿
麿
!
"
#
$
%
&
'
*
+
-
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
©
°
²
½
Á
Ä
Å
Ç
É
Í
Î
Ó
Ö
×
Ü
ß
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
ì
í
î
ï
ð
ñ
ò
ó
ô
õ
ö
ø
ú
û
ü
ý
ā
ă
ą
ć
Č
č
đ
ē
ė
ę
ě
ğ
ī
İ
ı
Ł
ł
ń
ň
ō
ř
Ş
ş
Š
š
ţ
ū
ź
ż
Ž
ž
Ș
ș
Α
Δ
α
λ
φ
Г
О
а
в
л
о
р
с
т
я
使
便
尿
彿
殿
沿
滿
西
調
輿
鹿
굿
꼿
릿
믿
퀀
...@@ -55,8 +55,8 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): ...@@ -55,8 +55,8 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
weight_name = weight_name.replace('binarize', '').replace( weight_name = weight_name.replace('binarize', '').replace(
'thresh', '') # for DB 'thresh', '') # for DB
if weight_name in pre_state_dict.keys(): if weight_name in pre_state_dict.keys():
logger.info('Load weight: {}, shape: {}'.format( # logger.info('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape)) # weight_name, pre_state_dict[weight_name].shape))
if 'encoder_rnn' in key: if 'encoder_rnn' in key:
# delete axis which is 1 # delete axis which is 1
pre_state_dict[weight_name] = pre_state_dict[ pre_state_dict[weight_name] = pre_state_dict[
......
...@@ -32,7 +32,7 @@ setup( ...@@ -32,7 +32,7 @@ setup(
package_dir={'paddleocr': ''}, package_dir={'paddleocr': ''},
include_package_data=True, include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]}, entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
version='0.0.3', version='2.0',
install_requires=requirements, install_requires=requirements,
license='Apache License 2.0', license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
......
...@@ -39,26 +39,12 @@ def parse_args(): ...@@ -39,26 +39,12 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
class Model(paddle.nn.Layer):
def __init__(self, model):
super(Model, self).__init__()
self.pre_model = model
# Please modify the 'shape' according to actual needs
@to_static(input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 640, 640], dtype='float32')
])
def forward(self, inputs):
x = self.pre_model(inputs)
return x
def main(): def main():
FLAGS = parse_args() FLAGS = parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
logger = get_logger() logger = get_logger()
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
config['Global']) config['Global'])
...@@ -71,9 +57,16 @@ def main(): ...@@ -71,9 +57,16 @@ def main():
init_model(config, model, logger) init_model(config, model, logger)
model.eval() model.eval()
model = Model(model) save_path = '{}/{}/inference'.format(FLAGS.output_path,
save_path = '{}/{}'.format(FLAGS.output_path, config['Architecture']['model_type'])
config['Architecture']['model_type']) infer_shape = [3, 32, 100] if config['Architecture'][
'model_type'] != "det" else [3, 640, 640]
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path)) logger.info('inference model is saved to {}'.format(save_path))
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
...@@ -30,12 +31,15 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif ...@@ -30,12 +31,15 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.infer.utility import draw_ocr_box_txt from tools.infer.utility import draw_ocr_box_txt
logger = get_logger()
class TextSystem(object): class TextSystem(object):
def __init__(self, args): def __init__(self, args):
self.text_detector = predict_det.TextDetector(args) self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args) self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
if self.use_angle_cls: if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args) self.text_classifier = predict_cls.TextClassifier(args)
...@@ -103,7 +107,13 @@ class TextSystem(object): ...@@ -103,7 +107,13 @@ class TextSystem(object):
logger.info("rec_res num : {}, elapse : {}".format( logger.info("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse)) len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) # self.print_draw_crop_rec_res(img_crop_list, rec_res)
return dt_boxes, rec_res filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_reuslt)
return filter_boxes, filter_rec_res
def sorted_boxes(dt_boxes): def sorted_boxes(dt_boxes):
...@@ -119,8 +129,8 @@ def sorted_boxes(dt_boxes): ...@@ -119,8 +129,8 @@ def sorted_boxes(dt_boxes):
_boxes = list(sorted_boxes) _boxes = list(sorted_boxes)
for i in range(num_boxes - 1): for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]): (_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i] tmp = _boxes[i]
_boxes[i] = _boxes[i + 1] _boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp _boxes[i + 1] = tmp
...@@ -145,12 +155,8 @@ def main(args): ...@@ -145,12 +155,8 @@ def main(args):
elapse = time.time() - starttime elapse = time.time() - starttime
logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
dt_num = len(dt_boxes) for text, score in rec_res:
for dno in range(dt_num): logger.info("{}, {:.3f}".format(text, score))
text, score = rec_res[dno]
if score >= drop_score:
text_str = "%s, %.3f" % (text, score)
logger.info(text_str)
if is_visualize: if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
...@@ -176,5 +182,4 @@ def main(args): ...@@ -176,5 +182,4 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
logger = get_logger() main(utility.parse_args())
main(utility.parse_args()) \ No newline at end of file
...@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger): ...@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger):
if model_dir is None: if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir)) logger.info("not find {} model file path {}".format(mode, model_dir))
sys.exit(0) sys.exit(0)
model_file_path = model_dir + "/model" model_file_path = model_dir + ".pdmodel"
params_file_path = model_dir + "/params" params_file_path = model_dir + ".pdiparams"
if not os.path.exists(model_file_path): if not os.path.exists(model_file_path):
logger.info("not find model file path {}".format(model_file_path)) logger.info("not find model file path {}".format(model_file_path))
sys.exit(0) sys.exit(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment