Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
a8a9b2e5
Unverified
Commit
a8a9b2e5
authored
Sep 07, 2021
by
xiaoting
Committed by
GitHub
Sep 07, 2021
Browse files
Merge branch 'dygraph' into fix_prepare
parents
49e42a16
0f5a5d96
Changes
124
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
23 deletions
+65
-23
tools/infer/utility.py
tools/infer/utility.py
+20
-1
tools/infer_det.py
tools/infer_det.py
+38
-17
tools/infer_rec.py
tools/infer_rec.py
+1
-1
tools/program.py
tools/program.py
+6
-4
No files found.
tools/infer/utility.py
View file @
a8a9b2e5
...
@@ -35,7 +35,7 @@ def init_args():
...
@@ -35,7 +35,7 @@ def init_args():
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--min_subgraph_size"
,
type
=
int
,
default
=
1
0
)
parser
.
add_argument
(
"--min_subgraph_size"
,
type
=
int
,
default
=
1
5
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"fp32"
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"fp32"
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
500
)
...
@@ -159,6 +159,11 @@ def create_predictor(args, mode, logger):
...
@@ -159,6 +159,11 @@ def create_predictor(args, mode, logger):
precision
=
inference
.
PrecisionType
.
Float32
precision
=
inference
.
PrecisionType
.
Float32
if
args
.
use_gpu
:
if
args
.
use_gpu
:
gpu_id
=
get_infer_gpuid
()
if
gpu_id
is
None
:
raise
ValueError
(
"Not found GPU in current device. Please check your device or set args.use_gpu as False"
)
config
.
enable_use_gpu
(
args
.
gpu_mem
,
0
)
config
.
enable_use_gpu
(
args
.
gpu_mem
,
0
)
if
args
.
use_tensorrt
:
if
args
.
use_tensorrt
:
config
.
enable_tensorrt_engine
(
config
.
enable_tensorrt_engine
(
...
@@ -280,6 +285,20 @@ def create_predictor(args, mode, logger):
...
@@ -280,6 +285,20 @@ def create_predictor(args, mode, logger):
return
predictor
,
input_tensor
,
output_tensors
,
config
return
predictor
,
input_tensor
,
output_tensors
,
config
def
get_infer_gpuid
():
cmd
=
"nvidia-smi"
res
=
os
.
popen
(
cmd
).
readlines
()
if
len
(
res
)
==
0
:
return
None
cmd
=
"env | grep CUDA_VISIBLE_DEVICES"
env_cuda
=
os
.
popen
(
cmd
).
readlines
()
if
len
(
env_cuda
)
==
0
:
return
0
else
:
gpu_id
=
env_cuda
[
0
].
strip
().
split
(
"="
)[
1
]
return
int
(
gpu_id
[
0
])
def
draw_e2e_res
(
dt_boxes
,
strs
,
img_path
):
def
draw_e2e_res
(
dt_boxes
,
strs
,
img_path
):
src_im
=
cv2
.
imread
(
img_path
)
src_im
=
cv2
.
imread
(
img_path
)
for
box
,
str
in
zip
(
dt_boxes
,
strs
):
for
box
,
str
in
zip
(
dt_boxes
,
strs
):
...
...
tools/infer_det.py
View file @
a8a9b2e5
...
@@ -34,23 +34,21 @@ import paddle
...
@@ -34,23 +34,21 @@ import paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
,
load_dygraph_params
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
import
tools.program
as
program
def
draw_det_res
(
dt_boxes
,
config
,
img
,
img_name
):
def
draw_det_res
(
dt_boxes
,
config
,
img
,
img_name
,
save_path
):
if
len
(
dt_boxes
)
>
0
:
if
len
(
dt_boxes
)
>
0
:
import
cv2
import
cv2
src_im
=
img
src_im
=
img
for
box
in
dt_boxes
:
for
box
in
dt_boxes
:
box
=
box
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
box
=
box
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
if
not
os
.
path
.
exists
(
save_path
):
'save_res_path'
])
+
"/det_results/"
os
.
makedirs
(
save_path
)
if
not
os
.
path
.
exists
(
save_det_path
):
save_path
=
os
.
path
.
join
(
save_path
,
os
.
path
.
basename
(
img_name
))
os
.
makedirs
(
save_det_path
)
save_path
=
os
.
path
.
join
(
save_det_path
,
os
.
path
.
basename
(
img_name
))
cv2
.
imwrite
(
save_path
,
src_im
)
cv2
.
imwrite
(
save_path
,
src_im
)
logger
.
info
(
"The detected Image saved in {}"
.
format
(
save_path
))
logger
.
info
(
"The detected Image saved in {}"
.
format
(
save_path
))
...
@@ -61,8 +59,7 @@ def main():
...
@@ -61,8 +59,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
)
_
=
load_dygraph_params
(
config
,
model
,
logger
,
None
)
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
...
@@ -96,17 +93,41 @@ def main():
...
@@ -96,17 +93,41 @@ def main():
images
=
paddle
.
to_tensor
(
images
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
,
shape_list
)
post_result
=
post_process_class
(
preds
,
shape_list
)
boxes
=
post_result
[
0
][
'points'
]
# write result
src_img
=
cv2
.
imread
(
file
)
dt_boxes_json
=
[]
dt_boxes_json
=
[]
for
box
in
boxes
:
# parser boxes if post_result is dict
tmp_json
=
{
"transcription"
:
""
}
if
isinstance
(
post_result
,
dict
):
tmp_json
[
'points'
]
=
box
.
tolist
()
det_box_json
=
{}
dt_boxes_json
.
append
(
tmp_json
)
for
k
in
post_result
.
keys
():
boxes
=
post_result
[
k
][
0
][
'points'
]
dt_boxes_list
=
[]
for
box
in
boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
()
dt_boxes_list
.
append
(
tmp_json
)
det_box_json
[
k
]
=
dt_boxes_list
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results_{}/"
.
format
(
k
)
draw_det_res
(
boxes
,
config
,
src_img
,
file
,
save_det_path
)
else
:
boxes
=
post_result
[
0
][
'points'
]
dt_boxes_json
=
[]
# write result
for
box
in
boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
()
dt_boxes_json
.
append
(
tmp_json
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results/"
draw_det_res
(
boxes
,
config
,
src_img
,
file
,
save_det_path
)
otstr
=
file
+
"
\t
"
+
json
.
dumps
(
dt_boxes_json
)
+
"
\n
"
otstr
=
file
+
"
\t
"
+
json
.
dumps
(
dt_boxes_json
)
+
"
\n
"
fout
.
write
(
otstr
.
encode
())
fout
.
write
(
otstr
.
encode
())
src_img
=
cv2
.
imread
(
file
)
draw_det_res
(
boxes
,
config
,
src_img
,
file
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results/"
draw_det_res
(
boxes
,
config
,
src_img
,
file
,
save_det_path
)
logger
.
info
(
"success!"
)
logger
.
info
(
"success!"
)
...
...
tools/infer_rec.py
View file @
a8a9b2e5
...
@@ -121,7 +121,7 @@ def main():
...
@@ -121,7 +121,7 @@ def main():
if
len
(
post_result
[
key
][
0
])
>=
2
:
if
len
(
post_result
[
key
][
0
])
>=
2
:
rec_info
[
key
]
=
{
rec_info
[
key
]
=
{
"label"
:
post_result
[
key
][
0
][
0
],
"label"
:
post_result
[
key
][
0
][
0
],
"score"
:
post_result
[
key
][
0
][
1
],
"score"
:
float
(
post_result
[
key
][
0
][
1
]
)
,
}
}
info
=
json
.
dumps
(
rec_info
)
info
=
json
.
dumps
(
rec_info
)
else
:
else
:
...
...
tools/program.py
View file @
a8a9b2e5
...
@@ -186,9 +186,11 @@ def train(config,
...
@@ -186,9 +186,11 @@ def train(config,
model
.
train
()
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
try
:
use_nrtr
=
config
[
'Architecture'
][
'algorithm'
]
==
"NRTR"
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
except
:
model_type
=
None
model_type
=
None
if
'start_epoch'
in
best_model_dict
:
if
'start_epoch'
in
best_model_dict
:
...
@@ -213,7 +215,7 @@ def train(config,
...
@@ -213,7 +215,7 @@ def train(config,
images
=
batch
[
0
]
images
=
batch
[
0
]
if
use_srn
:
if
use_srn
:
model_average
=
True
model_average
=
True
if
use_srn
or
model_type
==
'table'
:
if
use_srn
or
model_type
==
'table'
or
use_nrtr
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
...
@@ -398,7 +400,7 @@ def preprocess(is_train=False):
...
@@ -398,7 +400,7 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'TableAttn'
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
]
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment