Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
fee2c17b
Unverified
Commit
fee2c17b
authored
Aug 16, 2020
by
MissPenguin
Committed by
GitHub
Aug 16, 2020
Browse files
Merge branch 'develop' into develop
parents
da75ef8b
bad9f6cd
Changes
67
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
70 additions
and
73 deletions
+70
-73
ppocr/utils/utility.py
ppocr/utils/utility.py
+20
-5
tools/export_model.py
tools/export_model.py
+2
-20
tools/infer/predict_det.py
tools/infer/predict_det.py
+9
-3
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+5
-3
tools/infer/predict_system.py
tools/infer/predict_system.py
+29
-17
tools/infer/utility.py
tools/infer/utility.py
+2
-2
tools/train.py
tools/train.py
+3
-23
No files found.
ppocr/utils/utility.py
View file @
fee2c17b
...
...
@@ -14,6 +14,9 @@
import
logging
import
os
import
imghdr
import
cv2
from
paddle
import
fluid
def
initial_logger
():
...
...
@@ -61,19 +64,31 @@ def get_image_file_list(img_file):
if
img_file
is
None
or
not
os
.
path
.
exists
(
img_file
):
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
img_end
=
[
'jpg'
,
'png'
,
'jpeg'
,
'
JPEG
'
,
'
JPG
'
,
'
bmp'
]
if
os
.
path
.
isfile
(
img_file
)
and
img
_file
.
split
(
'.'
)[
-
1
]
in
img_end
:
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'
rgb'
,
'tif'
,
'tiff
'
,
'
gif
'
,
'
GIF'
}
if
os
.
path
.
isfile
(
img_file
)
and
img
hdr
.
what
(
img_file
)
in
img_end
:
imgs_lists
.
append
(
img_file
)
elif
os
.
path
.
isdir
(
img_file
):
for
single_file
in
os
.
listdir
(
img_file
):
if
single_file
.
split
(
'.'
)[
-
1
]
in
img_end
:
imgs_lists
.
append
(
os
.
path
.
join
(
img_file
,
single_file
))
file_path
=
os
.
path
.
join
(
img_file
,
single_file
)
if
imghdr
.
what
(
file_path
)
in
img_end
:
imgs_lists
.
append
(
file_path
)
if
len
(
imgs_lists
)
==
0
:
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
return
imgs_lists
from
paddle
import
fluid
def
check_and_read_gif
(
img_path
):
if
os
.
path
.
basename
(
img_path
)[
-
3
:]
in
[
'gif'
,
'GIF'
]:
gif
=
cv2
.
VideoCapture
(
img_path
)
ret
,
frame
=
gif
.
read
()
if
not
ret
:
logging
.
info
(
"Cannot read {}. This gif image maybe corrupted."
)
return
None
,
False
if
len
(
frame
.
shape
)
==
2
or
frame
.
shape
[
-
1
]
==
1
:
frame
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_GRAY2RGB
)
imgvalue
=
frame
[:,
:,
::
-
1
]
return
imgvalue
,
True
return
None
,
False
def
create_multi_devices_program
(
program
,
loss_var_name
):
...
...
tools/export_model.py
View file @
fee2c17b
...
...
@@ -41,27 +41,11 @@ from paddle import fluid
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.character
import
CharacterOps
from
ppocr.utils.utility
import
create_module
def
main
():
config
=
program
.
load_config
(
FLAGS
.
config
)
program
.
merge_config
(
FLAGS
.
opt
)
logger
.
info
(
config
)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
program
.
check_gpu
(
use_gpu
)
alg
=
config
[
'Global'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]:
config
[
'Global'
][
'char_ops'
]
=
CharacterOps
(
config
[
'Global'
])
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
startup_prog
=
fluid
.
Program
()
eval_program
=
fluid
.
Program
()
def
main
():
startup_prog
,
eval_program
,
place
,
config
,
_
=
program
.
preprocess
()
feeded_var_names
,
target_vars
,
fetches_var_name
=
program
.
build_export
(
config
,
eval_program
,
startup_prog
)
...
...
@@ -88,6 +72,4 @@ def main():
if
__name__
==
'__main__'
:
parser
=
program
.
ArgsParser
()
FLAGS
=
parser
.
parse_args
()
main
()
tools/infer/predict_det.py
View file @
fee2c17b
...
...
@@ -20,7 +20,7 @@ sys.path.append(os.path.join(__dir__, '../..'))
import
tools.infer.utility
as
utility
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
import
cv2
from
ppocr.data.det.east_process
import
EASTProcessTest
from
ppocr.data.det.db_process
import
DBProcessTest
...
...
@@ -135,8 +135,13 @@ if __name__ == "__main__":
text_detector
=
TextDetector
(
args
)
count
=
0
total_time
=
0
draw_img_save
=
"./inference_results"
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
for
image_file
in
image_file_list
:
img
=
cv2
.
imread
(
image_file
)
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
...
...
@@ -147,6 +152,7 @@ if __name__ == "__main__":
print
(
"Predict time of %s:"
%
image_file
,
elapse
)
src_im
=
utility
.
draw_text_det_res
(
dt_boxes
,
image_file
)
img_name_pure
=
image_file
.
split
(
"/"
)[
-
1
]
cv2
.
imwrite
(
"./inference_results/det_res_%s"
%
img_name_pure
,
src_im
)
cv2
.
imwrite
(
os
.
path
.
join
(
draw_img_save
,
"det_res_%s"
%
img_name_pure
),
src_im
)
if
count
>
1
:
print
(
"Avg Time:"
,
total_time
/
(
count
-
1
))
tools/infer/predict_rec.py
View file @
fee2c17b
...
...
@@ -20,7 +20,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import
tools.infer.utility
as
utility
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
import
cv2
import
copy
import
numpy
as
np
...
...
@@ -122,9 +122,9 @@ class TextRecognizer(object):
ind
=
np
.
argmax
(
probs
,
axis
=
1
)
blank
=
probs
.
shape
[
1
]
valid_ind
=
np
.
where
(
ind
!=
(
blank
-
1
))[
0
]
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
if
len
(
valid_ind
)
==
0
:
continue
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
# rec_res.append([preds_text, score])
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
[
preds_text
,
score
]
else
:
...
...
@@ -153,7 +153,9 @@ def main(args):
valid_image_file_list
=
[]
img_list
=
[]
for
image_file
in
image_file_list
:
img
=
cv2
.
imread
(
image_file
,
cv2
.
IMREAD_COLOR
)
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
...
...
tools/infer/predict_system.py
View file @
fee2c17b
...
...
@@ -27,7 +27,7 @@ import copy
import
numpy
as
np
import
math
import
time
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
PIL
import
Image
from
tools.infer.utility
import
draw_ocr
from
tools.infer.utility
import
draw_ocr_box_txt
...
...
@@ -49,18 +49,23 @@ class TextSystem(object):
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
img_crop_width
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
1
]),
np
.
linalg
.
norm
(
points
[
2
]
-
points
[
3
])))
img_crop_height
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
3
]),
np
.
linalg
.
norm
(
points
[
1
]
-
points
[
2
])))
pts_std
=
np
.
float32
([[
0
,
0
],
[
img_crop_width
,
0
],
img_crop_width
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
1
]),
np
.
linalg
.
norm
(
points
[
2
]
-
points
[
3
])))
img_crop_height
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
3
]),
np
.
linalg
.
norm
(
points
[
1
]
-
points
[
2
])))
pts_std
=
np
.
float32
([[
0
,
0
],
[
img_crop_width
,
0
],
[
img_crop_width
,
img_crop_height
],
[
0
,
img_crop_height
]])
M
=
cv2
.
getPerspectiveTransform
(
points
,
pts_std
)
dst_img
=
cv2
.
warpPerspective
(
img
,
M
,
(
img_crop_width
,
img_crop_height
),
borderMode
=
cv2
.
BORDER_REPLICATE
,
flags
=
cv2
.
INTER_CUBIC
)
dst_img
=
cv2
.
warpPerspective
(
img
,
M
,
(
img_crop_width
,
img_crop_height
),
borderMode
=
cv2
.
BORDER_REPLICATE
,
flags
=
cv2
.
INTER_CUBIC
)
dst_img_height
,
dst_img_width
=
dst_img
.
shape
[
0
:
2
]
if
dst_img_height
*
1.0
/
dst_img_width
>=
1.5
:
dst_img
=
np
.
rot90
(
dst_img
)
...
...
@@ -119,25 +124,27 @@ def main(args):
is_visualize
=
True
tackle_img_num
=
0
for
image_file
in
image_file_list
:
img
=
cv2
.
imread
(
image_file
)
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
starttime
=
time
.
time
()
tackle_img_num
+=
1
if
not
args
.
use_gpu
and
args
.
enable_mkldnn
and
tackle_img_num
%
30
==
0
:
tackle_img_num
+=
1
if
not
args
.
use_gpu
and
args
.
enable_mkldnn
and
tackle_img_num
%
30
==
0
:
text_sys
=
TextSystem
(
args
)
dt_boxes
,
rec_res
=
text_sys
(
img
)
elapse
=
time
.
time
()
-
starttime
print
(
"Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
drop_score
=
0.5
dt_num
=
len
(
dt_boxes
)
dt_boxes_final
=
[]
for
dno
in
range
(
dt_num
):
text
,
score
=
rec_res
[
dno
]
if
score
>=
0.5
:
if
score
>=
drop_score
:
text_str
=
"%s, %.3f"
%
(
text
,
score
)
print
(
text_str
)
dt_boxes_final
.
append
(
dt_boxes
[
dno
])
if
is_visualize
:
image
=
Image
.
fromarray
(
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
))
...
...
@@ -146,7 +153,12 @@ def main(args):
scores
=
[
rec_res
[
i
][
1
]
for
i
in
range
(
len
(
rec_res
))]
draw_img
=
draw_ocr
(
image
,
boxes
,
txts
,
scores
,
draw_txt
=
True
,
drop_score
=
0.5
)
image
,
boxes
,
txts
,
scores
,
draw_txt
=
True
,
drop_score
=
drop_score
)
draw_img_save
=
"./inference_results/"
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
...
...
tools/infer/utility.py
View file @
fee2c17b
...
...
@@ -95,7 +95,7 @@ def create_predictor(args, mode):
config
.
set_cpu_math_library_num_threads
(
6
)
if
args
.
enable_mkldnn
:
config
.
enable_mkldnn
()
#config.enable_memory_optim()
config
.
disable_glog_info
()
...
...
@@ -169,7 +169,7 @@ def draw_ocr_box_txt(image, boxes, txts):
img_right
=
Image
.
new
(
'RGB'
,
(
w
,
h
),
(
255
,
255
,
255
))
import
random
# 每次使用相同的随机种子 ,可以保证两次颜色一致
random
.
seed
(
0
)
draw_left
=
ImageDraw
.
Draw
(
img_left
)
draw_right
=
ImageDraw
.
Draw
(
img_right
)
...
...
tools/train.py
View file @
fee2c17b
...
...
@@ -42,27 +42,10 @@ from ppocr.utils.utility import initial_logger
logger
=
initial_logger
()
from
ppocr.data.reader_main
import
reader_main
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.character
import
CharacterOps
from
paddle.fluid.contrib.model_stat
import
summary
def
main
():
config
=
program
.
load_config
(
FLAGS
.
config
)
program
.
merge_config
(
FLAGS
.
opt
)
logger
.
info
(
config
)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
program
.
check_gpu
(
use_gpu
)
alg
=
config
[
'Global'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]:
config
[
'Global'
][
'char_ops'
]
=
CharacterOps
(
config
[
'Global'
])
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
startup_program
=
fluid
.
Program
()
train_program
=
fluid
.
Program
()
train_build_outputs
=
program
.
build
(
config
,
train_program
,
startup_program
,
mode
=
'train'
)
train_loader
=
train_build_outputs
[
0
]
...
...
@@ -91,7 +74,7 @@ def main():
# dump mode structure
if
config
[
'Global'
][
'debug'
]:
if
'attention'
in
config
[
'Global'
][
'loss_type'
]:
if
train_alg_type
==
'rec'
and
'attention'
in
config
[
'Global'
][
'loss_type'
]:
logger
.
warning
(
'Does not suport dump attention...'
)
else
:
summary
(
train_program
)
...
...
@@ -109,15 +92,13 @@ def main():
'fetch_name_list'
:
eval_fetch_name_list
,
\
'fetch_varname_list'
:
eval_fetch_varname_list
}
if
alg
in
[
'EAST'
,
'DB'
]
:
if
train_alg_type
==
'det'
:
program
.
train_eval_det_run
(
config
,
exe
,
train_info_dict
,
eval_info_dict
)
else
:
program
.
train_eval_rec_run
(
config
,
exe
,
train_info_dict
,
eval_info_dict
)
def
test_reader
():
config
=
program
.
load_config
(
FLAGS
.
config
)
program
.
merge_config
(
FLAGS
.
opt
)
logger
.
info
(
config
)
train_reader
=
reader_main
(
config
=
config
,
mode
=
"train"
)
import
time
...
...
@@ -136,7 +117,6 @@ def test_reader():
if
__name__
==
'__main__'
:
parser
=
program
.
ArgsParser
()
FLAGS
=
parser
.
parse_args
()
startup_program
,
train_program
,
place
,
config
,
train_alg_type
=
program
.
preprocess
()
main
()
# test_reader()
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment