Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
CRNN_Paddle
Commits
7a39c41a
Commit
7a39c41a
authored
Jul 07, 2025
by
wanglch
Browse files
Update infer_rec_prof.py
parent
9994486a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
238 additions
and
0 deletions
+238
-0
tools/infer_rec_prof.py
tools/infer_rec_prof.py
+238
-0
No files found.
tools/infer_rec_prof.py
View file @
7a39c41a
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
sys
import
json
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
".."
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
"auto_growth"
import
paddle
import
paddle.profiler
as
profiler
# 导入性能分析器模块
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
load_model
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
def
main
():
global_config
=
config
[
"Global"
]
# build post process
post_process_class
=
build_post_process
(
config
[
"PostProcess"
],
global_config
)
# build model
if
hasattr
(
post_process_class
,
"character"
):
char_num
=
len
(
getattr
(
post_process_class
,
"character"
))
if
config
[
"Architecture"
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
"Architecture"
][
"Models"
]:
if
(
config
[
"Architecture"
][
"Models"
][
key
][
"Head"
][
"name"
]
==
"MultiHead"
):
# multi head
out_channels_list
=
{}
if
config
[
"PostProcess"
][
"name"
]
==
"DistillationSARLabelDecode"
:
char_num
=
char_num
-
2
if
config
[
"PostProcess"
][
"name"
]
==
"DistillationNRTRLabelDecode"
:
char_num
=
char_num
-
3
out_channels_list
[
"CTCLabelDecode"
]
=
char_num
out_channels_list
[
"SARLabelDecode"
]
=
char_num
+
2
out_channels_list
[
"NRTRLabelDecode"
]
=
char_num
+
3
config
[
"Architecture"
][
"Models"
][
key
][
"Head"
][
"out_channels_list"
]
=
out_channels_list
else
:
config
[
"Architecture"
][
"Models"
][
key
][
"Head"
][
"out_channels"
]
=
char_num
elif
config
[
"Architecture"
][
"Head"
][
"name"
]
==
"MultiHead"
:
# multi head
out_channels_list
=
{}
char_num
=
len
(
getattr
(
post_process_class
,
"character"
))
if
config
[
"PostProcess"
][
"name"
]
==
"SARLabelDecode"
:
char_num
=
char_num
-
2
if
config
[
"PostProcess"
][
"name"
]
==
"NRTRLabelDecode"
:
char_num
=
char_num
-
3
out_channels_list
[
"CTCLabelDecode"
]
=
char_num
out_channels_list
[
"SARLabelDecode"
]
=
char_num
+
2
out_channels_list
[
"NRTRLabelDecode"
]
=
char_num
+
3
config
[
"Architecture"
][
"Head"
][
"out_channels_list"
]
=
out_channels_list
else
:
# base rec model
config
[
"Architecture"
][
"Head"
][
"out_channels"
]
=
char_num
if
config
[
"Architecture"
].
get
(
"algorithm"
)
in
[
"LaTeXOCR"
]:
config
[
"Architecture"
][
"Backbone"
][
"is_predict"
]
=
True
config
[
"Architecture"
][
"Backbone"
][
"is_export"
]
=
True
config
[
"Architecture"
][
"Head"
][
"is_export"
]
=
True
model
=
build_model
(
config
[
"Architecture"
])
load_model
(
config
,
model
)
# create data ops
transforms
=
[]
for
op
in
config
[
"Eval"
][
"dataset"
][
"transforms"
]:
op_name
=
list
(
op
)[
0
]
if
"Label"
in
op_name
:
continue
elif
op_name
in
[
"RecResizeImg"
]:
op
[
op_name
][
"infer_mode"
]
=
True
elif
op_name
==
"KeepKeys"
:
if
config
[
"Architecture"
][
"algorithm"
]
==
"SRN"
:
op
[
op_name
][
"keep_keys"
]
=
[
"image"
,
"encoder_word_pos"
,
"gsrm_word_pos"
,
"gsrm_slf_attn_bias1"
,
"gsrm_slf_attn_bias2"
,
]
elif
config
[
"Architecture"
][
"algorithm"
]
==
"SAR"
:
op
[
op_name
][
"keep_keys"
]
=
[
"image"
,
"valid_ratio"
]
elif
config
[
"Architecture"
][
"algorithm"
]
==
"RobustScanner"
:
op
[
op_name
][
"keep_keys"
]
=
[
"image"
,
"valid_ratio"
,
"word_positons"
]
else
:
op
[
op_name
][
"keep_keys"
]
=
[
"image"
]
transforms
.
append
(
op
)
global_config
[
"infer_mode"
]
=
True
ops
=
create_operators
(
transforms
,
global_config
)
save_res_path
=
config
[
"Global"
].
get
(
"save_res_path"
,
"./output/rec/predicts_rec.txt"
)
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
save_res_path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
save_res_path
))
model
.
eval
()
# 创建性能分析器相关的代码
def
my_on_trace_ready
(
prof
):
callback
=
profiler
.
export_chrome_tracing
(
'./profiler_demo'
)
callback
(
prof
)
# 将 Overview Summary 和 Operator Summary 保存到文件
with
open
(
'./profiler_summary.txt'
,
'w'
)
as
f
:
f
.
write
(
"Overview Summary:
\n
"
)
summary_overview
=
prof
.
summary
(
sorted_by
=
profiler
.
SortedKeys
.
GPUTotal
,
op_detail
=
True
,
thread_sep
=
True
,
time_unit
=
'ms'
)
if
summary_overview
is
not
None
:
f
.
write
(
summary_overview
)
else
:
f
.
write
(
"No summary available for Overview.
\n
"
)
f
.
write
(
"
\n\n
Operator Summary:
\n
"
)
summary_operator
=
prof
.
summary
(
sorted_by
=
profiler
.
SortedKeys
.
GPUTotal
,
op_detail
=
True
,
thread_sep
=
True
,
time_unit
=
'ms'
)
if
summary_operator
is
not
None
:
f
.
write
(
summary_operator
)
else
:
f
.
write
(
"No summary available for Operator.
\n
"
)
# 初始化 Profiler 对象,设置 timer_only=False 以收集详细信息
p
=
profiler
.
Profiler
(
on_trace_ready
=
my_on_trace_ready
,
timer_only
=
False
)
p
.
start
()
infer_imgs
=
config
[
"Global"
][
"infer_img"
]
infer_list
=
config
[
"Global"
].
get
(
"infer_list"
,
None
)
with
open
(
save_res_path
,
"w"
)
as
fout
:
for
file
in
get_image_file_list
(
infer_imgs
,
infer_list
=
infer_list
):
logger
.
info
(
"infer_img: {}"
.
format
(
file
))
with
open
(
file
,
"rb"
)
as
f
:
img
=
f
.
read
()
if
config
[
"Architecture"
][
"algorithm"
]
in
[
"UniMERNet"
,
"PP-FormulaNet-S"
,
"PP-FormulaNet-L"
,
]:
data
=
{
"image"
:
img
,
"filename"
:
file
}
else
:
data
=
{
"image"
:
img
}
batch
=
transform
(
data
,
ops
)
if
config
[
"Architecture"
][
"algorithm"
]
==
"SRN"
:
encoder_word_pos_list
=
np
.
expand_dims
(
batch
[
1
],
axis
=
0
)
gsrm_word_pos_list
=
np
.
expand_dims
(
batch
[
2
],
axis
=
0
)
gsrm_slf_attn_bias1_list
=
np
.
expand_dims
(
batch
[
3
],
axis
=
0
)
gsrm_slf_attn_bias2_list
=
np
.
expand_dims
(
batch
[
4
],
axis
=
0
)
others
=
[
paddle
.
to_tensor
(
encoder_word_pos_list
),
paddle
.
to_tensor
(
gsrm_word_pos_list
),
paddle
.
to_tensor
(
gsrm_slf_attn_bias1_list
),
paddle
.
to_tensor
(
gsrm_slf_attn_bias2_list
),
]
if
config
[
"Architecture"
][
"algorithm"
]
==
"SAR"
:
valid_ratio
=
np
.
expand_dims
(
batch
[
-
1
],
axis
=
0
)
img_metas
=
[
paddle
.
to_tensor
(
valid_ratio
)]
if
config
[
"Architecture"
][
"algorithm"
]
==
"RobustScanner"
:
valid_ratio
=
np
.
expand_dims
(
batch
[
1
],
axis
=
0
)
word_positons
=
np
.
expand_dims
(
batch
[
2
],
axis
=
0
)
img_metas
=
[
paddle
.
to_tensor
(
valid_ratio
),
paddle
.
to_tensor
(
word_positons
),
]
if
config
[
"Architecture"
][
"algorithm"
]
==
"CAN"
:
image_mask
=
paddle
.
ones
(
(
np
.
expand_dims
(
batch
[
0
],
axis
=
0
).
shape
),
dtype
=
"float32"
)
label
=
paddle
.
ones
((
1
,
36
),
dtype
=
"int64"
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
if
config
[
"Architecture"
][
"algorithm"
]
==
"SRN"
:
preds
=
model
(
images
,
others
)
elif
config
[
"Architecture"
][
"algorithm"
]
==
"SAR"
:
preds
=
model
(
images
,
img_metas
)
elif
config
[
"Architecture"
][
"algorithm"
]
==
"RobustScanner"
:
preds
=
model
(
images
,
img_metas
)
elif
config
[
"Architecture"
][
"algorithm"
]
==
"CAN"
:
preds
=
model
([
images
,
image_mask
,
label
])
else
:
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
)
info
=
None
if
isinstance
(
post_result
,
dict
):
rec_info
=
dict
()
for
key
in
post_result
:
if
len
(
post_result
[
key
][
0
])
>=
2
:
rec_info
[
key
]
=
{
"label"
:
post_result
[
key
][
0
][
0
],
"score"
:
float
(
post_result
[
key
][
0
][
1
]),
}
info
=
json
.
dumps
(
rec_info
,
ensure_ascii
=
False
)
elif
isinstance
(
post_result
,
list
)
and
isinstance
(
post_result
[
0
],
int
):
# for RFLearning CNT branch
info
=
str
(
post_result
[
0
])
elif
config
[
"Architecture"
][
"algorithm"
]
in
[
"LaTeXOCR"
,
"UniMERNet"
,
"PP-FormulaNet-S"
,
"PP-FormulaNet-L"
,
]:
info
=
str
(
post_result
[
0
])
else
:
if
len
(
post_result
[
0
])
>=
2
:
info
=
post_result
[
0
][
0
]
+
"
\t
"
+
str
(
post_result
[
0
][
1
])
if
info
is
not
None
:
logger
.
info
(
"
\t
result: {}"
.
format
(
info
))
fout
.
write
(
file
+
"
\t
"
+
info
+
"
\n
"
)
p
.
step
()
# 每次推理后调用 profiler 的 step 方法
p
.
stop
()
# 停止 profiler
logger
.
info
(
"success!"
)
if
__name__
==
"__main__"
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment