Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
b40ffdd4
Commit
b40ffdd4
authored
Sep 29, 2021
by
Topdu
Browse files
fix sar export inference model
parent
5613e21d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
1 deletion
+76
-1
tools/export_model.py
tools/export_model.py
+6
-0
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+70
-1
No files found.
tools/export_model.py
View file @
b40ffdd4
...
...
@@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger):
]
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"SAR"
:
other_shape
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
48
,
160
],
dtype
=
"float32"
),
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
else
:
infer_shape
=
[
3
,
-
1
,
-
1
]
if
arch_config
[
"model_type"
]
==
"rec"
:
...
...
tools/infer/predict_rec.py
View file @
b40ffdd4
...
...
@@ -68,6 +68,13 @@ class TextRecognizer(object):
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
elif
self
.
rec_algorithm
==
"SAR"
:
postprocess_params
=
{
'name'
:
'SARLabelDecode'
,
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
...
...
@@ -194,6 +201,41 @@ class TextRecognizer(object):
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
def
resize_norm_img_sar
(
self
,
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
valid_ratio
=
1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor
=
int
(
1
/
width_downsample_ratio
)
# resize
ratio
=
w
/
float
(
h
)
resize_w
=
math
.
ceil
(
imgH
*
ratio
)
if
resize_w
%
width_divisor
!=
0
:
resize_w
=
round
(
resize_w
/
width_divisor
)
*
width_divisor
if
imgW_min
is
not
None
:
resize_w
=
max
(
imgW_min
,
resize_w
)
if
imgW_max
is
not
None
:
valid_ratio
=
min
(
1.0
,
1.0
*
resize_w
/
imgW_max
)
resize_w
=
min
(
imgW_max
,
resize_w
)
resized_image
=
cv2
.
resize
(
img
,
(
resize_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
# norm
if
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
resized_image
=
resized_image
[
np
.
newaxis
,
:]
else
:
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
resize_shape
=
resized_image
.
shape
padding_im
=
-
1.0
*
np
.
ones
((
imgC
,
imgH
,
imgW_max
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resize_w
]
=
resized_image
pad_shape
=
padding_im
.
shape
return
padding_im
,
resize_shape
,
pad_shape
,
valid_ratio
def
__call__
(
self
,
img_list
):
img_num
=
len
(
img_list
)
# Calculate the aspect ratio of all text bars
...
...
@@ -216,11 +258,19 @@ class TextRecognizer(object):
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
if
self
.
rec_algorithm
!=
"SRN"
:
if
self
.
rec_algorithm
!=
"SRN"
and
self
.
rec_algorithm
!=
"SAR"
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
elif
self
.
rec_algorithm
==
"SAR"
:
norm_img
,
_
,
_
,
valid_ratio
=
self
.
resize_norm_img_sar
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
valid_ratio
=
np
.
expand_dims
(
valid_ratio
,
axis
=
0
)
valid_ratios
=
[]
valid_ratios
.
append
(
valid_ratio
)
norm_img_batch
.
append
(
norm_img
)
else
:
norm_img
=
self
.
process_image_srn
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
,
8
,
25
)
...
...
@@ -266,6 +316,25 @@ class TextRecognizer(object):
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
{
"predict"
:
outputs
[
2
]}
elif
self
.
rec_algorithm
==
"SAR"
:
valid_ratios
=
np
.
concatenate
(
valid_ratios
)
inputs
=
[
norm_img_batch
,
valid_ratios
,
]
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
outputs
[
0
]
else
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
run
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment