Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
465ef3bf
Unverified
Commit
465ef3bf
authored
Jun 28, 2021
by
Double_V
Committed by
GitHub
Jun 28, 2021
Browse files
Merge branch 'dygraph' into bm_dyg
parents
bf9f93f7
bc999986
Changes
94
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
298 additions
and
146 deletions
+298
-146
test1/utility.py
test1/utility.py
+5
-10
tools/eval.py
tools/eval.py
+12
-4
tools/export_model.py
tools/export_model.py
+55
-32
tools/infer/predict_det.py
tools/infer/predict_det.py
+12
-60
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+7
-5
tools/infer/predict_system.py
tools/infer/predict_system.py
+30
-2
tools/infer/utility.py
tools/infer/utility.py
+7
-5
tools/infer_cls.py
tools/infer_cls.py
+1
-1
tools/infer_det.py
tools/infer_det.py
+2
-2
tools/infer_e2e.py
tools/infer_e2e.py
+1
-1
tools/infer_rec.py
tools/infer_rec.py
+27
-8
tools/infer_table.py
tools/infer_table.py
+107
-0
tools/program.py
tools/program.py
+23
-14
tools/train.py
tools/train.py
+9
-2
No files found.
ppstructure
/utility.py
→
test1
/utility.py
View file @
465ef3bf
...
@@ -23,16 +23,11 @@ def init_args():
...
@@ -23,16 +23,11 @@ def init_args():
# params for output
# params for output
parser
.
add_argument
(
"--output"
,
type
=
str
,
default
=
'./output/table'
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
default
=
'./output/table'
)
# params for table structure
# params for table structure
parser
.
add_argument
(
"--structure_max_len"
,
type
=
int
,
default
=
488
)
parser
.
add_argument
(
"--table_max_len"
,
type
=
int
,
default
=
488
)
parser
.
add_argument
(
"--structure_max_text_length"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--table_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--structure_max_elem_length"
,
type
=
int
,
default
=
800
)
parser
.
add_argument
(
"--table_char_type"
,
type
=
str
,
default
=
'en'
)
parser
.
add_argument
(
"--structure_max_cell_num"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--table_char_dict_path"
,
type
=
str
,
default
=
"../ppocr/utils/dict/table_structure_dict.txt"
)
parser
.
add_argument
(
"--structure_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--structure_char_type"
,
type
=
str
,
default
=
'en'
)
parser
.
add_argument
(
"--structure_char_dict_path"
,
type
=
str
,
default
=
"../ppocr/utils/dict/table_structure_dict.txt"
)
# params for layout detector
parser
.
add_argument
(
"--layout_model_dir"
,
type
=
str
)
return
parser
return
parser
...
...
tools/eval.py
View file @
465ef3bf
...
@@ -44,12 +44,20 @@ def main():
...
@@ -44,12 +44,20 @@ def main():
# build model
# build model
# for rec algorithm
# for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
if
hasattr
(
post_process_class
,
'character'
):
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
len
(
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
getattr
(
post_process_class
,
'character'
))
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
model_type
=
config
[
'Architecture'
][
'model_type'
]
best_model_dict
=
init_model
(
config
,
model
,
logger
)
best_model_dict
=
init_model
(
config
,
model
)
if
len
(
best_model_dict
):
if
len
(
best_model_dict
):
logger
.
info
(
'metric in ckpt ***************'
)
logger
.
info
(
'metric in ckpt ***************'
)
for
k
,
v
in
best_model_dict
.
items
():
for
k
,
v
in
best_model_dict
.
items
():
...
@@ -60,7 +68,7 @@ def main():
...
@@ -60,7 +68,7 @@ def main():
# start eval
# start eval
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
use_srn
)
eval_class
,
model_type
,
use_srn
)
logger
.
info
(
'metric eval ***************'
)
logger
.
info
(
'metric eval ***************'
)
for
k
,
v
in
metric
.
items
():
for
k
,
v
in
metric
.
items
():
logger
.
info
(
'{}:{}'
.
format
(
k
,
v
))
logger
.
info
(
'{}:{}'
.
format
(
k
,
v
))
...
...
tools/export_model.py
View file @
465ef3bf
...
@@ -17,7 +17,7 @@ import sys
...
@@ -17,7 +17,7 @@ import sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'
..
'
)))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
"
..
"
)))
import
argparse
import
argparse
...
@@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger
...
@@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
def
main
():
def
export_single_model
(
model
,
arch_config
,
save_path
,
logger
):
FLAGS
=
ArgsParser
().
parse_args
()
if
arch_config
[
"algorithm"
]
==
"SRN"
:
config
=
load_config
(
FLAGS
.
config
)
max_text_length
=
arch_config
[
"Head"
][
"max_text_length"
]
merge_config
(
FLAGS
.
opt
)
logger
=
get_logger
()
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
config
[
'Global'
])
# build model
# for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
model
.
eval
()
save_path
=
'{}/inference'
.
format
(
config
[
'Global'
][
'save_inference_dir'
])
if
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
:
max_text_length
=
config
[
'Architecture'
][
'Head'
][
'max_text_length'
]
other_shape
=
[
other_shape
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
1
,
64
,
256
],
dtype
=
'
float32
'
),
[
shape
=
[
None
,
1
,
64
,
256
],
dtype
=
"
float32
"
),
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
256
,
1
],
shape
=
[
None
,
256
,
1
],
dtype
=
"int64"
),
paddle
.
static
.
InputSpec
(
dtype
=
"int64"
),
paddle
.
static
.
InputSpec
(
...
@@ -71,24 +51,67 @@ def main():
...
@@ -71,24 +51,67 @@ def main():
model
=
to_static
(
model
,
input_spec
=
other_shape
)
model
=
to_static
(
model
,
input_spec
=
other_shape
)
else
:
else
:
infer_shape
=
[
3
,
-
1
,
-
1
]
infer_shape
=
[
3
,
-
1
,
-
1
]
if
config
[
'Architecture'
][
'
model_type
'
]
==
"rec"
:
if
arch_
config
[
"
model_type
"
]
==
"rec"
:
infer_shape
=
[
3
,
32
,
-
1
]
# for rec model, H must be 32
infer_shape
=
[
3
,
32
,
-
1
]
# for rec model, H must be 32
if
'
Transform
'
in
config
[
'Architecture'
]
and
config
[
'Architecture'
]
[
if
"
Transform
"
in
arch_
config
and
arch_config
[
'
Transform
'
]
is
not
None
and
config
[
'Architecture'
][
"
Transform
"
]
is
not
None
and
arch_
config
[
"Transform"
][
'Transform'
][
'
name
'
]
==
'
TPS
'
:
"
name
"
]
==
"
TPS
"
:
logger
.
info
(
logger
.
info
(
'
When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training
'
"
When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training
"
)
)
infer_shape
[
-
1
]
=
100
infer_shape
[
-
1
]
=
100
elif
arch_config
[
"model_type"
]
==
"table"
:
infer_shape
=
[
3
,
488
,
488
]
model
=
to_static
(
model
=
to_static
(
model
,
model
,
input_spec
=
[
input_spec
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
shape
=
[
None
]
+
infer_shape
,
dtype
=
'
float32
'
)
shape
=
[
None
]
+
infer_shape
,
dtype
=
"
float32
"
)
])
])
paddle
.
jit
.
save
(
model
,
save_path
)
paddle
.
jit
.
save
(
model
,
save_path
)
logger
.
info
(
'inference model is saved to {}'
.
format
(
save_path
))
logger
.
info
(
"inference model is saved to {}"
.
format
(
save_path
))
return
def
main
():
FLAGS
=
ArgsParser
().
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
logger
=
get_logger
()
# build post process
post_process_class
=
build_post_process
(
config
[
"PostProcess"
],
config
[
"Global"
])
# build model
# for rec algorithm
if
hasattr
(
post_process_class
,
"character"
):
char_num
=
len
(
getattr
(
post_process_class
,
"character"
))
if
config
[
"Architecture"
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
"Architecture"
][
"Models"
]:
config
[
"Architecture"
][
"Models"
][
key
][
"Head"
][
"out_channels"
]
=
char_num
else
:
# base rec model
config
[
"Architecture"
][
"Head"
][
"out_channels"
]
=
char_num
model
=
build_model
(
config
[
"Architecture"
])
init_model
(
config
,
model
)
model
.
eval
()
save_path
=
config
[
"Global"
][
"save_inference_dir"
]
arch_config
=
config
[
"Architecture"
]
if
arch_config
[
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
archs
=
list
(
arch_config
[
"Models"
].
values
())
for
idx
,
name
in
enumerate
(
model
.
model_name_list
):
sub_model_save_path
=
os
.
path
.
join
(
save_path
,
name
,
"inference"
)
export_single_model
(
model
.
model_list
[
idx
],
archs
[
idx
],
sub_model_save_path
,
logger
)
else
:
save_path
=
os
.
path
.
join
(
save_path
,
"inference"
)
export_single_model
(
model
,
arch_config
,
save_path
,
logger
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tools/infer/predict_det.py
View file @
465ef3bf
...
@@ -31,7 +31,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
...
@@ -31,7 +31,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from
ppocr.data
import
create_operators
,
transform
from
ppocr.data
import
create_operators
,
transform
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
import
tools.infer.benchmark_utils
as
benchmark_utils
#
import tools.infer.benchmark_utils as benchmark_utils
logger
=
get_logger
()
logger
=
get_logger
()
...
@@ -100,8 +100,6 @@ class TextDetector(object):
...
@@ -100,8 +100,6 @@ class TextDetector(object):
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
utility
.
create_predictor
(
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
utility
.
create_predictor
(
args
,
'det'
,
logger
)
args
,
'det'
,
logger
)
self
.
det_times
=
utility
.
Timer
()
def
order_points_clockwise
(
self
,
pts
):
def
order_points_clockwise
(
self
,
pts
):
"""
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
...
@@ -158,8 +156,8 @@ class TextDetector(object):
...
@@ -158,8 +156,8 @@ class TextDetector(object):
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
data
=
{
'image'
:
img
}
self
.
det_times
.
total_time
.
start
()
s
elf
.
det_times
.
preprocess_time
.
start
()
s
t
=
time
.
time
()
data
=
transform
(
data
,
self
.
preprocess_op
)
data
=
transform
(
data
,
self
.
preprocess_op
)
img
,
shape_list
=
data
img
,
shape_list
=
data
if
img
is
None
:
if
img
is
None
:
...
@@ -168,16 +166,12 @@ class TextDetector(object):
...
@@ -168,16 +166,12 @@ class TextDetector(object):
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
img
=
img
.
copy
()
img
=
img
.
copy
()
self
.
det_times
.
preprocess_time
.
end
()
self
.
det_times
.
inference_time
.
start
()
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
self
.
predictor
.
run
()
outputs
=
[]
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
outputs
.
append
(
output
)
self
.
det_times
.
inference_time
.
end
()
preds
=
{}
preds
=
{}
if
self
.
det_algorithm
==
"EAST"
:
if
self
.
det_algorithm
==
"EAST"
:
...
@@ -193,8 +187,6 @@ class TextDetector(object):
...
@@ -193,8 +187,6 @@ class TextDetector(object):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
self
.
det_times
.
postprocess_time
.
start
()
self
.
predictor
.
try_shrink_memory
()
self
.
predictor
.
try_shrink_memory
()
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
dt_boxes
=
post_result
[
0
][
'points'
]
...
@@ -203,10 +195,8 @@ class TextDetector(object):
...
@@ -203,10 +195,8 @@ class TextDetector(object):
else
:
else
:
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
self
.
det_times
.
postprocess_time
.
end
()
et
=
time
.
time
()
self
.
det_times
.
total_time
.
end
()
return
dt_boxes
,
et
-
st
self
.
det_times
.
img_num
+=
1
return
dt_boxes
,
self
.
det_times
.
total_time
.
value
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -216,12 +206,13 @@ if __name__ == "__main__":
...
@@ -216,12 +206,13 @@ if __name__ == "__main__":
count
=
0
count
=
0
total_time
=
0
total_time
=
0
draw_img_save
=
"./inference_results"
draw_img_save
=
"./inference_results"
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
# warmup 10 times
if
args
.
warmup
:
fake_
img
=
np
.
random
.
uniform
(
-
1
,
1
,
[
640
,
640
,
3
]).
astype
(
np
.
float32
)
img
=
np
.
random
.
uniform
(
0
,
255
,
[
640
,
640
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
10
):
for
i
in
range
(
10
):
dt_boxes
,
_
=
text_detector
(
fake_img
)
res
=
text_detector
(
img
)
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
if
not
os
.
path
.
exists
(
draw_img_save
):
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
os
.
makedirs
(
draw_img_save
)
...
@@ -239,50 +230,11 @@ if __name__ == "__main__":
...
@@ -239,50 +230,11 @@ if __name__ == "__main__":
total_time
+=
elapse
total_time
+=
elapse
count
+=
1
count
+=
1
if
args
.
benchmark
:
cm
,
gm
,
gu
=
utility
.
get_current_memory_mb
(
0
)
cpu_mem
+=
cm
gpu_mem
+=
gm
gpu_util
+=
gu
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
src_im
=
utility
.
draw_text_det_res
(
dt_boxes
,
image_file
)
src_im
=
utility
.
draw_text_det_res
(
dt_boxes
,
image_file
)
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_path
=
os
.
path
.
join
(
draw_img_save
,
img_path
=
os
.
path
.
join
(
draw_img_save
,
"det_res_{}"
.
format
(
img_name_pure
))
"det_res_{}"
.
format
(
img_name_pure
))
cv2
.
imwrite
(
img_path
,
src_im
)
logger
.
info
(
"The visualized image saved in {}"
.
format
(
img_path
))
logger
.
info
(
"The visualized image saved in {}"
.
format
(
img_path
))
# print the information about memory and time-spent
if
args
.
benchmark
:
mems
=
{
'cpu_rss_mb'
:
cpu_mem
/
count
,
'gpu_rss_mb'
:
gpu_mem
/
count
,
'gpu_util'
:
gpu_util
*
100
/
count
}
else
:
mems
=
None
logger
.
info
(
"The predict time about detection module is as follows: "
)
det_time_dict
=
text_detector
.
det_times
.
report
(
average
=
True
)
det_model_name
=
args
.
det_model_dir
if
args
.
benchmark
:
# construct log information
model_info
=
{
'model_name'
:
args
.
det_model_dir
.
split
(
'/'
)[
-
1
],
'precision'
:
args
.
precision
}
data_info
=
{
'batch_size'
:
1
,
'shape'
:
'dynamic_shape'
,
'data_num'
:
det_time_dict
[
'img_num'
]
}
perf_info
=
{
'preprocess_time_s'
:
det_time_dict
[
'preprocess_time'
],
'inference_time_s'
:
det_time_dict
[
'inference_time'
],
'postprocess_time_s'
:
det_time_dict
[
'postprocess_time'
],
'total_time_s'
:
det_time_dict
[
'total_time'
]
}
benchmark_log
=
benchmark_utils
.
PaddleInferBenchmark
(
text_detector
.
config
,
model_info
,
data_info
,
perf_info
,
mems
,
args
.
save_log_path
)
benchmark_log
(
"Det"
)
tools/infer/predict_rec.py
View file @
465ef3bf
...
@@ -257,13 +257,15 @@ def main(args):
...
@@ -257,13 +257,15 @@ def main(args):
text_recognizer
=
TextRecognizer
(
args
)
text_recognizer
=
TextRecognizer
(
args
)
valid_image_file_list
=
[]
valid_image_file_list
=
[]
img_list
=
[]
img_list
=
[]
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
count
=
0
# warmup 10 times
# warmup 10 times
fake_img
=
np
.
random
.
uniform
(
-
1
,
1
,
[
1
,
32
,
320
,
3
]).
astype
(
np
.
float32
)
if
args
.
warmup
:
img
=
np
.
random
.
uniform
(
0
,
255
,
[
32
,
320
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
10
):
for
i
in
range
(
10
):
dt_boxes
,
_
=
text_recognizer
(
fake_img
)
res
=
text_recognizer
([
img
])
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
count
=
0
for
image_file
in
image_file_list
:
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
img
,
flag
=
check_and_read_gif
(
image_file
)
...
...
tools/infer/predict_system.py
View file @
465ef3bf
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
sys
import
sys
import
subprocess
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
...
@@ -24,6 +25,7 @@ import cv2
...
@@ -24,6 +25,7 @@ import cv2
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
import
time
import
time
import
logging
from
PIL
import
Image
from
PIL
import
Image
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
import
tools.infer.predict_rec
as
predict_rec
import
tools.infer.predict_rec
as
predict_rec
...
@@ -38,6 +40,9 @@ logger = get_logger()
...
@@ -38,6 +40,9 @@ logger = get_logger()
class
TextSystem
(
object
):
class
TextSystem
(
object
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
if
not
args
.
show_log
:
logger
.
setLevel
(
logging
.
INFO
)
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
use_angle_cls
=
args
.
use_angle_cls
self
.
use_angle_cls
=
args
.
use_angle_cls
...
@@ -142,20 +147,29 @@ def sorted_boxes(dt_boxes):
...
@@ -142,20 +147,29 @@ def sorted_boxes(dt_boxes):
def
main
(
args
):
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
image_file_list
[
args
.
process_id
::
args
.
total_process_num
]
text_sys
=
TextSystem
(
args
)
text_sys
=
TextSystem
(
args
)
is_visualize
=
True
is_visualize
=
True
font_path
=
args
.
vis_font_path
font_path
=
args
.
vis_font_path
drop_score
=
args
.
drop_score
drop_score
=
args
.
drop_score
# warm up 10 times
if
args
.
warmup
:
img
=
np
.
random
.
uniform
(
0
,
255
,
[
640
,
640
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
10
):
res
=
text_sys
(
img
)
total_time
=
0
total_time
=
0
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
_st
=
time
.
time
()
_st
=
time
.
time
()
count
=
0
count
=
0
for
idx
,
image_file
in
enumerate
(
image_file_list
):
for
idx
,
image_file
in
enumerate
(
image_file_list
):
img
,
flag
=
check_and_read_gif
(
image_file
)
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
if
img
is
None
:
logger
.
error
(
"error in loading image:{}"
.
format
(
image_file
))
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
continue
starttime
=
time
.
time
()
starttime
=
time
.
time
()
dt_boxes
,
rec_res
=
text_sys
(
img
)
dt_boxes
,
rec_res
=
text_sys
(
img
)
...
@@ -259,4 +273,18 @@ def main(args):
...
@@ -259,4 +273,18 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
args
=
utility
.
parse_args
()
if
args
.
use_mp
:
p_list
=
[]
total_process_num
=
args
.
total_process_num
for
process_id
in
range
(
total_process_num
):
cmd
=
[
sys
.
executable
,
"-u"
]
+
sys
.
argv
+
[
"--process_id={}"
.
format
(
process_id
),
"--use_mp={}"
.
format
(
False
)
]
p
=
subprocess
.
Popen
(
cmd
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stdout
)
p_list
.
append
(
p
)
for
p
in
p_list
:
p
.
wait
()
else
:
main
(
args
)
tools/infer/utility.py
View file @
465ef3bf
...
@@ -106,7 +106,9 @@ def init_args():
...
@@ -106,7 +106,9 @@ def init_args():
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--cpu_threads"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--cpu_threads"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--use_pdserving"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_pdserving"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--warmup"
,
type
=
str2bool
,
default
=
True
)
# multi-process
parser
.
add_argument
(
"--use_mp"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_mp"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--total_process_num"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--total_process_num"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--process_id"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--process_id"
,
type
=
int
,
default
=
0
)
...
@@ -115,7 +117,6 @@ def init_args():
...
@@ -115,7 +117,6 @@ def init_args():
parser
.
add_argument
(
"--save_log_path"
,
type
=
str
,
default
=
"./log_output/"
)
parser
.
add_argument
(
"--save_log_path"
,
type
=
str
,
default
=
"./log_output/"
)
parser
.
add_argument
(
"--show_log"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--show_log"
,
type
=
str2bool
,
default
=
True
)
return
parser
return
parser
...
@@ -201,8 +202,8 @@ def create_predictor(args, mode, logger):
...
@@ -201,8 +202,8 @@ def create_predictor(args, mode, logger):
model_dir
=
args
.
cls_model_dir
model_dir
=
args
.
cls_model_dir
elif
mode
==
'rec'
:
elif
mode
==
'rec'
:
model_dir
=
args
.
rec_model_dir
model_dir
=
args
.
rec_model_dir
elif
mode
==
'
structur
e'
:
elif
mode
==
'
tabl
e'
:
model_dir
=
args
.
structur
e_model_dir
model_dir
=
args
.
tabl
e_model_dir
else
:
else
:
model_dir
=
args
.
e2e_model_dir
model_dir
=
args
.
e2e_model_dir
...
@@ -310,10 +311,11 @@ def create_predictor(args, mode, logger):
...
@@ -310,10 +311,11 @@ def create_predictor(args, mode, logger):
config
.
disable_glog_info
()
config
.
disable_glog_info
()
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
if
mode
==
'table'
:
config
.
delete_pass
(
"fc_fuse_pass"
)
# not supported for table
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_ir_optim
(
True
)
config
.
switch_ir_optim
(
True
)
if
mode
==
'structure'
:
config
.
switch_ir_optim
(
False
)
# create predictor
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
predictor
=
inference
.
create_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
input_names
=
predictor
.
get_input_names
()
...
...
tools/infer_cls.py
View file @
465ef3bf
...
@@ -47,7 +47,7 @@ def main():
...
@@ -47,7 +47,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
init_model
(
config
,
model
)
# create data ops
# create data ops
transforms
=
[]
transforms
=
[]
...
...
tools/infer_det.py
View file @
465ef3bf
...
@@ -61,7 +61,7 @@ def main():
...
@@ -61,7 +61,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
init_model
(
config
,
model
)
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
...
...
tools/infer_e2e.py
View file @
465ef3bf
...
@@ -68,7 +68,7 @@ def main():
...
@@ -68,7 +68,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
init_model
(
config
,
model
)
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
...
...
tools/infer_rec.py
View file @
465ef3bf
...
@@ -20,6 +20,7 @@ import numpy as np
...
@@ -20,6 +20,7 @@ import numpy as np
import
os
import
os
import
sys
import
sys
import
json
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
...
@@ -46,12 +47,18 @@ def main():
...
@@ -46,12 +47,18 @@ def main():
# build model
# build model
if
hasattr
(
post_process_class
,
'character'
):
if
hasattr
(
post_process_class
,
'character'
):
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
len
(
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
getattr
(
post_process_class
,
'character'
))
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
init_model
(
config
,
model
)
# create data ops
# create data ops
transforms
=
[]
transforms
=
[]
...
@@ -107,11 +114,23 @@ def main():
...
@@ -107,11 +114,23 @@ def main():
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
)
post_result
=
post_process_class
(
preds
)
for
rec_reuslt
in
post_result
:
info
=
None
logger
.
info
(
'
\t
result: {}'
.
format
(
rec_reuslt
))
if
isinstance
(
post_result
,
dict
):
if
len
(
rec_reuslt
)
>=
2
:
rec_info
=
dict
()
fout
.
write
(
file
+
"
\t
"
+
rec_reuslt
[
0
]
+
"
\t
"
+
str
(
for
key
in
post_result
:
rec_reuslt
[
1
])
+
"
\n
"
)
if
len
(
post_result
[
key
][
0
])
>=
2
:
rec_info
[
key
]
=
{
"label"
:
post_result
[
key
][
0
][
0
],
"score"
:
post_result
[
key
][
0
][
1
],
}
info
=
json
.
dumps
(
rec_info
)
else
:
if
len
(
post_result
[
0
])
>=
2
:
info
=
post_result
[
0
][
0
]
+
"
\t
"
+
str
(
post_result
[
0
][
1
])
if
info
is
not
None
:
logger
.
info
(
"
\t
result: {}"
.
format
(
info
))
fout
.
write
(
file
+
"
\t
"
+
info
)
logger
.
info
(
"success!"
)
logger
.
info
(
"success!"
)
...
...
tools/infer_table.py
0 → 100644
View file @
465ef3bf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
sys
import
json
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
paddle
from
paddle.jit
import
to_static
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
import
cv2
def
main
(
config
,
device
,
logger
,
vdl_writer
):
global_config
=
config
[
'Global'
]
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
# build model
if
hasattr
(
post_process_class
,
'character'
):
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
len
(
getattr
(
post_process_class
,
'character'
))
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
# create data ops
transforms
=
[]
use_padding
=
False
for
op
in
config
[
'Eval'
][
'dataset'
][
'transforms'
]:
op_name
=
list
(
op
)[
0
]
if
'Label'
in
op_name
:
continue
if
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
if
op_name
==
"ResizeTableImage"
:
use_padding
=
True
padding_max_len
=
op
[
'ResizeTableImage'
][
'max_len'
]
transforms
.
append
(
op
)
global_config
[
'infer_mode'
]
=
True
ops
=
create_operators
(
transforms
,
global_config
)
model
.
eval
()
for
file
in
get_image_file_list
(
config
[
'Global'
][
'infer_img'
]):
logger
.
info
(
"infer_img: {}"
.
format
(
file
))
with
open
(
file
,
'rb'
)
as
f
:
img
=
f
.
read
()
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
ops
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
)
res_html_code
=
post_result
[
'res_html_code'
]
res_loc
=
post_result
[
'res_loc'
]
img
=
cv2
.
imread
(
file
)
imgh
,
imgw
=
img
.
shape
[
0
:
2
]
res_loc_final
=
[]
for
rno
in
range
(
len
(
res_loc
[
0
])):
x0
,
y0
,
x1
,
y1
=
res_loc
[
0
][
rno
]
left
=
max
(
int
(
imgw
*
x0
),
0
)
top
=
max
(
int
(
imgh
*
y0
),
0
)
right
=
min
(
int
(
imgw
*
x1
),
imgw
-
1
)
bottom
=
min
(
int
(
imgh
*
y1
),
imgh
-
1
)
cv2
.
rectangle
(
img
,
(
left
,
top
),
(
right
,
bottom
),
(
0
,
0
,
255
),
2
)
res_loc_final
.
append
([
left
,
top
,
right
,
bottom
])
res_loc_str
=
json
.
dumps
(
res_loc_final
)
logger
.
info
(
"result: {}, {}"
.
format
(
res_html_code
,
res_loc_final
))
logger
.
info
(
"success!"
)
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
(
config
,
device
,
logger
,
vdl_writer
)
tools/program.py
View file @
465ef3bf
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -186,6 +186,7 @@ def train(config,
...
@@ -186,6 +186,7 @@ def train(config,
model
.
train
()
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
model_type
=
config
[
'Architecture'
][
'model_type'
]
if
'start_epoch'
in
best_model_dict
:
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
start_epoch
=
best_model_dict
[
'start_epoch'
]
...
@@ -208,9 +209,9 @@ def train(config,
...
@@ -208,9 +209,9 @@ def train(config,
lr
=
optimizer
.
get_lr
()
lr
=
optimizer
.
get_lr
()
images
=
batch
[
0
]
images
=
batch
[
0
]
if
use_srn
:
if
use_srn
:
others
=
batch
[
-
4
:]
preds
=
model
(
images
,
others
)
model_average
=
True
model_average
=
True
if
use_srn
or
model_type
==
'table'
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
...
@@ -232,6 +233,9 @@ def train(config,
...
@@ -232,6 +233,9 @@ def train(config,
if
cal_metric_during_train
:
# only rec and cls need
if
cal_metric_during_train
:
# only rec and cls need
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch
=
[
item
.
numpy
()
for
item
in
batch
]
if
model_type
==
'table'
:
eval_class
(
preds
,
batch
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
eval_class
(
post_result
,
batch
)
metric
=
eval_class
.
get_metric
()
metric
=
eval_class
.
get_metric
()
...
@@ -269,6 +273,7 @@ def train(config,
...
@@ -269,6 +273,7 @@ def train(config,
valid_dataloader
,
valid_dataloader
,
post_process_class
,
post_process_class
,
eval_class
,
eval_class
,
model_type
,
use_srn
=
use_srn
)
use_srn
=
use_srn
)
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
...
@@ -336,7 +341,11 @@ def train(config,
...
@@ -336,7 +341,11 @@ def train(config,
return
return
def
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
def
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
model_type
,
use_srn
=
False
):
use_srn
=
False
):
model
.
eval
()
model
.
eval
()
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
...
@@ -350,18 +359,18 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
...
@@ -350,18 +359,18 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
break
break
images
=
batch
[
0
]
images
=
batch
[
0
]
start
=
time
.
time
()
start
=
time
.
time
()
if
use_srn
or
model_type
==
'table'
:
if
use_srn
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
others
=
batch
[
-
4
:]
preds
=
model
(
images
,
others
)
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch
=
[
item
.
numpy
()
for
item
in
batch
]
# Obtain usable results from post-processing methods
# Obtain usable results from post-processing methods
post_result
=
post_process_class
(
preds
,
batch
[
1
])
total_time
+=
time
.
time
()
-
start
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
# Evaluate the results of the current batch
if
model_type
==
'table'
:
eval_class
(
preds
,
batch
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
eval_class
(
post_result
,
batch
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
total_frame
+=
len
(
images
)
total_frame
+=
len
(
images
)
...
@@ -386,7 +395,7 @@ def preprocess(is_train=False):
...
@@ -386,7 +395,7 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
'CLS'
,
'PGNet'
,
'Distillation'
,
'TableAttn'
]
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
tools/train.py
View file @
465ef3bf
...
@@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer):
...
@@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm
# for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
if
config
[
'Global'
][
'distributed'
]:
if
config
[
'Global'
][
'distributed'
]:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
...
@@ -90,7 +97,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -90,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
pre_best_model_dict
=
init_model
(
config
,
model
,
optimizer
)
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
if
valid_dataloader
is
not
None
:
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment