Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
ac8c2a89
"include/ck/utility/amd_inline_asm.hpp" did not exist on "f2ac7832c65969f5b3ecf7972518d55ee099c03b"
Commit
ac8c2a89
authored
Jun 30, 2021
by
WenmuZhou
Browse files
merge dygraph
parents
88a8be12
e174e9ed
Changes
88
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
202 additions
and
114 deletions
+202
-114
tools/infer/predict_det.py
tools/infer/predict_det.py
+13
-60
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+9
-6
tools/infer/predict_system.py
tools/infer/predict_system.py
+32
-4
tools/infer/utility.py
tools/infer/utility.py
+15
-27
tools/infer_det.py
tools/infer_det.py
+1
-1
tools/infer_table.py
tools/infer_table.py
+107
-0
tools/program.py
tools/program.py
+23
-14
tools/train.py
tools/train.py
+2
-2
No files found.
tools/infer/predict_det.py
View file @
ac8c2a89
...
@@ -31,7 +31,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
...
@@ -31,7 +31,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from
ppocr.data
import
create_operators
,
transform
from
ppocr.data
import
create_operators
,
transform
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
import
tools.infer.benchmark_utils
as
benchmark_utils
#
import tools.infer.benchmark_utils as benchmark_utils
logger
=
get_logger
()
logger
=
get_logger
()
...
@@ -43,7 +43,7 @@ class TextDetector(object):
...
@@ -43,7 +43,7 @@ class TextDetector(object):
pre_process_list
=
[{
pre_process_list
=
[{
'DetResizeForTest'
:
{
'DetResizeForTest'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
'limit_type'
:
args
.
det_limit_type
,
}
}
},
{
},
{
'NormalizeImage'
:
{
'NormalizeImage'
:
{
...
@@ -100,8 +100,6 @@ class TextDetector(object):
...
@@ -100,8 +100,6 @@ class TextDetector(object):
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
utility
.
create_predictor
(
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
utility
.
create_predictor
(
args
,
'det'
,
logger
)
args
,
'det'
,
logger
)
self
.
det_times
=
utility
.
Timer
()
def
order_points_clockwise
(
self
,
pts
):
def
order_points_clockwise
(
self
,
pts
):
"""
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
...
@@ -158,8 +156,8 @@ class TextDetector(object):
...
@@ -158,8 +156,8 @@ class TextDetector(object):
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
data
=
{
'image'
:
img
}
self
.
det_times
.
total_time
.
start
()
s
elf
.
det_times
.
preprocess_time
.
start
()
s
t
=
time
.
time
()
data
=
transform
(
data
,
self
.
preprocess_op
)
data
=
transform
(
data
,
self
.
preprocess_op
)
img
,
shape_list
=
data
img
,
shape_list
=
data
if
img
is
None
:
if
img
is
None
:
...
@@ -168,16 +166,12 @@ class TextDetector(object):
...
@@ -168,16 +166,12 @@ class TextDetector(object):
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
img
=
img
.
copy
()
img
=
img
.
copy
()
self
.
det_times
.
preprocess_time
.
end
()
self
.
det_times
.
inference_time
.
start
()
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
self
.
predictor
.
run
()
outputs
=
[]
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
outputs
.
append
(
output
)
self
.
det_times
.
inference_time
.
end
()
preds
=
{}
preds
=
{}
if
self
.
det_algorithm
==
"EAST"
:
if
self
.
det_algorithm
==
"EAST"
:
...
@@ -193,8 +187,6 @@ class TextDetector(object):
...
@@ -193,8 +187,6 @@ class TextDetector(object):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
self
.
det_times
.
postprocess_time
.
start
()
self
.
predictor
.
try_shrink_memory
()
self
.
predictor
.
try_shrink_memory
()
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
dt_boxes
=
post_result
[
0
][
'points'
]
...
@@ -203,10 +195,8 @@ class TextDetector(object):
...
@@ -203,10 +195,8 @@ class TextDetector(object):
else
:
else
:
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
self
.
det_times
.
postprocess_time
.
end
()
et
=
time
.
time
()
self
.
det_times
.
total_time
.
end
()
return
dt_boxes
,
et
-
st
self
.
det_times
.
img_num
+=
1
return
dt_boxes
,
self
.
det_times
.
total_time
.
value
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -216,12 +206,13 @@ if __name__ == "__main__":
...
@@ -216,12 +206,13 @@ if __name__ == "__main__":
count
=
0
count
=
0
total_time
=
0
total_time
=
0
draw_img_save
=
"./inference_results"
draw_img_save
=
"./inference_results"
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
# warmup 10 times
if
args
.
warmup
:
fake_img
=
np
.
random
.
uniform
(
-
1
,
1
,
[
640
,
640
,
3
]).
astype
(
np
.
float32
)
img
=
np
.
random
.
uniform
(
0
,
255
,
[
640
,
640
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
10
):
for
i
in
range
(
10
):
dt_boxes
,
_
=
text_detector
(
fake_img
)
res
=
text_detector
(
img
)
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
if
not
os
.
path
.
exists
(
draw_img_save
):
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
os
.
makedirs
(
draw_img_save
)
...
@@ -239,49 +230,11 @@ if __name__ == "__main__":
...
@@ -239,49 +230,11 @@ if __name__ == "__main__":
total_time
+=
elapse
total_time
+=
elapse
count
+=
1
count
+=
1
if
args
.
benchmark
:
cm
,
gm
,
gu
=
utility
.
get_current_memory_mb
(
0
)
cpu_mem
+=
cm
gpu_mem
+=
gm
gpu_util
+=
gu
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
src_im
=
utility
.
draw_text_det_res
(
dt_boxes
,
image_file
)
src_im
=
utility
.
draw_text_det_res
(
dt_boxes
,
image_file
)
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_path
=
os
.
path
.
join
(
draw_img_save
,
img_path
=
os
.
path
.
join
(
draw_img_save
,
"det_res_{}"
.
format
(
img_name_pure
))
"det_res_{}"
.
format
(
img_name_pure
))
cv2
.
imwrite
(
img_path
,
src_im
)
logger
.
info
(
"The visualized image saved in {}"
.
format
(
img_path
))
logger
.
info
(
"The visualized image saved in {}"
.
format
(
img_path
))
# print the information about memory and time-spent
if
args
.
benchmark
:
mems
=
{
'cpu_rss_mb'
:
cpu_mem
/
count
,
'gpu_rss_mb'
:
gpu_mem
/
count
,
'gpu_util'
:
gpu_util
*
100
/
count
}
else
:
mems
=
None
logger
.
info
(
"The predict time about detection module is as follows: "
)
det_time_dict
=
text_detector
.
det_times
.
report
(
average
=
True
)
det_model_name
=
args
.
det_model_dir
if
args
.
benchmark
:
# construct log information
model_info
=
{
'model_name'
:
args
.
det_model_dir
.
split
(
'/'
)[
-
1
],
'precision'
:
args
.
precision
}
data_info
=
{
'batch_size'
:
1
,
'shape'
:
'dynamic_shape'
,
'data_num'
:
det_time_dict
[
'img_num'
]
}
perf_info
=
{
'preprocess_time_s'
:
det_time_dict
[
'preprocess_time'
],
'inference_time_s'
:
det_time_dict
[
'inference_time'
],
'postprocess_time_s'
:
det_time_dict
[
'postprocess_time'
],
'total_time_s'
:
det_time_dict
[
'total_time'
]
}
benchmark_log
=
benchmark_utils
.
PaddleInferBenchmark
(
text_detector
.
config
,
model_info
,
data_info
,
perf_info
,
mems
)
benchmark_log
(
"Det"
)
tools/infer/predict_rec.py
View file @
ac8c2a89
...
@@ -257,13 +257,15 @@ def main(args):
...
@@ -257,13 +257,15 @@ def main(args):
text_recognizer
=
TextRecognizer
(
args
)
text_recognizer
=
TextRecognizer
(
args
)
valid_image_file_list
=
[]
valid_image_file_list
=
[]
img_list
=
[]
img_list
=
[]
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
count
=
0
# warmup 10 times
# warmup 10 times
fake_img
=
np
.
random
.
uniform
(
-
1
,
1
,
[
1
,
32
,
320
,
3
]).
astype
(
np
.
float32
)
if
args
.
warmup
:
for
i
in
range
(
10
):
img
=
np
.
random
.
uniform
(
0
,
255
,
[
32
,
320
,
3
]).
astype
(
np
.
uint8
)
dt_boxes
,
_
=
text_recognizer
(
fake_img
)
for
i
in
range
(
10
):
res
=
text_recognizer
([
img
])
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
count
=
0
for
image_file
in
image_file_list
:
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
img
,
flag
=
check_and_read_gif
(
image_file
)
...
@@ -320,7 +322,8 @@ def main(args):
...
@@ -320,7 +322,8 @@ def main(args):
'total_time_s'
:
rec_time_dict
[
'total_time'
]
'total_time_s'
:
rec_time_dict
[
'total_time'
]
}
}
benchmark_log
=
benchmark_utils
.
PaddleInferBenchmark
(
benchmark_log
=
benchmark_utils
.
PaddleInferBenchmark
(
text_recognizer
.
config
,
model_info
,
data_info
,
perf_info
,
mems
)
text_recognizer
.
config
,
model_info
,
data_info
,
perf_info
,
mems
,
args
.
save_log_path
)
benchmark_log
(
"Rec"
)
benchmark_log
(
"Rec"
)
...
...
tools/infer/predict_system.py
View file @
ac8c2a89
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
sys
import
sys
import
subprocess
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
...
@@ -24,6 +25,7 @@ import cv2
...
@@ -24,6 +25,7 @@ import cv2
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
import
time
import
time
import
logging
from
PIL
import
Image
from
PIL
import
Image
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
import
tools.infer.predict_rec
as
predict_rec
import
tools.infer.predict_rec
as
predict_rec
...
@@ -38,6 +40,9 @@ logger = get_logger()
...
@@ -38,6 +40,9 @@ logger = get_logger()
class
TextSystem
(
object
):
class
TextSystem
(
object
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
if
not
args
.
show_log
:
logger
.
setLevel
(
logging
.
INFO
)
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
use_angle_cls
=
args
.
use_angle_cls
self
.
use_angle_cls
=
args
.
use_angle_cls
...
@@ -55,7 +60,7 @@ class TextSystem(object):
...
@@ -55,7 +60,7 @@ class TextSystem(object):
ori_im
=
img
.
copy
()
ori_im
=
img
.
copy
()
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
logger
.
info
(
"dt_boxes num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
len
(
dt_boxes
),
elapse
))
if
dt_boxes
is
None
:
if
dt_boxes
is
None
:
return
None
,
None
return
None
,
None
...
@@ -70,11 +75,11 @@ class TextSystem(object):
...
@@ -70,11 +75,11 @@ class TextSystem(object):
if
self
.
use_angle_cls
and
cls
:
if
self
.
use_angle_cls
and
cls
:
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
img_crop_list
)
img_crop_list
)
logger
.
info
(
"cls num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"cls num : {}, elapse : {}"
.
format
(
len
(
img_crop_list
),
elapse
))
len
(
img_crop_list
),
elapse
))
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
logger
.
info
(
"rec_res num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
len
(
rec_res
),
elapse
))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes
,
filter_rec_res
=
[],
[]
filter_boxes
,
filter_rec_res
=
[],
[]
...
@@ -109,15 +114,24 @@ def sorted_boxes(dt_boxes):
...
@@ -109,15 +114,24 @@ def sorted_boxes(dt_boxes):
def
main
(
args
):
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
image_file_list
[
args
.
process_id
::
args
.
total_process_num
]
text_sys
=
TextSystem
(
args
)
text_sys
=
TextSystem
(
args
)
is_visualize
=
True
is_visualize
=
True
font_path
=
args
.
vis_font_path
font_path
=
args
.
vis_font_path
drop_score
=
args
.
drop_score
drop_score
=
args
.
drop_score
# warm up 10 times
if
args
.
warmup
:
img
=
np
.
random
.
uniform
(
0
,
255
,
[
640
,
640
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
10
):
res
=
text_sys
(
img
)
total_time
=
0
total_time
=
0
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
_st
=
time
.
time
()
_st
=
time
.
time
()
count
=
0
count
=
0
for
idx
,
image_file
in
enumerate
(
image_file_list
):
for
idx
,
image_file
in
enumerate
(
image_file_list
):
img
,
flag
=
check_and_read_gif
(
image_file
)
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
img
=
cv2
.
imread
(
image_file
)
...
@@ -226,4 +240,18 @@ def main(args):
...
@@ -226,4 +240,18 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
args
=
utility
.
parse_args
()
if
args
.
use_mp
:
p_list
=
[]
total_process_num
=
args
.
total_process_num
for
process_id
in
range
(
total_process_num
):
cmd
=
[
sys
.
executable
,
"-u"
]
+
sys
.
argv
+
[
"--process_id={}"
.
format
(
process_id
),
"--use_mp={}"
.
format
(
False
)
]
p
=
subprocess
.
Popen
(
cmd
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stdout
)
p_list
.
append
(
p
)
for
p
in
p_list
:
p
.
wait
()
else
:
main
(
args
)
tools/infer/utility.py
View file @
ac8c2a89
...
@@ -37,6 +37,7 @@ def init_args():
...
@@ -37,6 +37,7 @@ def init_args():
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--min_subgraph_size"
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"fp32"
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"fp32"
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
500
)
...
@@ -105,7 +106,9 @@ def init_args():
...
@@ -105,7 +106,9 @@ def init_args():
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--cpu_threads"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--cpu_threads"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--use_pdserving"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_pdserving"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--warmup"
,
type
=
str2bool
,
default
=
True
)
# multi-process
parser
.
add_argument
(
"--use_mp"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_mp"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--total_process_num"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--total_process_num"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--process_id"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--process_id"
,
type
=
int
,
default
=
0
)
...
@@ -113,6 +116,7 @@ def init_args():
...
@@ -113,6 +116,7 @@ def init_args():
parser
.
add_argument
(
"--benchmark"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--benchmark"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--save_log_path"
,
type
=
str
,
default
=
"./log_output/"
)
parser
.
add_argument
(
"--save_log_path"
,
type
=
str
,
default
=
"./log_output/"
)
parser
.
add_argument
(
"--show_log"
,
type
=
str2bool
,
default
=
True
)
return
parser
return
parser
...
@@ -198,6 +202,8 @@ def create_predictor(args, mode, logger):
...
@@ -198,6 +202,8 @@ def create_predictor(args, mode, logger):
model_dir
=
args
.
cls_model_dir
model_dir
=
args
.
cls_model_dir
elif
mode
==
'rec'
:
elif
mode
==
'rec'
:
model_dir
=
args
.
rec_model_dir
model_dir
=
args
.
rec_model_dir
elif
mode
==
'table'
:
model_dir
=
args
.
table_model_dir
else
:
else
:
model_dir
=
args
.
e2e_model_dir
model_dir
=
args
.
e2e_model_dir
...
@@ -231,12 +237,14 @@ def create_predictor(args, mode, logger):
...
@@ -231,12 +237,14 @@ def create_predictor(args, mode, logger):
config
.
enable_tensorrt_engine
(
config
.
enable_tensorrt_engine
(
precision_mode
=
inference
.
PrecisionType
.
Float32
,
precision_mode
=
inference
.
PrecisionType
.
Float32
,
max_batch_size
=
args
.
max_batch_size
,
max_batch_size
=
args
.
max_batch_size
,
min_subgraph_size
=
3
)
# skip the minmum trt subgraph
min_subgraph_size
=
args
.
min_subgraph_size
)
if
mode
==
"det"
and
"mobile"
in
model_file_path
:
# skip the minmum trt subgraph
if
mode
==
"det"
:
min_input_shape
=
{
min_input_shape
=
{
"x"
:
[
1
,
3
,
50
,
50
],
"x"
:
[
1
,
3
,
50
,
50
],
"conv2d_92.tmp_0"
:
[
1
,
96
,
20
,
20
],
"conv2d_92.tmp_0"
:
[
1
,
96
,
20
,
20
],
"conv2d_91.tmp_0"
:
[
1
,
96
,
10
,
10
],
"conv2d_91.tmp_0"
:
[
1
,
96
,
10
,
10
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
96
,
10
,
10
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
96
,
10
,
10
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
20
,
20
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
20
,
20
],
...
@@ -249,6 +257,7 @@ def create_predictor(args, mode, logger):
...
@@ -249,6 +257,7 @@ def create_predictor(args, mode, logger):
"x"
:
[
1
,
3
,
2000
,
2000
],
"x"
:
[
1
,
3
,
2000
,
2000
],
"conv2d_92.tmp_0"
:
[
1
,
96
,
400
,
400
],
"conv2d_92.tmp_0"
:
[
1
,
96
,
400
,
400
],
"conv2d_91.tmp_0"
:
[
1
,
96
,
200
,
200
],
"conv2d_91.tmp_0"
:
[
1
,
96
,
200
,
200
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
96
,
200
,
200
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
96
,
200
,
200
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
400
,
400
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
400
,
400
],
...
@@ -261,6 +270,7 @@ def create_predictor(args, mode, logger):
...
@@ -261,6 +270,7 @@ def create_predictor(args, mode, logger):
"x"
:
[
1
,
3
,
640
,
640
],
"x"
:
[
1
,
3
,
640
,
640
],
"conv2d_92.tmp_0"
:
[
1
,
96
,
160
,
160
],
"conv2d_92.tmp_0"
:
[
1
,
96
,
160
,
160
],
"conv2d_91.tmp_0"
:
[
1
,
96
,
80
,
80
],
"conv2d_91.tmp_0"
:
[
1
,
96
,
80
,
80
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
96
,
80
,
80
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
96
,
80
,
80
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
160
,
160
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
160
,
160
],
...
@@ -269,31 +279,6 @@ def create_predictor(args, mode, logger):
...
@@ -269,31 +279,6 @@ def create_predictor(args, mode, logger):
"elementwise_add_7"
:
[
1
,
56
,
40
,
40
],
"elementwise_add_7"
:
[
1
,
56
,
40
,
40
],
"nearest_interp_v2_0.tmp_0"
:
[
1
,
96
,
40
,
40
]
"nearest_interp_v2_0.tmp_0"
:
[
1
,
96
,
40
,
40
]
}
}
if
mode
==
"det"
and
"server"
in
model_file_path
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
50
,
50
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
20
,
20
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
24
,
20
,
20
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
24
,
20
,
20
]
}
max_input_shape
=
{
"x"
:
[
1
,
3
,
2000
,
2000
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
400
,
400
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
24
,
400
,
400
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
24
,
400
,
400
]
}
opt_input_shape
=
{
"x"
:
[
1
,
3
,
640
,
640
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
160
,
160
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
24
,
160
,
160
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
24
,
160
,
160
]
}
elif
mode
==
"rec"
:
elif
mode
==
"rec"
:
min_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
32
,
10
]}
min_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
32
,
10
]}
max_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
32
,
2000
]}
max_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
32
,
2000
]}
...
@@ -326,7 +311,10 @@ def create_predictor(args, mode, logger):
...
@@ -326,7 +311,10 @@ def create_predictor(args, mode, logger):
config
.
disable_glog_info
()
config
.
disable_glog_info
()
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
if
mode
==
'table'
:
config
.
delete_pass
(
"fc_fuse_pass"
)
# not supported for table
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_ir_optim
(
True
)
# create predictor
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
predictor
=
inference
.
create_predictor
(
config
)
...
...
tools/infer_det.py
View file @
ac8c2a89
...
@@ -112,4 +112,4 @@ def main():
...
@@ -112,4 +112,4 @@ def main():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
()
main
()
\ No newline at end of file
tools/infer_table.py
0 → 100644
View file @
ac8c2a89
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
sys
import
json
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
paddle
from
paddle.jit
import
to_static
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
import
cv2
def
main
(
config
,
device
,
logger
,
vdl_writer
):
global_config
=
config
[
'Global'
]
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
# build model
if
hasattr
(
post_process_class
,
'character'
):
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
len
(
getattr
(
post_process_class
,
'character'
))
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
# create data ops
transforms
=
[]
use_padding
=
False
for
op
in
config
[
'Eval'
][
'dataset'
][
'transforms'
]:
op_name
=
list
(
op
)[
0
]
if
'Label'
in
op_name
:
continue
if
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
if
op_name
==
"ResizeTableImage"
:
use_padding
=
True
padding_max_len
=
op
[
'ResizeTableImage'
][
'max_len'
]
transforms
.
append
(
op
)
global_config
[
'infer_mode'
]
=
True
ops
=
create_operators
(
transforms
,
global_config
)
model
.
eval
()
for
file
in
get_image_file_list
(
config
[
'Global'
][
'infer_img'
]):
logger
.
info
(
"infer_img: {}"
.
format
(
file
))
with
open
(
file
,
'rb'
)
as
f
:
img
=
f
.
read
()
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
ops
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
)
res_html_code
=
post_result
[
'res_html_code'
]
res_loc
=
post_result
[
'res_loc'
]
img
=
cv2
.
imread
(
file
)
imgh
,
imgw
=
img
.
shape
[
0
:
2
]
res_loc_final
=
[]
for
rno
in
range
(
len
(
res_loc
[
0
])):
x0
,
y0
,
x1
,
y1
=
res_loc
[
0
][
rno
]
left
=
max
(
int
(
imgw
*
x0
),
0
)
top
=
max
(
int
(
imgh
*
y0
),
0
)
right
=
min
(
int
(
imgw
*
x1
),
imgw
-
1
)
bottom
=
min
(
int
(
imgh
*
y1
),
imgh
-
1
)
cv2
.
rectangle
(
img
,
(
left
,
top
),
(
right
,
bottom
),
(
0
,
0
,
255
),
2
)
res_loc_final
.
append
([
left
,
top
,
right
,
bottom
])
res_loc_str
=
json
.
dumps
(
res_loc_final
)
logger
.
info
(
"result: {}, {}"
.
format
(
res_html_code
,
res_loc_final
))
logger
.
info
(
"success!"
)
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
(
config
,
device
,
logger
,
vdl_writer
)
tools/program.py
View file @
ac8c2a89
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -186,6 +186,7 @@ def train(config,
...
@@ -186,6 +186,7 @@ def train(config,
model
.
train
()
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
model_type
=
config
[
'Architecture'
][
'model_type'
]
if
'start_epoch'
in
best_model_dict
:
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
start_epoch
=
best_model_dict
[
'start_epoch'
]
...
@@ -208,9 +209,9 @@ def train(config,
...
@@ -208,9 +209,9 @@ def train(config,
lr
=
optimizer
.
get_lr
()
lr
=
optimizer
.
get_lr
()
images
=
batch
[
0
]
images
=
batch
[
0
]
if
use_srn
:
if
use_srn
:
others
=
batch
[
-
4
:]
preds
=
model
(
images
,
others
)
model_average
=
True
model_average
=
True
if
use_srn
or
model_type
==
'table'
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
...
@@ -232,8 +233,11 @@ def train(config,
...
@@ -232,8 +233,11 @@ def train(config,
if
cal_metric_during_train
:
# only rec and cls need
if
cal_metric_during_train
:
# only rec and cls need
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch
=
[
item
.
numpy
()
for
item
in
batch
]
post_result
=
post_process_class
(
preds
,
batch
[
1
])
if
model_type
==
'table'
:
eval_class
(
post_result
,
batch
)
eval_class
(
preds
,
batch
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
metric
=
eval_class
.
get_metric
()
metric
=
eval_class
.
get_metric
()
train_stats
.
update
(
metric
)
train_stats
.
update
(
metric
)
...
@@ -269,6 +273,7 @@ def train(config,
...
@@ -269,6 +273,7 @@ def train(config,
valid_dataloader
,
valid_dataloader
,
post_process_class
,
post_process_class
,
eval_class
,
eval_class
,
model_type
,
use_srn
=
use_srn
)
use_srn
=
use_srn
)
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
...
@@ -336,7 +341,11 @@ def train(config,
...
@@ -336,7 +341,11 @@ def train(config,
return
return
def
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
def
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
model_type
,
use_srn
=
False
):
use_srn
=
False
):
model
.
eval
()
model
.
eval
()
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
...
@@ -350,19 +359,19 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
...
@@ -350,19 +359,19 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
break
break
images
=
batch
[
0
]
images
=
batch
[
0
]
start
=
time
.
time
()
start
=
time
.
time
()
if
use_srn
or
model_type
==
'table'
:
if
use_srn
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
others
=
batch
[
-
4
:]
preds
=
model
(
images
,
others
)
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch
=
[
item
.
numpy
()
for
item
in
batch
]
# Obtain usable results from post-processing methods
# Obtain usable results from post-processing methods
post_result
=
post_process_class
(
preds
,
batch
[
1
])
total_time
+=
time
.
time
()
-
start
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
# Evaluate the results of the current batch
eval_class
(
post_result
,
batch
)
if
model_type
==
'table'
:
eval_class
(
preds
,
batch
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
total_frame
+=
len
(
images
)
total_frame
+=
len
(
images
)
# Get final metric,eg. acc or hmean
# Get final metric,eg. acc or hmean
...
@@ -386,7 +395,7 @@ def preprocess(is_train=False):
...
@@ -386,7 +395,7 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
'CLS'
,
'PGNet'
,
'Distillation'
,
'TableAttn'
]
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
tools/train.py
View file @
ac8c2a89
...
@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
...
@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
from
ppocr.optimizer
import
build_optimizer
from
ppocr.optimizer
import
build_optimizer
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
,
load_dygraph_params
import
tools.program
as
program
import
tools.program
as
program
dist
.
get_world_size
()
dist
.
get_world_size
()
...
@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
optimizer
)
pre_best_model_dict
=
load_dygraph_params
(
config
,
model
,
logger
,
optimizer
)
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
if
valid_dataloader
is
not
None
:
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment