Unverified Commit 0458f0cc authored by zhoujun's avatar zhoujun Committed by GitHub
Browse files

Merge pull request #13 from PaddlePaddle/dygraph

Dygraph
parents 04b0318b 836839bb
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
f
e
n
c
h
_
i
m
g
/
r
v
a
l
t
w
o
d
6
1
.
p
B
u
2
à
3
R
y
4
U
E
A
5
P
O
S
T
D
7
Z
8
I
N
L
G
M
H
0
J
K
-
9
F
C
V
é
X
'
s
Q
:
è
x
b
Y
Œ
É
z
W
Ç
È
k
Ô
ô
À
Ê
q
ù
°
ê
î
*
Â
j
"
,
â
%
û
ç
ü
?
!
;
ö
(
)
ï
º
ó
ø
å
+
á
Ë
<
²
Á
Î
&
@
œ
ε
Ü
ë
[
]
í
ò
Ö
ä
ß
«
»
ú
ñ
æ
µ
³
Å
$
#
!
"
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
]
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
¡
¢
£
¤
¥
¦
§
¨
©
ª
«
¬
­
®
¯
°
±
²
³
´
µ
·
¸
¹
º
»
¼
½
¿
Â
Ã
Å
Ê
Î
Ð
á
â
å
æ
é
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
]
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
©
°
²
´
½
Á
Ä
Å
Ç
È
É
Í
Ó
Ö
×
Ü
ß
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
í
ð
ñ
ò
ó
ô
õ
ö
ø
ú
û
ü
ý
ā
ă
ą
ć
Č
č
đ
ē
ė
ę
ğ
ī
ı
Ł
ł
ń
ň
ō
ř
Ş
ş
Š
š
ţ
ū
ż
Ž
ž
Ș
ș
ț
Δ
α
λ
μ
φ
Г
О
а
в
л
о
р
с
т
я
 
丿
使
便
姿
婿
宿
寿
尿
廿
忿
椿
槿
橿
殿
沿
湿
滿
漿
禿
稿
竿
簿
綿
耀
西
調
谿
貿
輿
退
駿
鹿
麿
!
"
#
$
%
&
'
*
+
-
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
©
°
²
½
Á
Ä
Å
Ç
É
Í
Î
Ó
Ö
×
Ü
ß
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
ì
í
î
ï
ð
ñ
ò
ó
ô
õ
ö
ø
ú
û
ü
ý
ā
ă
ą
ć
Č
č
đ
ē
ė
ę
ě
ğ
ī
İ
ı
Ł
ł
ń
ň
ō
ř
Ş
ş
Š
š
ţ
ū
ź
ż
Ž
ž
Ș
ș
Α
Δ
α
λ
φ
Г
О
а
в
л
о
р
с
т
я
使
便
尿
彿
殿
沿
滿
西
調
輿
鹿
굿
꼿
릿
믿
퀀
......@@ -55,8 +55,8 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
weight_name = weight_name.replace('binarize', '').replace(
'thresh', '') # for DB
if weight_name in pre_state_dict.keys():
logger.info('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape))
# logger.info('Load weight: {}, shape: {}'.format(
# weight_name, pre_state_dict[weight_name].shape))
if 'encoder_rnn' in key:
# delete axis which is 1
pre_state_dict[weight_name] = pre_state_dict[
......
......@@ -32,7 +32,7 @@ setup(
package_dir={'paddleocr': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
version='0.0.3',
version='2.0',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
......
......@@ -39,26 +39,12 @@ def parse_args():
return parser.parse_args()
class Model(paddle.nn.Layer):
def __init__(self, model):
super(Model, self).__init__()
self.pre_model = model
# Please modify the 'shape' according to actual needs
@to_static(input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 640, 640], dtype='float32')
])
def forward(self, inputs):
x = self.pre_model(inputs)
return x
def main():
FLAGS = parse_args()
config = load_config(FLAGS.config)
logger = get_logger()
# build post process
post_process_class = build_post_process(config['PostProcess'],
config['Global'])
......@@ -71,9 +57,16 @@ def main():
init_model(config, model, logger)
model.eval()
model = Model(model)
save_path = '{}/{}'.format(FLAGS.output_path,
save_path = '{}/{}/inference'.format(FLAGS.output_path,
config['Architecture']['model_type'])
infer_shape = [3, 32, 100] if config['Architecture'][
'model_type'] != "det" else [3, 640, 640]
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
......@@ -30,12 +31,15 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.utility import draw_ocr_box_txt
logger = get_logger()
class TextSystem(object):
def __init__(self, args):
self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args)
......@@ -103,7 +107,13 @@ class TextSystem(object):
logger.info("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
return dt_boxes, rec_res
filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_reuslt)
return filter_boxes, filter_rec_res
def sorted_boxes(dt_boxes):
......@@ -119,7 +129,7 @@ def sorted_boxes(dt_boxes):
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
......@@ -145,12 +155,8 @@ def main(args):
elapse = time.time() - starttime
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
dt_num = len(dt_boxes)
for dno in range(dt_num):
text, score = rec_res[dno]
if score >= drop_score:
text_str = "%s, %.3f" % (text, score)
logger.info(text_str)
for text, score in rec_res:
logger.info("{}, {:.3f}".format(text, score))
if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
......@@ -176,5 +182,4 @@ def main(args):
if __name__ == "__main__":
logger = get_logger()
main(utility.parse_args())
\ No newline at end of file
......@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger):
if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir))
sys.exit(0)
model_file_path = model_dir + "/model"
params_file_path = model_dir + "/params"
model_file_path = model_dir + ".pdmodel"
params_file_path = model_dir + ".pdiparams"
if not os.path.exists(model_file_path):
logger.info("not find model file path {}".format(model_file_path))
sys.exit(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment