Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
ee05c913
Unverified
Commit
ee05c913
authored
Aug 27, 2020
by
zhoujun
Committed by
GitHub
Aug 27, 2020
Browse files
Merge pull request #5 from PaddlePaddle/develop
merge paddleocr
parents
7c09c97d
2bdaea56
Changes
68
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
122 additions
and
45 deletions
+122
-45
tools/export_model.py
tools/export_model.py
+2
-2
tools/infer/predict_det.py
tools/infer/predict_det.py
+54
-12
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+14
-6
tools/infer/predict_system.py
tools/infer/predict_system.py
+0
-1
tools/infer/utility.py
tools/infer/utility.py
+40
-14
tools/infer_det.py
tools/infer_det.py
+6
-4
tools/infer_rec.py
tools/infer_rec.py
+4
-4
tools/train.py
tools/train.py
+2
-2
No files found.
tools/export_model.py
View file @
ee05c913
...
@@ -18,9 +18,9 @@ from __future__ import print_function
...
@@ -18,9 +18,9 @@ from __future__ import print_function
import
os
import
os
import
sys
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
))
)
def
set_paddle_flags
(
**
kwargs
):
def
set_paddle_flags
(
**
kwargs
):
...
...
tools/infer/predict_det.py
View file @
ee05c913
...
@@ -13,30 +13,36 @@
...
@@ -13,30 +13,36 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
sys
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'../..'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
cv2
import
copy
import
numpy
as
np
import
math
import
time
import
sys
import
paddle.fluid
as
fluid
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
from
ppocr.utils.utility
import
initial_logger
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
logger
=
initial_logger
()
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
import
cv2
from
ppocr.data.det.sast_process
import
SASTProcessTest
from
ppocr.data.det.east_process
import
EASTProcessTest
from
ppocr.data.det.east_process
import
EASTProcessTest
from
ppocr.data.det.db_process
import
DBProcessTest
from
ppocr.data.det.db_process
import
DBProcessTest
from
ppocr.postprocess.db_postprocess
import
DBPostProcess
from
ppocr.postprocess.db_postprocess
import
DBPostProcess
from
ppocr.postprocess.east_postprocess
import
EASTPostPocess
from
ppocr.postprocess.east_postprocess
import
EASTPostPocess
import
copy
from
ppocr.postprocess.sast_postprocess
import
SASTPostProcess
import
numpy
as
np
import
math
import
time
import
sys
class
TextDetector
(
object
):
class
TextDetector
(
object
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
max_side_len
=
args
.
det_max_side_len
max_side_len
=
args
.
det_max_side_len
self
.
det_algorithm
=
args
.
det_algorithm
self
.
det_algorithm
=
args
.
det_algorithm
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
preprocess_params
=
{
'max_side_len'
:
max_side_len
}
preprocess_params
=
{
'max_side_len'
:
max_side_len
}
postprocess_params
=
{}
postprocess_params
=
{}
if
self
.
det_algorithm
==
"DB"
:
if
self
.
det_algorithm
==
"DB"
:
...
@@ -52,6 +58,20 @@ class TextDetector(object):
...
@@ -52,6 +58,20 @@ class TextDetector(object):
postprocess_params
[
"cover_thresh"
]
=
args
.
det_east_cover_thresh
postprocess_params
[
"cover_thresh"
]
=
args
.
det_east_cover_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_east_nms_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_east_nms_thresh
self
.
postprocess_op
=
EASTPostPocess
(
postprocess_params
)
self
.
postprocess_op
=
EASTPostPocess
(
postprocess_params
)
elif
self
.
det_algorithm
==
"SAST"
:
self
.
preprocess_op
=
SASTProcessTest
(
preprocess_params
)
postprocess_params
[
"score_thresh"
]
=
args
.
det_sast_score_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_sast_nms_thresh
self
.
det_sast_polygon
=
args
.
det_sast_polygon
if
self
.
det_sast_polygon
:
postprocess_params
[
"sample_pts_num"
]
=
6
postprocess_params
[
"expand_scale"
]
=
1.2
postprocess_params
[
"shrink_ratio_of_width"
]
=
0.2
else
:
postprocess_params
[
"sample_pts_num"
]
=
2
postprocess_params
[
"expand_scale"
]
=
1.0
postprocess_params
[
"shrink_ratio_of_width"
]
=
0.3
self
.
postprocess_op
=
SASTPostProcess
(
postprocess_params
)
else
:
else
:
logger
.
info
(
"unknown det_algorithm:{}"
.
format
(
self
.
det_algorithm
))
logger
.
info
(
"unknown det_algorithm:{}"
.
format
(
self
.
det_algorithm
))
sys
.
exit
(
0
)
sys
.
exit
(
0
)
...
@@ -84,7 +104,7 @@ class TextDetector(object):
...
@@ -84,7 +104,7 @@ class TextDetector(object):
return
rect
return
rect
def
clip_det_res
(
self
,
points
,
img_height
,
img_width
):
def
clip_det_res
(
self
,
points
,
img_height
,
img_width
):
for
pno
in
range
(
4
):
for
pno
in
range
(
points
.
shape
[
0
]
):
points
[
pno
,
0
]
=
int
(
min
(
max
(
points
[
pno
,
0
],
0
),
img_width
-
1
))
points
[
pno
,
0
]
=
int
(
min
(
max
(
points
[
pno
,
0
],
0
),
img_width
-
1
))
points
[
pno
,
1
]
=
int
(
min
(
max
(
points
[
pno
,
1
],
0
),
img_height
-
1
))
points
[
pno
,
1
]
=
int
(
min
(
max
(
points
[
pno
,
1
],
0
),
img_height
-
1
))
return
points
return
points
...
@@ -103,6 +123,15 @@ class TextDetector(object):
...
@@ -103,6 +123,15 @@ class TextDetector(object):
dt_boxes
=
np
.
array
(
dt_boxes_new
)
dt_boxes
=
np
.
array
(
dt_boxes_new
)
return
dt_boxes
return
dt_boxes
def
filter_tag_det_res_only_clip
(
self
,
dt_boxes
,
image_shape
):
img_height
,
img_width
=
image_shape
[
0
:
2
]
dt_boxes_new
=
[]
for
box
in
dt_boxes
:
box
=
self
.
clip_det_res
(
box
,
img_height
,
img_width
)
dt_boxes_new
.
append
(
box
)
dt_boxes
=
np
.
array
(
dt_boxes_new
)
return
dt_boxes
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
ori_im
=
img
.
copy
()
im
,
ratio_list
=
self
.
preprocess_op
(
img
)
im
,
ratio_list
=
self
.
preprocess_op
(
img
)
...
@@ -110,8 +139,12 @@ class TextDetector(object):
...
@@ -110,8 +139,12 @@ class TextDetector(object):
return
None
,
0
return
None
,
0
im
=
im
.
copy
()
im
=
im
.
copy
()
starttime
=
time
.
time
()
starttime
=
time
.
time
()
if
self
.
use_zero_copy_run
:
self
.
input_tensor
.
copy_from_cpu
(
im
)
self
.
input_tensor
.
copy_from_cpu
(
im
)
self
.
predictor
.
zero_copy_run
()
self
.
predictor
.
zero_copy_run
()
else
:
im
=
fluid
.
core
.
PaddleTensor
(
im
)
self
.
predictor
.
run
([
im
])
outputs
=
[]
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
output
=
output_tensor
.
copy_to_cpu
()
...
@@ -120,10 +153,19 @@ class TextDetector(object):
...
@@ -120,10 +153,19 @@ class TextDetector(object):
if
self
.
det_algorithm
==
"EAST"
:
if
self
.
det_algorithm
==
"EAST"
:
outs_dict
[
'f_geo'
]
=
outputs
[
0
]
outs_dict
[
'f_geo'
]
=
outputs
[
0
]
outs_dict
[
'f_score'
]
=
outputs
[
1
]
outs_dict
[
'f_score'
]
=
outputs
[
1
]
elif
self
.
det_algorithm
==
'SAST'
:
outs_dict
[
'f_border'
]
=
outputs
[
0
]
outs_dict
[
'f_score'
]
=
outputs
[
1
]
outs_dict
[
'f_tco'
]
=
outputs
[
2
]
outs_dict
[
'f_tvo'
]
=
outputs
[
3
]
else
:
else
:
outs_dict
[
'maps'
]
=
outputs
[
0
]
outs_dict
[
'maps'
]
=
outputs
[
0
]
dt_boxes_list
=
self
.
postprocess_op
(
outs_dict
,
[
ratio_list
])
dt_boxes_list
=
self
.
postprocess_op
(
outs_dict
,
[
ratio_list
])
dt_boxes
=
dt_boxes_list
[
0
]
dt_boxes
=
dt_boxes_list
[
0
]
if
self
.
det_algorithm
==
"SAST"
and
self
.
det_sast_polygon
:
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
dt_boxes
,
ori_im
.
shape
)
else
:
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
return
dt_boxes
,
elapse
return
dt_boxes
,
elapse
...
...
tools/infer/predict_rec.py
View file @
ee05c913
...
@@ -17,15 +17,18 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
...
@@ -17,15 +17,18 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
tools.infer.utility
as
utility
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
import
cv2
import
cv2
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
import
math
import
math
import
time
import
time
import
paddle.fluid
as
fluid
import
tools.infer.utility
as
utility
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.character
import
CharacterOps
from
ppocr.utils.character
import
CharacterOps
...
@@ -37,6 +40,7 @@ class TextRecognizer(object):
...
@@ -37,6 +40,7 @@ class TextRecognizer(object):
self
.
character_type
=
args
.
rec_char_type
self
.
character_type
=
args
.
rec_char_type
self
.
rec_batch_num
=
args
.
rec_batch_num
self
.
rec_batch_num
=
args
.
rec_batch_num
self
.
rec_algorithm
=
args
.
rec_algorithm
self
.
rec_algorithm
=
args
.
rec_algorithm
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
char_ops_params
=
{
char_ops_params
=
{
"character_type"
:
args
.
rec_char_type
,
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
...
@@ -102,8 +106,12 @@ class TextRecognizer(object):
...
@@ -102,8 +106,12 @@ class TextRecognizer(object):
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
norm_img_batch
.
copy
()
norm_img_batch
=
norm_img_batch
.
copy
()
starttime
=
time
.
time
()
starttime
=
time
.
time
()
if
self
.
use_zero_copy_run
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
zero_copy_run
()
self
.
predictor
.
zero_copy_run
()
else
:
norm_img_batch
=
fluid
.
core
.
PaddleTensor
(
norm_img_batch
)
self
.
predictor
.
run
([
norm_img_batch
])
if
self
.
loss_type
==
"ctc"
:
if
self
.
loss_type
==
"ctc"
:
rec_idx_batch
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
rec_idx_batch
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
...
...
tools/infer/predict_system.py
View file @
ee05c913
...
@@ -157,7 +157,6 @@ def main(args):
...
@@ -157,7 +157,6 @@ def main(args):
boxes
,
boxes
,
txts
,
txts
,
scores
,
scores
,
draw_txt
=
True
,
drop_score
=
drop_score
)
drop_score
=
drop_score
)
draw_img_save
=
"./inference_results/"
draw_img_save
=
"./inference_results/"
if
not
os
.
path
.
exists
(
draw_img_save
):
if
not
os
.
path
.
exists
(
draw_img_save
):
...
...
tools/infer/utility.py
View file @
ee05c913
...
@@ -53,6 +53,11 @@ def parse_args():
...
@@ -53,6 +53,11 @@ def parse_args():
parser
.
add_argument
(
"--det_east_cover_thresh"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--det_east_cover_thresh"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--det_east_nms_thresh"
,
type
=
float
,
default
=
0.2
)
parser
.
add_argument
(
"--det_east_nms_thresh"
,
type
=
float
,
default
=
0.2
)
#SAST parmas
parser
.
add_argument
(
"--det_sast_score_thresh"
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
"--det_sast_nms_thresh"
,
type
=
float
,
default
=
0.2
)
parser
.
add_argument
(
"--det_sast_polygon"
,
type
=
bool
,
default
=
False
)
#params for text recognizer
#params for text recognizer
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'CRNN'
)
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'CRNN'
)
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
)
...
@@ -66,6 +71,7 @@ def parse_args():
...
@@ -66,6 +71,7 @@ def parse_args():
default
=
"./ppocr/utils/ppocr_keys_v1.txt"
)
default
=
"./ppocr/utils/ppocr_keys_v1.txt"
)
parser
.
add_argument
(
"--use_space_char"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--use_space_char"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--use_zero_copy_run"
,
type
=
bool
,
default
=
False
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -100,9 +106,12 @@ def create_predictor(args, mode):
...
@@ -100,9 +106,12 @@ def create_predictor(args, mode):
#config.enable_memory_optim()
#config.enable_memory_optim()
config
.
disable_glog_info
()
config
.
disable_glog_info
()
#
use
zero
copy
if
args
.
use
_
zero
_
copy
_run
:
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_use_feed_fetch_ops
(
False
)
else
:
config
.
switch_use_feed_fetch_ops
(
True
)
predictor
=
create_paddle_predictor
(
config
)
predictor
=
create_paddle_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
input_names
=
predictor
.
get_input_names
()
input_tensor
=
predictor
.
get_input_tensor
(
input_names
[
0
])
input_tensor
=
predictor
.
get_input_tensor
(
input_names
[
0
])
...
@@ -134,7 +143,12 @@ def resize_img(img, input_size=600):
...
@@ -134,7 +143,12 @@ def resize_img(img, input_size=600):
return
im
return
im
def
draw_ocr
(
image
,
boxes
,
txts
,
scores
,
draw_txt
=
True
,
drop_score
=
0.5
):
def
draw_ocr
(
image
,
boxes
,
txts
=
None
,
scores
=
None
,
drop_score
=
0.5
,
font_path
=
"./doc/simfang.ttf"
):
"""
"""
Visualize the results of OCR detection and recognition
Visualize the results of OCR detection and recognition
args:
args:
...
@@ -142,23 +156,29 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5):
...
@@ -142,23 +156,29 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5):
boxes(list): boxes with shape(N, 4, 2)
boxes(list): boxes with shape(N, 4, 2)
txts(list): the texts
txts(list): the texts
scores(list): txxs corresponding scores
scores(list): txxs corresponding scores
draw_txt(bool): whether draw text or not
drop_score(float): only scores greater than drop_threshold will be visualized
drop_score(float): only scores greater than drop_threshold will be visualized
font_path: the path of font which is used to draw text
return(array):
return(array):
the visualized img
the visualized img
"""
"""
if
scores
is
None
:
if
scores
is
None
:
scores
=
[
1
]
*
len
(
boxes
)
scores
=
[
1
]
*
len
(
boxes
)
for
(
box
,
score
)
in
zip
(
boxes
,
scores
):
box_num
=
len
(
boxes
)
if
score
<
drop_score
or
math
.
isnan
(
score
):
for
i
in
range
(
box_num
):
if
scores
is
not
None
and
(
scores
[
i
]
<
drop_score
or
math
.
isnan
(
scores
[
i
])):
continue
continue
box
=
np
.
reshape
(
np
.
array
(
box
),
[
-
1
,
1
,
2
]).
astype
(
np
.
int64
)
box
=
np
.
reshape
(
np
.
array
(
box
es
[
i
]
),
[
-
1
,
1
,
2
]).
astype
(
np
.
int64
)
image
=
cv2
.
polylines
(
np
.
array
(
image
),
[
box
],
True
,
(
255
,
0
,
0
),
2
)
image
=
cv2
.
polylines
(
np
.
array
(
image
),
[
box
],
True
,
(
255
,
0
,
0
),
2
)
if
txts
is
not
None
:
if
draw_txt
:
img
=
np
.
array
(
resize_img
(
image
,
input_size
=
600
))
img
=
np
.
array
(
resize_img
(
image
,
input_size
=
600
))
txt_img
=
text_visual
(
txt_img
=
text_visual
(
txts
,
scores
,
img_h
=
img
.
shape
[
0
],
img_w
=
600
,
threshold
=
drop_score
)
txts
,
scores
,
img_h
=
img
.
shape
[
0
],
img_w
=
600
,
threshold
=
drop_score
,
font_path
=
font_path
)
img
=
np
.
concatenate
([
np
.
array
(
img
),
np
.
array
(
txt_img
)],
axis
=
1
)
img
=
np
.
concatenate
([
np
.
array
(
img
),
np
.
array
(
txt_img
)],
axis
=
1
)
return
img
return
img
return
image
return
image
...
@@ -236,7 +256,12 @@ def str_count(s):
...
@@ -236,7 +256,12 @@ def str_count(s):
return
s_len
-
math
.
ceil
(
en_dg_count
/
2
)
return
s_len
-
math
.
ceil
(
en_dg_count
/
2
)
def
text_visual
(
texts
,
scores
,
img_h
=
400
,
img_w
=
600
,
threshold
=
0.
):
def
text_visual
(
texts
,
scores
,
img_h
=
400
,
img_w
=
600
,
threshold
=
0.
,
font_path
=
"./doc/simfang.ttf"
):
"""
"""
create new blank img and draw txt on it
create new blank img and draw txt on it
args:
args:
...
@@ -244,6 +269,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.):
...
@@ -244,6 +269,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.):
scores(list|None): corresponding score of each txt
scores(list|None): corresponding score of each txt
img_h(int): the height of blank img
img_h(int): the height of blank img
img_w(int): the width of blank img
img_w(int): the width of blank img
font_path: the path of font which is used to draw text
return(array):
return(array):
"""
"""
...
@@ -262,7 +288,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.):
...
@@ -262,7 +288,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.):
font_size
=
20
font_size
=
20
txt_color
=
(
0
,
0
,
0
)
txt_color
=
(
0
,
0
,
0
)
font
=
ImageFont
.
truetype
(
"./doc/simfang.ttf"
,
font_size
,
encoding
=
"utf-8"
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
gap
=
font_size
+
5
gap
=
font_size
+
5
txt_img_list
=
[]
txt_img_list
=
[]
...
@@ -343,6 +369,6 @@ if __name__ == '__main__':
...
@@ -343,6 +369,6 @@ if __name__ == '__main__':
txts
.
append
(
dic
[
'transcription'
])
txts
.
append
(
dic
[
'transcription'
])
scores
.
append
(
round
(
dic
[
'scores'
],
3
))
scores
.
append
(
round
(
dic
[
'scores'
],
3
))
new_img
=
draw_ocr
(
image
,
boxes
,
txts
,
scores
,
draw_txt
=
True
)
new_img
=
draw_ocr
(
image
,
boxes
,
txts
,
scores
)
cv2
.
imwrite
(
img_name
,
new_img
)
cv2
.
imwrite
(
img_name
,
new_img
)
tools/infer_det.py
View file @
ee05c913
...
@@ -22,9 +22,9 @@ import json
...
@@ -22,9 +22,9 @@ import json
import
os
import
os
import
sys
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
))
)
def
set_paddle_flags
(
**
kwargs
):
def
set_paddle_flags
(
**
kwargs
):
...
@@ -134,8 +134,10 @@ def main():
...
@@ -134,8 +134,10 @@ def main():
dic
=
{
'f_score'
:
outs
[
0
],
'f_geo'
:
outs
[
1
]}
dic
=
{
'f_score'
:
outs
[
0
],
'f_geo'
:
outs
[
1
]}
elif
config
[
'Global'
][
'algorithm'
]
==
'DB'
:
elif
config
[
'Global'
][
'algorithm'
]
==
'DB'
:
dic
=
{
'maps'
:
outs
[
0
]}
dic
=
{
'maps'
:
outs
[
0
]}
elif
config
[
'Global'
][
'algorithm'
]
==
'SAST'
:
dic
=
{
'f_score'
:
outs
[
0
],
'f_border'
:
outs
[
1
],
'f_tvo'
:
outs
[
2
],
'f_tco'
:
outs
[
3
]}
else
:
else
:
raise
Exception
(
"only support algorithm: ['EAST', 'DB']"
)
raise
Exception
(
"only support algorithm: ['EAST', 'DB'
, 'SAST'
]"
)
dt_boxes_list
=
postprocess
(
dic
,
ratio_list
)
dt_boxes_list
=
postprocess
(
dic
,
ratio_list
)
for
ino
in
range
(
img_num
):
for
ino
in
range
(
img_num
):
dt_boxes
=
dt_boxes_list
[
ino
]
dt_boxes
=
dt_boxes_list
[
ino
]
...
...
tools/infer_rec.py
View file @
ee05c913
...
@@ -19,9 +19,9 @@ from __future__ import print_function
...
@@ -19,9 +19,9 @@ from __future__ import print_function
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
sys
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
))
)
def
set_paddle_flags
(
**
kwargs
):
def
set_paddle_flags
(
**
kwargs
):
...
@@ -140,12 +140,12 @@ def main():
...
@@ -140,12 +140,12 @@ def main():
preds
=
preds
.
reshape
(
-
1
)
preds
=
preds
.
reshape
(
-
1
)
preds_text
=
char_ops
.
decode
(
preds
)
preds_text
=
char_ops
.
decode
(
preds
)
elif
loss_type
==
"srn"
:
elif
loss_type
==
"srn"
:
c
ur_pred
=
[]
c
har_num
=
char_ops
.
get_char_num
()
preds
=
np
.
array
(
predict
[
0
])
preds
=
np
.
array
(
predict
[
0
])
preds
=
preds
.
reshape
(
-
1
)
preds
=
preds
.
reshape
(
-
1
)
probs
=
np
.
array
(
predict
[
1
])
probs
=
np
.
array
(
predict
[
1
])
ind
=
np
.
argmax
(
probs
,
axis
=
1
)
ind
=
np
.
argmax
(
probs
,
axis
=
1
)
valid_ind
=
np
.
where
(
preds
!=
37
)[
0
]
valid_ind
=
np
.
where
(
preds
!=
int
(
char_num
-
1
)
)[
0
]
if
len
(
valid_ind
)
==
0
:
if
len
(
valid_ind
)
==
0
:
continue
continue
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
...
...
tools/train.py
View file @
ee05c913
...
@@ -18,9 +18,9 @@ from __future__ import print_function
...
@@ -18,9 +18,9 @@ from __future__ import print_function
import
os
import
os
import
sys
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
))
)
def
set_paddle_flags
(
**
kwargs
):
def
set_paddle_flags
(
**
kwargs
):
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment