Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
4402e629
"git@developer.sourcefind.cn:dcuai/dlexamples.git" did not exist on "0016b0a737735ecfe6ea10919f2af0bdc3277639"
Commit
4402e629
authored
Nov 09, 2020
by
WenmuZhou
Browse files
修正export_model里的bug,添加predict_det
parent
89e031f0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
18 deletions
+34
-18
tools/export_model.py
tools/export_model.py
+15
-9
tools/infer/predict_det.py
tools/infer/predict_det.py
+19
-9
No files found.
tools/export_model.py
View file @
4402e629
...
@@ -12,6 +12,13 @@
...
@@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
import
argparse
import
argparse
import
paddle
import
paddle
...
@@ -20,14 +27,11 @@ from paddle.jit import to_static
...
@@ -20,14 +27,11 @@ from paddle.jit import to_static
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.logging
import
get_logger
from
tools.program
import
load_config
from
tools.program
import
load_config
from
tools.program
import
merge_config
def
parse_args
():
def
parse_args
():
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
parser
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -43,7 +47,7 @@ class Model(paddle.nn.Layer):
...
@@ -43,7 +47,7 @@ class Model(paddle.nn.Layer):
# Please modify the 'shape' according to actual needs
# Please modify the 'shape' according to actual needs
@
to_static
(
input_spec
=
[
@
to_static
(
input_spec
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
32
,
None
],
dtype
=
'float32'
)
shape
=
[
None
,
3
,
640
,
640
],
dtype
=
'float32'
)
])
])
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
x
=
self
.
pre_model
(
inputs
)
x
=
self
.
pre_model
(
inputs
)
...
@@ -53,14 +57,13 @@ class Model(paddle.nn.Layer):
...
@@ -53,14 +57,13 @@ class Model(paddle.nn.Layer):
def
main
():
def
main
():
FLAGS
=
parse_args
()
FLAGS
=
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
logger
=
get_logger
()
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
config
[
'Global'
])
config
[
'Global'
])
# build model
# build model
#for rec algorithm
#
for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
...
@@ -69,7 +72,10 @@ def main():
...
@@ -69,7 +72,10 @@ def main():
model
.
eval
()
model
.
eval
()
model
=
Model
(
model
)
model
=
Model
(
model
)
paddle
.
jit
.
save
(
model
,
FLAGS
.
output_path
)
save_path
=
'{}/{}'
.
format
(
FLAGS
.
output_path
,
config
[
'Architecture'
][
'model_type'
])
paddle
.
jit
.
save
(
model
,
save_path
)
logger
.
info
(
'inference model is saved to {}'
.
format
(
save_path
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tools/infer/predict_det.py
View file @
4402e629
...
@@ -22,7 +22,6 @@ import cv2
...
@@ -22,7 +22,6 @@ import cv2
import
numpy
as
np
import
numpy
as
np
import
time
import
time
import
sys
import
sys
import
paddle
import
paddle
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
...
@@ -39,7 +38,7 @@ class TextDetector(object):
...
@@ -39,7 +38,7 @@ class TextDetector(object):
postprocess_params
=
{}
postprocess_params
=
{}
if
self
.
det_algorithm
==
"DB"
:
if
self
.
det_algorithm
==
"DB"
:
pre_process_list
=
[{
pre_process_list
=
[{
'ResizeForTest'
:
{
'
Det
ResizeForTest'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
'limit_type'
:
args
.
det_limit_type
}
}
...
@@ -53,7 +52,7 @@ class TextDetector(object):
...
@@ -53,7 +52,7 @@ class TextDetector(object):
},
{
},
{
'ToCHWImage'
:
None
'ToCHWImage'
:
None
},
{
},
{
'
k
eepKeys'
:
{
'
K
eepKeys'
:
{
'keep_keys'
:
[
'image'
,
'shape'
]
'keep_keys'
:
[
'image'
,
'shape'
]
}
}
}]
}]
...
@@ -68,8 +67,9 @@ class TextDetector(object):
...
@@ -68,8 +67,9 @@ class TextDetector(object):
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
=
paddle
.
jit
.
load
(
args
.
det_model_dir
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
utility
.
create_predictor
(
self
.
predictor
.
eval
()
args
,
'det'
,
logger
)
# paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
def
order_points_clockwise
(
self
,
pts
):
def
order_points_clockwise
(
self
,
pts
):
"""
"""
...
@@ -133,11 +133,23 @@ class TextDetector(object):
...
@@ -133,11 +133,23 @@ class TextDetector(object):
return
None
,
0
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
img
=
img
.
copy
()
starttime
=
time
.
time
()
starttime
=
time
.
time
()
preds
=
self
.
predictor
(
img
)
if
self
.
use_zero_copy_run
:
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
zero_copy_run
()
else
:
im
=
paddle
.
fluid
.
core
.
PaddleTensor
(
img
)
self
.
predictor
.
run
([
im
])
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
outputs
[
0
]
# preds = self.predictor(img)
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
dt_boxes
=
post_result
[
0
][
'points'
]
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
...
@@ -146,8 +158,6 @@ class TextDetector(object):
...
@@ -146,8 +158,6 @@ class TextDetector(object):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
utility
.
parse_args
()
args
=
utility
.
parse_args
()
place
=
paddle
.
CPUPlace
()
paddle
.
disable_static
(
place
)
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
logger
=
get_logger
()
logger
=
get_logger
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment