Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
adc62fcd
"vscode:/vscode.git/clone" did not exist on "97e13de2b40a95ea8510766e8073a9ca6ba6d945"
Unverified
Commit
adc62fcd
authored
Aug 17, 2021
by
topduke
Committed by
GitHub
Aug 17, 2021
Browse files
Merge branch 'dygraph' into dygraph
parents
8227ad1b
a81b88a0
Changes
152
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
376 additions
and
646 deletions
+376
-646
tools/eval.py
tools/eval.py
+15
-4
tools/export_model.py
tools/export_model.py
+2
-1
tools/infer/benchmark_utils.py
tools/infer/benchmark_utils.py
+0
-232
tools/infer/predict_cls.py
tools/infer/predict_cls.py
+1
-14
tools/infer/predict_det.py
tools/infer/predict_det.py
+40
-59
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+37
-59
tools/infer/predict_system.py
tools/infer/predict_system.py
+34
-104
tools/infer/utility.py
tools/infer/utility.py
+105
-157
tools/infer_det.py
tools/infer_det.py
+1
-1
tools/infer_table.py
tools/infer_table.py
+107
-0
tools/program.py
tools/program.py
+32
-12
tools/train.py
tools/train.py
+2
-3
No files found.
tools/eval.py
View file @
adc62fcd
...
...
@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
,
load_pretrained_params
from
ppocr.utils.utility
import
print_dict
import
tools.program
as
program
...
...
@@ -44,10 +44,21 @@ def main():
# build model
# for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
len
(
getattr
(
post_process_class
,
'character'
))
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
if
"model_type"
in
config
[
'Architecture'
].
keys
():
model_type
=
config
[
'Architecture'
][
'model_type'
]
else
:
model_type
=
None
best_model_dict
=
init_model
(
config
,
model
)
if
len
(
best_model_dict
):
...
...
@@ -60,7 +71,7 @@ def main():
# start eval
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
use_srn
)
eval_class
,
model_type
,
use_srn
)
logger
.
info
(
'metric eval ***************'
)
for
k
,
v
in
metric
.
items
():
logger
.
info
(
'{}:{}'
.
format
(
k
,
v
))
...
...
tools/export_model.py
View file @
adc62fcd
...
...
@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape
[
-
1
]
=
100
elif
arch_config
[
"model_type"
]
==
"table"
:
infer_shape
=
[
3
,
488
,
488
]
model
=
to_static
(
model
,
input_spec
=
[
...
...
tools/infer/benchmark_utils.py
deleted
100644 → 0
View file @
8227ad1b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
time
import
logging
import
paddle
import
paddle.inference
as
paddle_infer
from
pathlib
import
Path
CUR_DIR
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
class
PaddleInferBenchmark
(
object
):
def
__init__
(
self
,
config
,
model_info
:
dict
=
{},
data_info
:
dict
=
{},
perf_info
:
dict
=
{},
resource_info
:
dict
=
{},
save_log_path
:
str
=
""
,
**
kwargs
):
"""
Construct PaddleInferBenchmark Class to format logs.
args:
config(paddle.inference.Config): paddle inference config
model_info(dict): basic model info
{'model_name': 'resnet50'
'precision': 'fp32'}
data_info(dict): input data info
{'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info(dict): performance result
{'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info(dict):
cpu and gpu resources
{'cpu_rss': 100
'gpu_rss': 100
'gpu_util': 60}
"""
# PaddleInferBenchmark Log Version
self
.
log_version
=
1.0
# Paddle Version
self
.
paddle_version
=
paddle
.
__version__
self
.
paddle_commit
=
paddle
.
__git_commit__
paddle_infer_info
=
paddle_infer
.
get_version
()
self
.
paddle_branch
=
paddle_infer_info
.
strip
().
split
(
': '
)[
-
1
]
# model info
self
.
model_info
=
model_info
# data info
self
.
data_info
=
data_info
# perf info
self
.
perf_info
=
perf_info
try
:
self
.
model_name
=
model_info
[
'model_name'
]
self
.
precision
=
model_info
[
'precision'
]
self
.
batch_size
=
data_info
[
'batch_size'
]
self
.
shape
=
data_info
[
'shape'
]
self
.
data_num
=
data_info
[
'data_num'
]
self
.
preprocess_time_s
=
round
(
perf_info
[
'preprocess_time_s'
],
4
)
self
.
inference_time_s
=
round
(
perf_info
[
'inference_time_s'
],
4
)
self
.
postprocess_time_s
=
round
(
perf_info
[
'postprocess_time_s'
],
4
)
self
.
total_time_s
=
round
(
perf_info
[
'total_time_s'
],
4
)
except
:
self
.
print_help
()
raise
ValueError
(
"Set argument wrong, please check input argument and its type"
)
# conf info
self
.
config_status
=
self
.
parse_config
(
config
)
self
.
save_log_path
=
save_log_path
# mem info
if
isinstance
(
resource_info
,
dict
):
self
.
cpu_rss_mb
=
int
(
resource_info
.
get
(
'cpu_rss_mb'
,
0
))
self
.
gpu_rss_mb
=
int
(
resource_info
.
get
(
'gpu_rss_mb'
,
0
))
self
.
gpu_util
=
round
(
resource_info
.
get
(
'gpu_util'
,
0
),
2
)
else
:
self
.
cpu_rss_mb
=
0
self
.
gpu_rss_mb
=
0
self
.
gpu_util
=
0
# init benchmark logger
self
.
benchmark_logger
()
def
benchmark_logger
(
self
):
"""
benchmark logger
"""
# Init logger
FORMAT
=
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
log_output
=
f
"
{
self
.
save_log_path
}
/
{
self
.
model_name
}
.log"
Path
(
f
"
{
self
.
save_log_path
}
"
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
,
handlers
=
[
logging
.
FileHandler
(
filename
=
log_output
,
mode
=
'w'
),
logging
.
StreamHandler
(),
])
self
.
logger
=
logging
.
getLogger
(
__name__
)
self
.
logger
.
info
(
f
"Paddle Inference benchmark log will be saved to
{
log_output
}
"
)
def
parse_config
(
self
,
config
)
->
dict
:
"""
parse paddle predictor config
args:
config(paddle.inference.Config): paddle inference config
return:
config_status(dict): dict style config info
"""
config_status
=
{}
config_status
[
'runtime_device'
]
=
"gpu"
if
config
.
use_gpu
()
else
"cpu"
config_status
[
'ir_optim'
]
=
config
.
ir_optim
()
config_status
[
'enable_tensorrt'
]
=
config
.
tensorrt_engine_enabled
()
config_status
[
'precision'
]
=
self
.
precision
config_status
[
'enable_mkldnn'
]
=
config
.
mkldnn_enabled
()
config_status
[
'cpu_math_library_num_threads'
]
=
config
.
cpu_math_library_num_threads
(
)
return
config_status
def
report
(
self
,
identifier
=
None
):
"""
print log report
args:
identifier(string): identify log
"""
if
identifier
:
identifier
=
f
"[
{
identifier
}
]"
else
:
identifier
=
""
self
.
logger
.
info
(
"
\n
"
)
self
.
logger
.
info
(
"---------------------- Paddle info ----------------------"
)
self
.
logger
.
info
(
f
"
{
identifier
}
paddle_version:
{
self
.
paddle_version
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
paddle_commit:
{
self
.
paddle_commit
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
paddle_branch:
{
self
.
paddle_branch
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
log_api_version:
{
self
.
log_version
}
"
)
self
.
logger
.
info
(
"----------------------- Conf info -----------------------"
)
self
.
logger
.
info
(
f
"
{
identifier
}
runtime_device:
{
self
.
config_status
[
'runtime_device'
]
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
ir_optim:
{
self
.
config_status
[
'ir_optim'
]
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
enable_memory_optim:
{
True
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
enable_tensorrt:
{
self
.
config_status
[
'enable_tensorrt'
]
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
enable_mkldnn:
{
self
.
config_status
[
'enable_mkldnn'
]
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
cpu_math_library_num_threads:
{
self
.
config_status
[
'cpu_math_library_num_threads'
]
}
"
)
self
.
logger
.
info
(
"----------------------- Model info ----------------------"
)
self
.
logger
.
info
(
f
"
{
identifier
}
model_name:
{
self
.
model_name
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
precision:
{
self
.
precision
}
"
)
self
.
logger
.
info
(
"----------------------- Data info -----------------------"
)
self
.
logger
.
info
(
f
"
{
identifier
}
batch_size:
{
self
.
batch_size
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
input_shape:
{
self
.
shape
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
data_num:
{
self
.
data_num
}
"
)
self
.
logger
.
info
(
"----------------------- Perf info -----------------------"
)
self
.
logger
.
info
(
f
"
{
identifier
}
cpu_rss(MB):
{
self
.
cpu_rss_mb
}
, gpu_rss(MB):
{
self
.
gpu_rss_mb
}
, gpu_util:
{
self
.
gpu_util
}
%"
)
self
.
logger
.
info
(
f
"
{
identifier
}
total time spent(s):
{
self
.
total_time_s
}
"
)
self
.
logger
.
info
(
f
"
{
identifier
}
preprocess_time(ms):
{
round
(
self
.
preprocess_time_s
*
1000
,
1
)
}
, inference_time(ms):
{
round
(
self
.
inference_time_s
*
1000
,
1
)
}
, postprocess_time(ms):
{
round
(
self
.
postprocess_time_s
*
1000
,
1
)
}
"
)
def
print_help
(
self
):
"""
print function help
"""
print
(
"""Usage:
==== Print inference benchmark logs. ====
config = paddle.inference.Config()
model_info = {'model_name': 'resnet50'
'precision': 'fp32'}
data_info = {'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info = {'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info = {'cpu_rss_mb': 100
'gpu_rss_mb': 100
'gpu_util': 60}
log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info)
log('Test')
"""
)
def
__call__
(
self
,
identifier
=
None
):
"""
__call__
args:
identifier(string): identify log
"""
self
.
report
(
identifier
)
tools/infer/predict_cls.py
View file @
adc62fcd
...
...
@@ -48,8 +48,6 @@ class TextClassifier(object):
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
_
=
\
utility
.
create_predictor
(
args
,
'cls'
,
logger
)
self
.
cls_times
=
utility
.
Timer
()
def
resize_norm_img
(
self
,
img
):
imgC
,
imgH
,
imgW
=
self
.
cls_image_shape
h
=
img
.
shape
[
0
]
...
...
@@ -85,35 +83,28 @@ class TextClassifier(object):
cls_res
=
[[
''
,
0.0
]]
*
img_num
batch_num
=
self
.
cls_batch_num
elapse
=
0
self
.
cls_times
.
total_time
.
start
()
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
norm_img_batch
=
[]
max_wh_ratio
=
0
starttime
=
time
.
time
()
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
self
.
cls_times
.
preprocess_time
.
start
()
for
ino
in
range
(
beg_img_no
,
end_img_no
):
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]])
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
norm_img_batch
.
copy
()
starttime
=
time
.
time
()
self
.
cls_times
.
preprocess_time
.
end
()
self
.
cls_times
.
inference_time
.
start
()
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
run
()
prob_out
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
self
.
cls_times
.
inference_time
.
end
()
self
.
cls_times
.
postprocess_time
.
start
()
self
.
predictor
.
try_shrink_memory
()
cls_result
=
self
.
postprocess_op
(
prob_out
)
self
.
cls_times
.
postprocess_time
.
end
()
elapse
+=
time
.
time
()
-
starttime
for
rno
in
range
(
len
(
cls_result
)):
label
,
score
=
cls_result
[
rno
]
...
...
@@ -121,9 +112,6 @@ class TextClassifier(object):
if
'180'
in
label
and
score
>
self
.
cls_thresh
:
img_list
[
indices
[
beg_img_no
+
rno
]]
=
cv2
.
rotate
(
img_list
[
indices
[
beg_img_no
+
rno
]],
1
)
self
.
cls_times
.
total_time
.
end
()
self
.
cls_times
.
img_num
+=
img_num
elapse
=
self
.
cls_times
.
total_time
.
value
()
return
img_list
,
cls_res
,
elapse
...
...
@@ -157,7 +145,6 @@ def main(args):
cls_res
[
ino
]))
logger
.
info
(
"The predict time about text angle classify module is as follows: "
)
text_classifier
.
cls_times
.
info
(
average
=
False
)
if
__name__
==
"__main__"
:
...
...
tools/infer/predict_det.py
View file @
adc62fcd
...
...
@@ -31,8 +31,6 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from
ppocr.data
import
create_operators
,
transform
from
ppocr.postprocess
import
build_post_process
import
tools.infer.benchmark_utils
as
benchmark_utils
logger
=
get_logger
()
...
...
@@ -43,7 +41,7 @@ class TextDetector(object):
pre_process_list
=
[{
'DetResizeForTest'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
'limit_type'
:
args
.
det_limit_type
,
}
},
{
'NormalizeImage'
:
{
...
...
@@ -100,7 +98,24 @@ class TextDetector(object):
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
utility
.
create_predictor
(
args
,
'det'
,
logger
)
self
.
det_times
=
utility
.
Timer
()
if
args
.
benchmark
:
import
auto_log
pid
=
os
.
getpid
()
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
"det"
,
model_precision
=
args
.
precision
,
batch_size
=
1
,
data_shape
=
"dynamic"
,
save_path
=
None
,
inference_config
=
self
.
config
,
pids
=
pid
,
process_name
=
None
,
gpu_ids
=
0
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
2
,
logger
=
logger
)
def
order_points_clockwise
(
self
,
pts
):
"""
...
...
@@ -158,8 +173,12 @@ class TextDetector(object):
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
self
.
det_times
.
total_time
.
start
()
self
.
det_times
.
preprocess_time
.
start
()
st
=
time
.
time
()
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
start
()
data
=
transform
(
data
,
self
.
preprocess_op
)
img
,
shape_list
=
data
if
img
is
None
:
...
...
@@ -168,8 +187,8 @@ class TextDetector(object):
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
img
=
img
.
copy
()
self
.
det_times
.
preprocess_time
.
end
()
self
.
det_times
.
inference_time
.
sta
rt
()
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
sta
mp
()
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
...
...
@@ -177,7 +196,8 @@ class TextDetector(object):
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
self
.
det_times
.
inference_time
.
end
()
if
self
.
args
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
{}
if
self
.
det_algorithm
==
"EAST"
:
...
...
@@ -193,9 +213,7 @@ class TextDetector(object):
else
:
raise
NotImplementedError
self
.
det_times
.
postprocess_time
.
start
()
self
.
predictor
.
try_shrink_memory
()
#self.predictor.try_shrink_memory()
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
if
self
.
det_algorithm
==
"SAST"
and
self
.
det_sast_polygon
:
...
...
@@ -203,10 +221,10 @@ class TextDetector(object):
else
:
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
self
.
det_times
.
postprocess_time
.
end
()
self
.
det_times
.
total_
time
.
end
()
self
.
det_times
.
img_num
+=
1
return
dt_boxes
,
self
.
det_times
.
total_time
.
value
()
if
self
.
args
.
benchmark
:
self
.
autolog
.
time
s
.
end
(
stamp
=
True
)
et
=
time
.
time
()
return
dt_boxes
,
et
-
st
if
__name__
==
"__main__"
:
...
...
@@ -216,12 +234,11 @@ if __name__ == "__main__":
count
=
0
total_time
=
0
draw_img_save
=
"./inference_results"
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
# warmup 10 times
fake_
img
=
np
.
random
.
uniform
(
-
1
,
1
,
[
640
,
640
,
3
]).
astype
(
np
.
float32
)
for
i
in
range
(
10
):
dt_boxes
,
_
=
text_detector
(
fake_
img
)
if
args
.
warmup
:
img
=
np
.
random
.
uniform
(
0
,
255
,
[
640
,
640
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
2
):
res
=
text_detector
(
img
)
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
...
...
@@ -239,49 +256,13 @@ if __name__ == "__main__":
total_time
+=
elapse
count
+=
1
if
args
.
benchmark
:
cm
,
gm
,
gu
=
utility
.
get_current_memory_mb
(
0
)
cpu_mem
+=
cm
gpu_mem
+=
gm
gpu_util
+=
gu
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
src_im
=
utility
.
draw_text_det_res
(
dt_boxes
,
image_file
)
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_path
=
os
.
path
.
join
(
draw_img_save
,
"det_res_{}"
.
format
(
img_name_pure
))
cv2
.
imwrite
(
img_path
,
src_im
)
logger
.
info
(
"The visualized image saved in {}"
.
format
(
img_path
))
# print the information about memory and time-spent
if
args
.
benchmark
:
mems
=
{
'cpu_rss_mb'
:
cpu_mem
/
count
,
'gpu_rss_mb'
:
gpu_mem
/
count
,
'gpu_util'
:
gpu_util
*
100
/
count
}
else
:
mems
=
None
logger
.
info
(
"The predict time about detection module is as follows: "
)
det_time_dict
=
text_detector
.
det_times
.
report
(
average
=
True
)
det_model_name
=
args
.
det_model_dir
if
args
.
benchmark
:
# construct log information
model_info
=
{
'model_name'
:
args
.
det_model_dir
.
split
(
'/'
)[
-
1
],
'precision'
:
args
.
precision
}
data_info
=
{
'batch_size'
:
1
,
'shape'
:
'dynamic_shape'
,
'data_num'
:
det_time_dict
[
'img_num'
]
}
perf_info
=
{
'preprocess_time_s'
:
det_time_dict
[
'preprocess_time'
],
'inference_time_s'
:
det_time_dict
[
'inference_time'
],
'postprocess_time_s'
:
det_time_dict
[
'postprocess_time'
],
'total_time_s'
:
det_time_dict
[
'total_time'
]
}
benchmark_log
=
benchmark_utils
.
PaddleInferBenchmark
(
text_detector
.
config
,
model_info
,
data_info
,
perf_info
,
mems
)
benchmark_log
(
"Det"
)
text_detector
.
autolog
.
report
()
tools/infer/predict_rec.py
View file @
adc62fcd
...
...
@@ -28,7 +28,6 @@ import traceback
import
paddle
import
tools.infer.utility
as
utility
import
tools.infer.benchmark_utils
as
benchmark_utils
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
...
...
@@ -65,8 +64,25 @@ class TextRecognizer(object):
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
self
.
rec_times
=
utility
.
Timer
()
self
.
benchmark
=
args
.
benchmark
if
args
.
benchmark
:
import
auto_log
pid
=
os
.
getpid
()
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
"rec"
,
model_precision
=
args
.
precision
,
batch_size
=
args
.
rec_batch_num
,
data_shape
=
"dynamic"
,
save_path
=
None
,
#args.save_log_path,
inference_config
=
self
.
config
,
pids
=
pid
,
process_name
=
None
,
gpu_ids
=
0
if
args
.
use_gpu
else
None
,
time_keys
=
[
'preprocess_time'
,
'inference_time'
,
'postprocess_time'
],
warmup
=
2
,
logger
=
logger
)
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
...
...
@@ -168,14 +184,15 @@ class TextRecognizer(object):
width_list
.
append
(
img
.
shape
[
1
]
/
float
(
img
.
shape
[
0
]))
# Sorting can speed up the recognition process
indices
=
np
.
argsort
(
np
.
array
(
width_list
))
self
.
rec_times
.
total_time
.
start
()
rec_res
=
[[
''
,
0.0
]]
*
img_num
batch_num
=
self
.
rec_batch_num
st
=
time
.
time
()
if
self
.
benchmark
:
self
.
autolog
.
times
.
start
()
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
norm_img_batch
=
[]
max_wh_ratio
=
0
self
.
rec_times
.
preprocess_time
.
start
()
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
wh_ratio
=
w
*
1.0
/
h
...
...
@@ -200,6 +217,8 @@ class TextRecognizer(object):
norm_img_batch
.
append
(
norm_img
[
0
])
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
norm_img_batch
.
copy
()
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
if
self
.
rec_algorithm
==
"SRN"
:
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
)
...
...
@@ -216,23 +235,20 @@ class TextRecognizer(object):
gsrm_slf_attn_bias1_list
,
gsrm_slf_attn_bias2_list
,
]
self
.
rec_times
.
preprocess_time
.
end
()
self
.
rec_times
.
inference_time
.
start
()
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
self
.
rec_times
.
inference_time
.
end
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
{
"predict"
:
outputs
[
2
]}
else
:
self
.
rec_times
.
preprocess_time
.
end
()
self
.
rec_times
.
inference_time
.
start
()
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
run
()
...
...
@@ -240,16 +256,15 @@ class TextRecognizer(object):
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
outputs
[
0
]
self
.
rec_times
.
inference_time
.
end
()
self
.
rec_times
.
postprocess_time
.
start
()
rec_result
=
self
.
postprocess_op
(
preds
)
for
rno
in
range
(
len
(
rec_result
)):
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
self
.
rec_times
.
postprocess_time
.
end
()
self
.
rec_times
.
img_num
+=
int
(
norm_img_batch
.
shape
[
0
])
self
.
rec_times
.
total_time
.
end
()
return
rec_res
,
self
.
rec_times
.
total_time
.
value
()
if
self
.
benchmark
:
self
.
autolog
.
times
.
end
(
stamp
=
True
)
return
rec_res
,
time
.
time
()
-
st
def
main
(
args
):
...
...
@@ -257,13 +272,12 @@ def main(args):
text_recognizer
=
TextRecognizer
(
args
)
valid_image_file_list
=
[]
img_list
=
[]
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
count
=
0
# warmup 10 times
fake_img
=
np
.
random
.
uniform
(
-
1
,
1
,
[
1
,
32
,
320
,
3
]).
astype
(
np
.
float32
)
for
i
in
range
(
10
):
dt_boxes
,
_
=
text_recognizer
(
fake_img
)
# warmup 2 times
if
args
.
warmup
:
img
=
np
.
random
.
uniform
(
0
,
255
,
[
32
,
320
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
2
):
res
=
text_recognizer
([
img
])
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
...
...
@@ -276,12 +290,6 @@ def main(args):
img_list
.
append
(
img
)
try
:
rec_res
,
_
=
text_recognizer
(
img_list
)
if
args
.
benchmark
:
cm
,
gm
,
gu
=
utility
.
get_current_memory_mb
(
0
)
cpu_mem
+=
cm
gpu_mem
+=
gm
gpu_util
+=
gu
count
+=
1
except
Exception
as
E
:
logger
.
info
(
traceback
.
format_exc
())
...
...
@@ -291,37 +299,7 @@ def main(args):
logger
.
info
(
"Predicts of {}:{}"
.
format
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
if
args
.
benchmark
:
mems
=
{
'cpu_rss_mb'
:
cpu_mem
/
count
,
'gpu_rss_mb'
:
gpu_mem
/
count
,
'gpu_util'
:
gpu_util
*
100
/
count
}
else
:
mems
=
None
logger
.
info
(
"The predict time about recognizer module is as follows: "
)
rec_time_dict
=
text_recognizer
.
rec_times
.
report
(
average
=
True
)
rec_model_name
=
args
.
rec_model_dir
if
args
.
benchmark
:
# construct log information
model_info
=
{
'model_name'
:
args
.
rec_model_dir
.
split
(
'/'
)[
-
1
],
'precision'
:
args
.
precision
}
data_info
=
{
'batch_size'
:
args
.
rec_batch_num
,
'shape'
:
'dynamic_shape'
,
'data_num'
:
rec_time_dict
[
'img_num'
]
}
perf_info
=
{
'preprocess_time_s'
:
rec_time_dict
[
'preprocess_time'
],
'inference_time_s'
:
rec_time_dict
[
'inference_time'
],
'postprocess_time_s'
:
rec_time_dict
[
'postprocess_time'
],
'total_time_s'
:
rec_time_dict
[
'total_time'
]
}
benchmark_log
=
benchmark_utils
.
PaddleInferBenchmark
(
text_recognizer
.
config
,
model_info
,
data_info
,
perf_info
,
mems
)
benchmark_log
(
"Rec"
)
text_recognizer
.
autolog
.
report
()
if
__name__
==
"__main__"
:
...
...
tools/infer/predict_system.py
View file @
adc62fcd
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
os
import
sys
import
subprocess
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
...
...
@@ -24,6 +25,7 @@ import cv2
import
copy
import
numpy
as
np
import
time
import
logging
from
PIL
import
Image
import
tools.infer.utility
as
utility
import
tools.infer.predict_rec
as
predict_rec
...
...
@@ -31,13 +33,15 @@ import tools.infer.predict_det as predict_det
import
tools.infer.predict_cls
as
predict_cls
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.logging
import
get_logger
from
tools.infer.utility
import
draw_ocr_box_txt
,
get_current_memory_mb
import
tools.infer.benchmark_utils
as
benchmark_utils
from
tools.infer.utility
import
draw_ocr_box_txt
,
get_rotate_crop_image
logger
=
get_logger
()
class
TextSystem
(
object
):
def
__init__
(
self
,
args
):
if
not
args
.
show_log
:
logger
.
setLevel
(
logging
.
INFO
)
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
use_angle_cls
=
args
.
use_angle_cls
...
...
@@ -45,39 +49,6 @@ class TextSystem(object):
if
self
.
use_angle_cls
:
self
.
text_classifier
=
predict_cls
.
TextClassifier
(
args
)
def
get_rotate_crop_image
(
self
,
img
,
points
):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
img_crop_width
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
1
]),
np
.
linalg
.
norm
(
points
[
2
]
-
points
[
3
])))
img_crop_height
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
3
]),
np
.
linalg
.
norm
(
points
[
1
]
-
points
[
2
])))
pts_std
=
np
.
float32
([[
0
,
0
],
[
img_crop_width
,
0
],
[
img_crop_width
,
img_crop_height
],
[
0
,
img_crop_height
]])
M
=
cv2
.
getPerspectiveTransform
(
points
,
pts_std
)
dst_img
=
cv2
.
warpPerspective
(
img
,
M
,
(
img_crop_width
,
img_crop_height
),
borderMode
=
cv2
.
BORDER_REPLICATE
,
flags
=
cv2
.
INTER_CUBIC
)
dst_img_height
,
dst_img_width
=
dst_img
.
shape
[
0
:
2
]
if
dst_img_height
*
1.0
/
dst_img_width
>=
1.5
:
dst_img
=
np
.
rot90
(
dst_img
)
return
dst_img
def
print_draw_crop_rec_res
(
self
,
img_crop_list
,
rec_res
):
bbox_num
=
len
(
img_crop_list
)
for
bno
in
range
(
bbox_num
):
...
...
@@ -88,8 +59,7 @@ class TextSystem(object):
ori_im
=
img
.
copy
()
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
logger
.
info
(
"dt_boxes num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
if
dt_boxes
is
None
:
return
None
,
None
...
...
@@ -99,16 +69,16 @@ class TextSystem(object):
for
bno
in
range
(
len
(
dt_boxes
)):
tmp_box
=
copy
.
deepcopy
(
dt_boxes
[
bno
])
img_crop
=
self
.
get_rotate_crop_image
(
ori_im
,
tmp_box
)
img_crop
=
get_rotate_crop_image
(
ori_im
,
tmp_box
)
img_crop_list
.
append
(
img_crop
)
if
self
.
use_angle_cls
and
cls
:
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
img_crop_list
)
logger
.
info
(
"cls num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"cls num : {}, elapse : {}"
.
format
(
len
(
img_crop_list
),
elapse
))
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
logger
.
info
(
"rec_res num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes
,
filter_rec_res
=
[],
[]
...
...
@@ -143,15 +113,24 @@ def sorted_boxes(dt_boxes):
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
image_file_list
[
args
.
process_id
::
args
.
total_process_num
]
text_sys
=
TextSystem
(
args
)
is_visualize
=
True
font_path
=
args
.
vis_font_path
drop_score
=
args
.
drop_score
# warm up 10 times
if
args
.
warmup
:
img
=
np
.
random
.
uniform
(
0
,
255
,
[
640
,
640
,
3
]).
astype
(
np
.
uint8
)
for
i
in
range
(
10
):
res
=
text_sys
(
img
)
total_time
=
0
cpu_mem
,
gpu_mem
,
gpu_util
=
0
,
0
,
0
_st
=
time
.
time
()
count
=
0
for
idx
,
image_file
in
enumerate
(
image_file_list
):
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
...
...
@@ -162,12 +141,6 @@ def main(args):
dt_boxes
,
rec_res
=
text_sys
(
img
)
elapse
=
time
.
time
()
-
starttime
total_time
+=
elapse
if
args
.
benchmark
and
idx
%
20
==
0
:
cm
,
gm
,
gu
=
get_current_memory_mb
(
0
)
cpu_mem
+=
cm
gpu_mem
+=
gm
gpu_util
+=
gu
count
+=
1
logger
.
info
(
str
(
idx
)
+
" Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
...
...
@@ -201,63 +174,20 @@ def main(args):
logger
.
info
(
"The predict total time is {}"
.
format
(
time
.
time
()
-
_st
))
logger
.
info
(
"
\n
The predict total time is {}"
.
format
(
total_time
))
img_num
=
text_sys
.
text_detector
.
det_times
.
img_num
if
args
.
benchmark
:
mems
=
{
'cpu_rss_mb'
:
cpu_mem
/
count
,
'gpu_rss_mb'
:
gpu_mem
/
count
,
'gpu_util'
:
gpu_util
*
100
/
count
}
else
:
mems
=
None
det_time_dict
=
text_sys
.
text_detector
.
det_times
.
report
(
average
=
True
)
rec_time_dict
=
text_sys
.
text_recognizer
.
rec_times
.
report
(
average
=
True
)
det_model_name
=
args
.
det_model_dir
rec_model_name
=
args
.
rec_model_dir
# construct det log information
model_info
=
{
'model_name'
:
args
.
det_model_dir
.
split
(
'/'
)[
-
1
],
'precision'
:
args
.
precision
}
data_info
=
{
'batch_size'
:
1
,
'shape'
:
'dynamic_shape'
,
'data_num'
:
det_time_dict
[
'img_num'
]
}
perf_info
=
{
'preprocess_time_s'
:
det_time_dict
[
'preprocess_time'
],
'inference_time_s'
:
det_time_dict
[
'inference_time'
],
'postprocess_time_s'
:
det_time_dict
[
'postprocess_time'
],
'total_time_s'
:
det_time_dict
[
'total_time'
]
}
benchmark_log
=
benchmark_utils
.
PaddleInferBenchmark
(
text_sys
.
text_detector
.
config
,
model_info
,
data_info
,
perf_info
,
mems
,
args
.
save_log_path
)
benchmark_log
(
"Det"
)
# construct rec log information
model_info
=
{
'model_name'
:
args
.
rec_model_dir
.
split
(
'/'
)[
-
1
],
'precision'
:
args
.
precision
}
data_info
=
{
'batch_size'
:
args
.
rec_batch_num
,
'shape'
:
'dynamic_shape'
,
'data_num'
:
rec_time_dict
[
'img_num'
]
}
perf_info
=
{
'preprocess_time_s'
:
rec_time_dict
[
'preprocess_time'
],
'inference_time_s'
:
rec_time_dict
[
'inference_time'
],
'postprocess_time_s'
:
rec_time_dict
[
'postprocess_time'
],
'total_time_s'
:
rec_time_dict
[
'total_time'
]
}
benchmark_log
=
benchmark_utils
.
PaddleInferBenchmark
(
text_sys
.
text_recognizer
.
config
,
model_info
,
data_info
,
perf_info
,
mems
,
args
.
save_log_path
)
benchmark_log
(
"Rec"
)
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
args
=
utility
.
parse_args
()
if
args
.
use_mp
:
p_list
=
[]
total_process_num
=
args
.
total_process_num
for
process_id
in
range
(
total_process_num
):
cmd
=
[
sys
.
executable
,
"-u"
]
+
sys
.
argv
+
[
"--process_id={}"
.
format
(
process_id
),
"--use_mp={}"
.
format
(
False
)
]
p
=
subprocess
.
Popen
(
cmd
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stdout
)
p_list
.
append
(
p
)
for
p
in
p_list
:
p
.
wait
()
else
:
main
(
args
)
tools/infer/utility.py
View file @
adc62fcd
...
...
@@ -24,8 +24,6 @@ from paddle import inference
import
time
from
ppocr.utils.logging
import
get_logger
logger
=
get_logger
()
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
...
...
@@ -37,6 +35,7 @@ def init_args():
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--min_subgraph_size"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"fp32"
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
500
)
...
...
@@ -49,10 +48,10 @@ def init_args():
# DB parmas
parser
.
add_argument
(
"--det_db_thresh"
,
type
=
float
,
default
=
0.3
)
parser
.
add_argument
(
"--det_db_box_thresh"
,
type
=
float
,
default
=
0.
5
)
parser
.
add_argument
(
"--det_db_unclip_ratio"
,
type
=
float
,
default
=
1.
6
)
parser
.
add_argument
(
"--det_db_box_thresh"
,
type
=
float
,
default
=
0.
6
)
parser
.
add_argument
(
"--det_db_unclip_ratio"
,
type
=
float
,
default
=
1.
5
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--use_dilation"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--use_dilation"
,
type
=
str2
bool
,
default
=
False
)
parser
.
add_argument
(
"--det_db_score_mode"
,
type
=
str
,
default
=
"fast"
)
# EAST parmas
parser
.
add_argument
(
"--det_east_score_thresh"
,
type
=
float
,
default
=
0.8
)
...
...
@@ -62,7 +61,7 @@ def init_args():
# SAST parmas
parser
.
add_argument
(
"--det_sast_score_thresh"
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
"--det_sast_nms_thresh"
,
type
=
float
,
default
=
0.2
)
parser
.
add_argument
(
"--det_sast_polygon"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--det_sast_polygon"
,
type
=
str2
bool
,
default
=
False
)
# params for text recognizer
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'CRNN'
)
...
...
@@ -91,7 +90,7 @@ def init_args():
parser
.
add_argument
(
"--e2e_char_dict_path"
,
type
=
str
,
default
=
"./ppocr/utils/ic15_dict.txt"
)
parser
.
add_argument
(
"--e2e_pgnet_valid_set"
,
type
=
str
,
default
=
'totaltext'
)
parser
.
add_argument
(
"--e2e_pgnet_polygon"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--e2e_pgnet_polygon"
,
type
=
str2
bool
,
default
=
True
)
parser
.
add_argument
(
"--e2e_pgnet_mode"
,
type
=
str
,
default
=
'fast'
)
# params for text classifier
...
...
@@ -105,15 +104,17 @@ def init_args():
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--cpu_threads"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--use_pdserving"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--warmup"
,
type
=
str2bool
,
default
=
True
)
# multi-process
parser
.
add_argument
(
"--use_mp"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--total_process_num"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--process_id"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--benchmark"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--save_log_path"
,
type
=
str
,
default
=
"./log_output/"
)
parser
.
add_argument
(
"--benchmark"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--save_log_path"
,
type
=
str
,
default
=
"./log_output/"
)
parser
.
add_argument
(
"--show_log"
,
type
=
str2bool
,
default
=
True
)
return
parser
...
...
@@ -122,76 +123,6 @@ def parse_args():
return
parser
.
parse_args
()
class
Times
(
object
):
def
__init__
(
self
):
self
.
time
=
0.
self
.
st
=
0.
self
.
et
=
0.
def
start
(
self
):
self
.
st
=
time
.
time
()
def
end
(
self
,
accumulative
=
True
):
self
.
et
=
time
.
time
()
if
accumulative
:
self
.
time
+=
self
.
et
-
self
.
st
else
:
self
.
time
=
self
.
et
-
self
.
st
def
reset
(
self
):
self
.
time
=
0.
self
.
st
=
0.
self
.
et
=
0.
def
value
(
self
):
return
round
(
self
.
time
,
4
)
class
Timer
(
Times
):
def
__init__
(
self
):
super
(
Timer
,
self
).
__init__
()
self
.
total_time
=
Times
()
self
.
preprocess_time
=
Times
()
self
.
inference_time
=
Times
()
self
.
postprocess_time
=
Times
()
self
.
img_num
=
0
def
info
(
self
,
average
=
False
):
logger
.
info
(
"----------------------- Perf info -----------------------"
)
logger
.
info
(
"total_time: {}, img_num: {}"
.
format
(
self
.
total_time
.
value
(
),
self
.
img_num
))
preprocess_time
=
round
(
self
.
preprocess_time
.
value
()
/
self
.
img_num
,
4
)
if
average
else
self
.
preprocess_time
.
value
()
postprocess_time
=
round
(
self
.
postprocess_time
.
value
()
/
self
.
img_num
,
4
)
if
average
else
self
.
postprocess_time
.
value
()
inference_time
=
round
(
self
.
inference_time
.
value
()
/
self
.
img_num
,
4
)
if
average
else
self
.
inference_time
.
value
()
average_latency
=
self
.
total_time
.
value
()
/
self
.
img_num
logger
.
info
(
"average_latency(ms): {:.2f}, QPS: {:2f}"
.
format
(
average_latency
*
1000
,
1
/
average_latency
))
logger
.
info
(
"preprocess_latency(ms): {:.2f}, inference_latency(ms): {:.2f}, postprocess_latency(ms): {:.2f}"
.
format
(
preprocess_time
*
1000
,
inference_time
*
1000
,
postprocess_time
*
1000
))
def
report
(
self
,
average
=
False
):
dic
=
{}
dic
[
'preprocess_time'
]
=
round
(
self
.
preprocess_time
.
value
()
/
self
.
img_num
,
4
)
if
average
else
self
.
preprocess_time
.
value
()
dic
[
'postprocess_time'
]
=
round
(
self
.
postprocess_time
.
value
()
/
self
.
img_num
,
4
)
if
average
else
self
.
postprocess_time
.
value
()
dic
[
'inference_time'
]
=
round
(
self
.
inference_time
.
value
()
/
self
.
img_num
,
4
)
if
average
else
self
.
inference_time
.
value
()
dic
[
'img_num'
]
=
self
.
img_num
dic
[
'total_time'
]
=
round
(
self
.
total_time
.
value
(),
4
)
return
dic
def
create_predictor
(
args
,
mode
,
logger
):
if
mode
==
"det"
:
model_dir
=
args
.
det_model_dir
...
...
@@ -199,6 +130,8 @@ def create_predictor(args, mode, logger):
model_dir
=
args
.
cls_model_dir
elif
mode
==
'rec'
:
model_dir
=
args
.
rec_model_dir
elif
mode
==
'table'
:
model_dir
=
args
.
table_model_dir
else
:
model_dir
=
args
.
e2e_model_dir
...
...
@@ -208,11 +141,10 @@ def create_predictor(args, mode, logger):
model_file_path
=
model_dir
+
"/inference.pdmodel"
params_file_path
=
model_dir
+
"/inference.pdiparams"
if
not
os
.
path
.
exists
(
model_file_path
):
logger
.
info
(
"not find model file path {}"
.
format
(
model_file_path
))
sys
.
exit
(
0
)
raise
ValueError
(
"not find model file path {}"
.
format
(
model_file_path
))
if
not
os
.
path
.
exists
(
params_file_path
):
logger
.
info
(
"not find params file path {}"
.
format
(
params_file_path
))
sys
.
exit
(
0
)
raise
ValueError
(
"not find params file path {}"
.
format
(
params_file_path
)
)
config
=
inference
.
Config
(
model_file_path
,
params_file_path
)
...
...
@@ -230,71 +162,74 @@ def create_predictor(args, mode, logger):
config
.
enable_use_gpu
(
args
.
gpu_mem
,
0
)
if
args
.
use_tensorrt
:
config
.
enable_tensorrt_engine
(
precision_mode
=
inference
.
PrecisionType
.
Float32
,
precision_mode
=
precision
,
max_batch_size
=
args
.
max_batch_size
,
min_subgraph_size
=
3
)
# skip the minmum trt subgraph
if
mode
==
"det"
and
"mobile"
in
model_file_path
:
min_subgraph_size
=
args
.
min_subgraph_size
)
# skip the minmum trt subgraph
if
mode
==
"det"
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
50
,
50
],
"conv2d_92.tmp_0"
:
[
1
,
96
,
20
,
20
],
"conv2d_91.tmp_0"
:
[
1
,
96
,
10
,
10
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
96
,
10
,
10
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
20
,
20
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
24
,
20
,
20
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
24
,
20
,
20
],
"conv2d_92.tmp_0"
:
[
1
,
120
,
20
,
20
],
"conv2d_91.tmp_0"
:
[
1
,
24
,
10
,
10
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
256
,
10
,
10
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
256
,
20
,
20
],
"conv2d_124.tmp_0"
:
[
1
,
256
,
20
,
20
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
64
,
20
,
20
],
"elementwise_add_7"
:
[
1
,
56
,
2
,
2
],
"nearest_interp_v2_0.tmp_0"
:
[
1
,
9
6
,
2
,
2
]
"nearest_interp_v2_0.tmp_0"
:
[
1
,
25
6
,
2
,
2
]
}
max_input_shape
=
{
"x"
:
[
1
,
3
,
2000
,
2000
],
"conv2d_92.tmp_0"
:
[
1
,
96
,
400
,
400
],
"conv2d_91.tmp_0"
:
[
1
,
96
,
200
,
200
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
96
,
200
,
200
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
400
,
400
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
24
,
400
,
400
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
24
,
400
,
400
],
"conv2d_92.tmp_0"
:
[
1
,
120
,
400
,
400
],
"conv2d_91.tmp_0"
:
[
1
,
24
,
200
,
200
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
256
,
200
,
200
],
"conv2d_124.tmp_0"
:
[
1
,
256
,
400
,
400
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
256
,
400
,
400
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
64
,
400
,
400
],
"elementwise_add_7"
:
[
1
,
56
,
400
,
400
],
"nearest_interp_v2_0.tmp_0"
:
[
1
,
9
6
,
400
,
400
]
"nearest_interp_v2_0.tmp_0"
:
[
1
,
25
6
,
400
,
400
]
}
opt_input_shape
=
{
"x"
:
[
1
,
3
,
640
,
640
],
"conv2d_92.tmp_0"
:
[
1
,
96
,
160
,
160
],
"conv2d_91.tmp_0"
:
[
1
,
96
,
80
,
80
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
96
,
80
,
80
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
160
,
160
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
24
,
160
,
160
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
24
,
160
,
160
],
"conv2d_92.tmp_0"
:
[
1
,
120
,
160
,
160
],
"conv2d_91.tmp_0"
:
[
1
,
24
,
80
,
80
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_1.tmp_0"
:
[
1
,
256
,
80
,
80
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
256
,
160
,
160
],
"conv2d_124.tmp_0"
:
[
1
,
256
,
160
,
160
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
64
,
160
,
160
],
"elementwise_add_7"
:
[
1
,
56
,
40
,
40
],
"nearest_interp_v2_0.tmp_0"
:
[
1
,
9
6
,
40
,
40
]
"nearest_interp_v2_0.tmp_0"
:
[
1
,
25
6
,
40
,
40
]
}
if
mode
==
"det"
and
"server"
in
model_file_path
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
50
,
50
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
20
,
20
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
20
,
20
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
24
,
20
,
20
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
24
,
20
,
20
]
min_pact_shape
=
{
"nearest_interp_v2_26.tmp_0"
:
[
1
,
256
,
20
,
20
],
"nearest_interp_v2_27.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_28.tmp_0"
:
[
1
,
64
,
20
,
20
],
"nearest_interp_v2_29.tmp_0"
:
[
1
,
64
,
20
,
20
]
}
max_input_shape
=
{
"x"
:
[
1
,
3
,
2000
,
2000
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
400
,
400
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
400
,
400
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
24
,
400
,
400
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
24
,
400
,
400
]
max_pact_shape
=
{
"nearest_interp_v2_26.tmp_0"
:
[
1
,
256
,
400
,
400
],
"nearest_interp_v2_27.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_28.tmp_0"
:
[
1
,
64
,
400
,
400
],
"nearest_interp_v2_29.tmp_0"
:
[
1
,
64
,
400
,
400
]
}
opt_input_shape
=
{
"x"
:
[
1
,
3
,
640
,
640
],
"conv2d_59.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_2.tmp_0"
:
[
1
,
96
,
160
,
160
],
"nearest_interp_v2_3.tmp_0"
:
[
1
,
24
,
160
,
160
],
"nearest_interp_v2_4.tmp_0"
:
[
1
,
24
,
160
,
160
],
"nearest_interp_v2_5.tmp_0"
:
[
1
,
24
,
160
,
160
]
opt_pact_shape
=
{
"nearest_interp_v2_26.tmp_0"
:
[
1
,
256
,
160
,
160
],
"nearest_interp_v2_27.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_28.tmp_0"
:
[
1
,
64
,
160
,
160
],
"nearest_interp_v2_29.tmp_0"
:
[
1
,
64
,
160
,
160
]
}
min_input_shape
.
update
(
min_pact_shape
)
max_input_shape
.
update
(
max_pact_shape
)
opt_input_shape
.
update
(
opt_pact_shape
)
elif
mode
==
"rec"
:
min_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
32
,
10
]}
max_input_shape
=
{
"x"
:
[
args
.
rec_batch_num
,
3
,
32
,
2000
]}
...
...
@@ -324,10 +259,13 @@ def create_predictor(args, mode, logger):
# enable memory optim
config
.
enable_memory_optim
()
config
.
disable_glog_info
()
#
config.disable_glog_info()
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
if
mode
==
'table'
:
config
.
delete_pass
(
"fc_fuse_pass"
)
# not supported for table
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_ir_optim
(
True
)
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
...
...
@@ -590,29 +528,39 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
return
image
def
get_current_memory_mb
(
gpu_id
=
None
):
"""
It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
And this function Current program is time-consuming.
"""
import
pynvml
import
psutil
import
GPUtil
pid
=
os
.
getpid
()
p
=
psutil
.
Process
(
pid
)
info
=
p
.
memory_full_info
()
cpu_mem
=
info
.
uss
/
1024.
/
1024.
gpu_mem
=
0
gpu_percent
=
0
if
gpu_id
is
not
None
:
GPUs
=
GPUtil
.
getGPUs
()
gpu_load
=
GPUs
[
gpu_id
].
load
gpu_percent
=
gpu_load
pynvml
.
nvmlInit
()
handle
=
pynvml
.
nvmlDeviceGetHandleByIndex
(
0
)
meminfo
=
pynvml
.
nvmlDeviceGetMemoryInfo
(
handle
)
gpu_mem
=
meminfo
.
used
/
1024.
/
1024.
return
round
(
cpu_mem
,
4
),
round
(
gpu_mem
,
4
),
round
(
gpu_percent
,
4
)
def
get_rotate_crop_image
(
img
,
points
):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert
len
(
points
)
==
4
,
"shape of points must be 4*2"
img_crop_width
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
1
]),
np
.
linalg
.
norm
(
points
[
2
]
-
points
[
3
])))
img_crop_height
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
3
]),
np
.
linalg
.
norm
(
points
[
1
]
-
points
[
2
])))
pts_std
=
np
.
float32
([[
0
,
0
],
[
img_crop_width
,
0
],
[
img_crop_width
,
img_crop_height
],
[
0
,
img_crop_height
]])
M
=
cv2
.
getPerspectiveTransform
(
points
,
pts_std
)
dst_img
=
cv2
.
warpPerspective
(
img
,
M
,
(
img_crop_width
,
img_crop_height
),
borderMode
=
cv2
.
BORDER_REPLICATE
,
flags
=
cv2
.
INTER_CUBIC
)
dst_img_height
,
dst_img_width
=
dst_img
.
shape
[
0
:
2
]
if
dst_img_height
*
1.0
/
dst_img_width
>=
1.5
:
dst_img
=
np
.
rot90
(
dst_img
)
return
dst_img
if
__name__
==
'__main__'
:
...
...
tools/infer_det.py
View file @
adc62fcd
...
...
@@ -112,4 +112,4 @@ def main():
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
()
\ No newline at end of file
main
()
tools/infer_table.py
0 → 100644
View file @
adc62fcd
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
sys
import
json
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
paddle
from
paddle.jit
import
to_static
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
import
cv2
def
main
(
config
,
device
,
logger
,
vdl_writer
):
global_config
=
config
[
'Global'
]
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
# build model
if
hasattr
(
post_process_class
,
'character'
):
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
len
(
getattr
(
post_process_class
,
'character'
))
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
# create data ops
transforms
=
[]
use_padding
=
False
for
op
in
config
[
'Eval'
][
'dataset'
][
'transforms'
]:
op_name
=
list
(
op
)[
0
]
if
'Label'
in
op_name
:
continue
if
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
if
op_name
==
"ResizeTableImage"
:
use_padding
=
True
padding_max_len
=
op
[
'ResizeTableImage'
][
'max_len'
]
transforms
.
append
(
op
)
global_config
[
'infer_mode'
]
=
True
ops
=
create_operators
(
transforms
,
global_config
)
model
.
eval
()
for
file
in
get_image_file_list
(
config
[
'Global'
][
'infer_img'
]):
logger
.
info
(
"infer_img: {}"
.
format
(
file
))
with
open
(
file
,
'rb'
)
as
f
:
img
=
f
.
read
()
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
ops
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
)
res_html_code
=
post_result
[
'res_html_code'
]
res_loc
=
post_result
[
'res_loc'
]
img
=
cv2
.
imread
(
file
)
imgh
,
imgw
=
img
.
shape
[
0
:
2
]
res_loc_final
=
[]
for
rno
in
range
(
len
(
res_loc
[
0
])):
x0
,
y0
,
x1
,
y1
=
res_loc
[
0
][
rno
]
left
=
max
(
int
(
imgw
*
x0
),
0
)
top
=
max
(
int
(
imgh
*
y0
),
0
)
right
=
min
(
int
(
imgw
*
x1
),
imgw
-
1
)
bottom
=
min
(
int
(
imgh
*
y1
),
imgh
-
1
)
cv2
.
rectangle
(
img
,
(
left
,
top
),
(
right
,
bottom
),
(
0
,
0
,
255
),
2
)
res_loc_final
.
append
([
left
,
top
,
right
,
bottom
])
res_loc_str
=
json
.
dumps
(
res_loc_final
)
logger
.
info
(
"result: {}, {}"
.
format
(
res_html_code
,
res_loc_final
))
logger
.
info
(
"success!"
)
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
(
config
,
device
,
logger
,
vdl_writer
)
tools/program.py
View file @
adc62fcd
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -186,7 +186,14 @@ def train(config,
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_nrtr
=
config
[
'Architecture'
][
'algorithm'
]
==
"NRTR"
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
model_type
=
None
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
else
:
...
...
@@ -208,9 +215,9 @@ def train(config,
lr
=
optimizer
.
get_lr
()
images
=
batch
[
0
]
if
use_srn
:
others
=
batch
[
-
4
:]
preds
=
model
(
images
,
others
)
model_average
=
True
if
use_srn
or
model_type
==
'table'
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
elif
use_nrtr
:
max_len
=
batch
[
2
].
max
()
preds
=
model
(
images
,
batch
[
1
][:,:
2
+
max_len
])
...
...
@@ -235,8 +242,11 @@ def train(config,
if
cal_metric_during_train
:
# only rec and cls need
batch
=
[
item
.
numpy
()
for
item
in
batch
]
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
if
model_type
==
'table'
:
eval_class
(
preds
,
batch
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
metric
=
eval_class
.
get_metric
()
train_stats
.
update
(
metric
)
...
...
@@ -272,6 +282,7 @@ def train(config,
valid_dataloader
,
post_process_class
,
eval_class
,
model_type
,
use_srn
=
use_srn
)
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
...
...
@@ -339,7 +350,11 @@ def train(config,
return
def
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
def
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
model_type
,
use_srn
=
False
):
model
.
eval
()
with
paddle
.
no_grad
():
...
...
@@ -353,17 +368,20 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
break
images
=
batch
[
0
]
start
=
time
.
time
()
if
use_srn
:
others
=
batch
[
-
4
:]
preds
=
model
(
images
,
others
)
if
use_srn
or
model_type
==
'table'
:
preds
=
model
(
images
,
data
=
batch
[
1
:]
)
else
:
preds
=
model
(
images
)
batch
=
[
item
.
numpy
()
for
item
in
batch
]
# Obtain usable results from post-processing methods
post_result
=
post_process_class
(
preds
,
batch
[
1
])
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
eval_class
(
post_result
,
batch
)
if
model_type
==
'table'
:
eval_class
(
preds
,
batch
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
pbar
.
update
(
1
)
total_frame
+=
len
(
images
)
# Get final metric,eg. acc or hmean
...
...
@@ -387,7 +405,9 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
tools/train.py
View file @
adc62fcd
...
...
@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
from
ppocr.optimizer
import
build_optimizer
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
,
load_dygraph_params
import
tools.program
as
program
dist
.
get_world_size
()
...
...
@@ -97,8 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
optimizer
)
pre_best_model_dict
=
load_dygraph_params
(
config
,
model
,
logger
,
optimizer
)
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
...
...
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment