Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
1f76f449
Commit
1f76f449
authored
Mar 08, 2021
by
Jethong
Browse files
Add PGNet
parent
1a087990
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
3268 additions
and
73 deletions
+3268
-73
ppocr/postprocess/pg_postprocess.py
ppocr/postprocess/pg_postprocess.py
+194
-0
ppocr/postprocess/sast_postprocess.py
ppocr/postprocess/sast_postprocess.py
+127
-72
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+877
-0
ppocr/utils/e2e_metric/polygon_fast.py
ppocr/utils/e2e_metric/polygon_fast.py
+71
-0
ppocr/utils/e2e_metric/tttt.py
ppocr/utils/e2e_metric/tttt.py
+881
-0
ppocr/utils/e2e_utils/extract_textpoint.py
ppocr/utils/e2e_utils/extract_textpoint.py
+532
-0
ppocr/utils/e2e_utils/ski_thin.py
ppocr/utils/e2e_utils/ski_thin.py
+126
-0
ppocr/utils/e2e_utils/visual.py
ppocr/utils/e2e_utils/visual.py
+343
-0
tools/infer_e2e.py
tools/infer_e2e.py
+114
-0
tools/program.py
tools/program.py
+3
-1
No files found.
ppocr/postprocess/pg_postprocess.py
0 → 100644
View file @
1f76f449
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
import
numpy
as
np
from
.locality_aware_nms
import
nms_locality
from
ppocr.utils.e2e_utils.extract_textpoint
import
*
from
ppocr.utils.e2e_utils.ski_thin
import
*
from
ppocr.utils.e2e_utils.visual
import
*
import
paddle
import
cv2
import
time
class
PGPostProcess
(
object
):
"""
The post process for SAST.
"""
def
__init__
(
self
,
score_thresh
=
0.5
,
nms_thresh
=
0.2
,
sample_pts_num
=
2
,
shrink_ratio_of_width
=
0.3
,
expand_scale
=
1.0
,
tcl_map_thresh
=
0.5
,
**
kwargs
):
self
.
result_path
=
""
self
.
valid_set
=
'totaltext'
self
.
Lexicon_Table
=
[
'0'
,
'1'
,
'2'
,
'3'
,
'4'
,
'5'
,
'6'
,
'7'
,
'8'
,
'9'
,
'A'
,
'B'
,
'C'
,
'D'
,
'E'
,
'F'
,
'G'
,
'H'
,
'I'
,
'J'
,
'K'
,
'L'
,
'M'
,
'N'
,
'O'
,
'P'
,
'Q'
,
'R'
,
'S'
,
'T'
,
'U'
,
'V'
,
'W'
,
'X'
,
'Y'
,
'Z'
]
self
.
score_thresh
=
score_thresh
self
.
nms_thresh
=
nms_thresh
self
.
sample_pts_num
=
sample_pts_num
self
.
shrink_ratio_of_width
=
shrink_ratio_of_width
self
.
expand_scale
=
expand_scale
self
.
tcl_map_thresh
=
tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
__call__
(
self
,
outs_dict
,
shape_list
):
p_score
,
p_border
,
p_direction
,
p_char
=
outs_dict
[:
4
]
p_score
=
p_score
[
0
].
numpy
()
p_border
=
p_border
[
0
].
numpy
()
p_direction
=
p_direction
[
0
].
numpy
()
p_char
=
p_char
[
0
].
numpy
()
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
0
]
if
self
.
valid_set
!=
'totaltext'
:
is_curved
=
False
else
:
is_curved
=
True
instance_yxs_list
=
generate_pivot_list
(
p_score
,
p_char
,
p_direction
,
score_thresh
=
self
.
score_thresh
,
is_backbone
=
True
,
is_curved
=
is_curved
)
p_char
=
np
.
expand_dims
(
p_char
,
axis
=
0
)
p_char
=
paddle
.
to_tensor
(
p_char
)
char_seq_idx_set
=
[]
for
i
in
range
(
len
(
instance_yxs_list
)):
gather_info_lod
=
paddle
.
to_tensor
(
instance_yxs_list
[
i
])
f_char_map
=
paddle
.
transpose
(
p_char
,
[
0
,
2
,
3
,
1
])
featyre_seq
=
paddle
.
gather_nd
(
f_char_map
,
gather_info_lod
)
featyre_seq
=
np
.
expand_dims
(
featyre_seq
.
numpy
(),
axis
=
0
)
t
=
len
(
featyre_seq
[
0
])
featyre_seq
=
paddle
.
to_tensor
(
featyre_seq
)
l
=
np
.
array
([[
t
]]).
astype
(
np
.
int64
)
length
=
paddle
.
to_tensor
(
l
)
seq_pred
=
paddle
.
fluid
.
layers
.
ctc_greedy_decoder
(
input
=
featyre_seq
,
blank
=
36
,
input_length
=
length
)
seq_pred1
=
seq_pred
[
0
].
numpy
().
tolist
()[
0
]
seq_len
=
seq_pred
[
1
].
numpy
()[
0
][
0
]
temp_t
=
[]
for
x
in
seq_pred1
[:
seq_len
]:
temp_t
.
append
(
x
)
char_seq_idx_set
.
append
(
temp_t
)
seq_strs
=
[]
for
char_idx_set
in
char_seq_idx_set
:
pr_str
=
''
.
join
([
self
.
Lexicon_Table
[
pos
]
for
pos
in
char_idx_set
])
seq_strs
.
append
(
pr_str
)
poly_list
=
[]
keep_str_list
=
[]
all_point_list
=
[]
all_point_pair_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
yx_center_line
)
==
1
:
print
(
'the length of tcl point is less than 2, repeat'
)
yx_center_line
.
append
(
yx_center_line
[
-
1
])
# expand corresponding offset for total-text.
offset_expand
=
1.0
if
self
.
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
batch_id
,
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
# for visualization
all_point_list
.
append
([
int
(
round
(
x
*
4.0
/
ratio_w
)),
int
(
round
(
y
*
4.0
/
ratio_h
))
])
all_point_pair_list
.
append
(
point_pair
.
round
().
astype
(
np
.
int32
)
.
tolist
())
# ndarry: (x, 2)
detected_poly
,
pair_length_info
=
point_pair2poly
(
point_pair_list
)
print
(
'expand along width. {}'
.
format
(
detected_poly
.
shape
))
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
if
len
(
keep_str
)
<
2
:
print
(
'--> too short, {}'
.
format
(
keep_str
))
continue
keep_str_list
.
append
(
keep_str
)
if
self
.
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
self
.
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
print
(
'--> Not supported format.'
)
exit
(
-
1
)
data
=
{
'points'
:
poly_list
,
'strs'
:
keep_str_list
,
}
# visualization
# if self.save_visualization:
# visualize_e2e_result(im_fn, poly_list, keep_str_list, src_im)
# visualize_point_result(im_fn, all_point_list, all_point_pair_list, src_im)
# save detected boxes
# txt_dir = (result_path[:-1] if result_path.endswith('/') else result_path) + '_txt_anno'
# if not os.path.exists(txt_dir):
# os.makedirs(txt_dir)
# res_file = os.path.join(txt_dir, '{}.txt'.format(im_prefix))
# with open(res_file, 'w') as f:
# for i_box, box in enumerate(poly_list):
# seq_str = keep_str_list[i_box]
# box = np.round(box).astype('int32')
# box_str = ','.join(str(s) for s in (box.flatten().tolist()))
# f.write('{}\t{}\r\n'.format(box_str, seq_str))
return
data
ppocr/postprocess/sast_postprocess.py
View file @
1f76f449
...
@@ -18,6 +18,7 @@ from __future__ import print_function
...
@@ -18,6 +18,7 @@ from __future__ import print_function
import
os
import
os
import
sys
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
...
@@ -49,12 +50,12 @@ class SASTPostProcess(object):
...
@@ -49,12 +50,12 @@ class SASTPostProcess(object):
self
.
shrink_ratio_of_width
=
shrink_ratio_of_width
self
.
shrink_ratio_of_width
=
shrink_ratio_of_width
self
.
expand_scale
=
expand_scale
self
.
expand_scale
=
expand_scale
self
.
tcl_map_thresh
=
tcl_map_thresh
self
.
tcl_map_thresh
=
tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
self
.
is_python35
=
True
def
point_pair2poly
(
self
,
point_pair_list
):
def
point_pair2poly
(
self
,
point_pair_list
):
"""
"""
Transfer vertical point_pairs into poly point in clockwise.
Transfer vertical point_pairs into poly point in clockwise.
...
@@ -66,31 +67,42 @@ class SASTPostProcess(object):
...
@@ -66,31 +67,42 @@ class SASTPostProcess(object):
point_list
[
idx
]
=
point_pair
[
0
]
point_list
[
idx
]
=
point_pair
[
0
]
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
)
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
)
def
shrink_quad_along_width
(
self
,
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
def
shrink_quad_along_width
(
self
,
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
"""
"""
Generate shrink_quad_along_width.
Generate shrink_quad_along_width.
"""
"""
ratio_pair
=
np
.
array
([[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
ratio_pair
=
np
.
array
(
[[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
def
expand_poly_along_width
(
self
,
poly
,
shrink_ratio_of_width
=
0.3
):
def
expand_poly_along_width
(
self
,
poly
,
shrink_ratio_of_width
=
0.3
):
"""
"""
expand poly along width.
expand poly along width.
"""
"""
point_num
=
poly
.
shape
[
0
]
point_num
=
poly
.
shape
[
0
]
left_quad
=
np
.
array
([
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_quad
=
np
.
array
(
[
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
left_quad_expand
=
self
.
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
left_quad_expand
=
self
.
shrink_quad_along_width
(
left_quad
,
left_ratio
,
right_quad
=
np
.
array
([
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
1.0
)
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]],
dtype
=
np
.
float32
)
right_quad
=
np
.
array
(
[
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]
],
dtype
=
np
.
float32
)
right_ratio
=
1.0
+
\
right_ratio
=
1.0
+
\
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
self
.
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
right_quad_expand
=
self
.
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
...
@@ -100,7 +112,7 @@ class SASTPostProcess(object):
...
@@ -100,7 +112,7 @@ class SASTPostProcess(object):
def
restore_quad
(
self
,
tcl_map
,
tcl_map_thresh
,
tvo_map
):
def
restore_quad
(
self
,
tcl_map
,
tcl_map_thresh
,
tvo_map
):
"""Restore quad."""
"""Restore quad."""
xy_text
=
np
.
argwhere
(
tcl_map
[:,
:,
0
]
>
tcl_map_thresh
)
xy_text
=
np
.
argwhere
(
tcl_map
[:,
:,
0
]
>
tcl_map_thresh
)
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
# Sort the text boxes via the y axis
# Sort the text boxes via the y axis
xy_text
=
xy_text
[
np
.
argsort
(
xy_text
[:,
1
])]
xy_text
=
xy_text
[
np
.
argsort
(
xy_text
[:,
1
])]
...
@@ -112,7 +124,7 @@ class SASTPostProcess(object):
...
@@ -112,7 +124,7 @@ class SASTPostProcess(object):
point_num
=
int
(
tvo_map
.
shape
[
-
1
]
/
2
)
point_num
=
int
(
tvo_map
.
shape
[
-
1
]
/
2
)
assert
point_num
==
4
assert
point_num
==
4
tvo_map
=
tvo_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
tvo_map
=
tvo_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
xy_text_tile
=
np
.
tile
(
xy_text
,
(
1
,
point_num
))
# (n, point_num * 2)
xy_text_tile
=
np
.
tile
(
xy_text
,
(
1
,
point_num
))
# (n, point_num * 2)
quads
=
xy_text_tile
-
tvo_map
quads
=
xy_text_tile
-
tvo_map
return
scores
,
quads
,
xy_text
return
scores
,
quads
,
xy_text
...
@@ -121,14 +133,12 @@ class SASTPostProcess(object):
...
@@ -121,14 +133,12 @@ class SASTPostProcess(object):
"""
"""
compute area of a quad.
compute area of a quad.
"""
"""
edge
=
[
edge
=
[(
quad
[
1
][
0
]
-
quad
[
0
][
0
])
*
(
quad
[
1
][
1
]
+
quad
[
0
][
1
]),
(
quad
[
1
][
0
]
-
quad
[
0
][
0
])
*
(
quad
[
1
][
1
]
+
quad
[
0
][
1
]),
(
quad
[
2
][
0
]
-
quad
[
1
][
0
])
*
(
quad
[
2
][
1
]
+
quad
[
1
][
1
]),
(
quad
[
2
][
0
]
-
quad
[
1
][
0
])
*
(
quad
[
2
][
1
]
+
quad
[
1
][
1
]),
(
quad
[
3
][
0
]
-
quad
[
2
][
0
])
*
(
quad
[
3
][
1
]
+
quad
[
2
][
1
]),
(
quad
[
3
][
0
]
-
quad
[
2
][
0
])
*
(
quad
[
3
][
1
]
+
quad
[
2
][
1
]),
(
quad
[
0
][
0
]
-
quad
[
3
][
0
])
*
(
quad
[
0
][
1
]
+
quad
[
3
][
1
])]
(
quad
[
0
][
0
]
-
quad
[
3
][
0
])
*
(
quad
[
0
][
1
]
+
quad
[
3
][
1
])
]
return
np
.
sum
(
edge
)
/
2.
return
np
.
sum
(
edge
)
/
2.
def
nms
(
self
,
dets
):
def
nms
(
self
,
dets
):
if
self
.
is_python35
:
if
self
.
is_python35
:
import
lanms
import
lanms
...
@@ -141,7 +151,7 @@ class SASTPostProcess(object):
...
@@ -141,7 +151,7 @@ class SASTPostProcess(object):
"""
"""
Cluster pixels in tcl_map based on quads.
Cluster pixels in tcl_map based on quads.
"""
"""
instance_count
=
quads
.
shape
[
0
]
+
1
# contain background
instance_count
=
quads
.
shape
[
0
]
+
1
# contain background
instance_label_map
=
np
.
zeros
(
tcl_map
.
shape
[:
2
],
dtype
=
np
.
int32
)
instance_label_map
=
np
.
zeros
(
tcl_map
.
shape
[:
2
],
dtype
=
np
.
int32
)
if
instance_count
==
1
:
if
instance_count
==
1
:
return
instance_count
,
instance_label_map
return
instance_count
,
instance_label_map
...
@@ -149,18 +159,19 @@ class SASTPostProcess(object):
...
@@ -149,18 +159,19 @@ class SASTPostProcess(object):
# predict text center
# predict text center
xy_text
=
np
.
argwhere
(
tcl_map
[:,
:,
0
]
>
tcl_map_thresh
)
xy_text
=
np
.
argwhere
(
tcl_map
[:,
:,
0
]
>
tcl_map_thresh
)
n
=
xy_text
.
shape
[
0
]
n
=
xy_text
.
shape
[
0
]
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
tco
=
tco_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
# (n, 2)
tco
=
tco_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
# (n, 2)
pred_tc
=
xy_text
-
tco
pred_tc
=
xy_text
-
tco
# get gt text center
# get gt text center
m
=
quads
.
shape
[
0
]
m
=
quads
.
shape
[
0
]
gt_tc
=
np
.
mean
(
quads
,
axis
=
1
)
# (m, 2)
gt_tc
=
np
.
mean
(
quads
,
axis
=
1
)
# (m, 2)
pred_tc_tile
=
np
.
tile
(
pred_tc
[:,
np
.
newaxis
,
:],
(
1
,
m
,
1
))
# (n, m, 2)
pred_tc_tile
=
np
.
tile
(
pred_tc
[:,
np
.
newaxis
,
:],
gt_tc_tile
=
np
.
tile
(
gt_tc
[
np
.
newaxis
,
:,
:],
(
n
,
1
,
1
))
# (n, m, 2)
(
1
,
m
,
1
))
# (n, m, 2)
dist_mat
=
np
.
linalg
.
norm
(
pred_tc_tile
-
gt_tc_tile
,
axis
=
2
)
# (n, m)
gt_tc_tile
=
np
.
tile
(
gt_tc
[
np
.
newaxis
,
:,
:],
(
n
,
1
,
1
))
# (n, m, 2)
xy_text_assign
=
np
.
argmin
(
dist_mat
,
axis
=
1
)
+
1
# (n,)
dist_mat
=
np
.
linalg
.
norm
(
pred_tc_tile
-
gt_tc_tile
,
axis
=
2
)
# (n, m)
xy_text_assign
=
np
.
argmin
(
dist_mat
,
axis
=
1
)
+
1
# (n,)
instance_label_map
[
xy_text
[:,
1
],
xy_text
[:,
0
]]
=
xy_text_assign
instance_label_map
[
xy_text
[:,
1
],
xy_text
[:,
0
]]
=
xy_text_assign
return
instance_count
,
instance_label_map
return
instance_count
,
instance_label_map
...
@@ -169,26 +180,47 @@ class SASTPostProcess(object):
...
@@ -169,26 +180,47 @@ class SASTPostProcess(object):
"""
"""
Estimate sample points number.
Estimate sample points number.
"""
"""
eh
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
/
2.0
eh
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
ew
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
3
]))
/
2.0
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
/
2.0
ew
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
3
]))
/
2.0
dense_sample_pts_num
=
max
(
2
,
int
(
ew
))
dense_sample_pts_num
=
max
(
2
,
int
(
ew
))
dense_xy_center_line
=
xy_text
[
np
.
linspace
(
0
,
xy_text
.
shape
[
0
]
-
1
,
dense_sample_pts_num
,
dense_xy_center_line
=
xy_text
[
np
.
linspace
(
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
0
,
xy_text
.
shape
[
0
]
-
1
,
dense_xy_center_line_diff
=
dense_xy_center_line
[
1
:]
-
dense_xy_center_line
[:
-
1
]
dense_sample_pts_num
,
estimate_arc_len
=
np
.
sum
(
np
.
linalg
.
norm
(
dense_xy_center_line_diff
,
axis
=
1
))
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
dense_xy_center_line_diff
=
dense_xy_center_line
[
1
:]
-
dense_xy_center_line
[:
-
1
]
estimate_arc_len
=
np
.
sum
(
np
.
linalg
.
norm
(
dense_xy_center_line_diff
,
axis
=
1
))
sample_pts_num
=
max
(
2
,
int
(
estimate_arc_len
/
eh
))
sample_pts_num
=
max
(
2
,
int
(
estimate_arc_len
/
eh
))
return
sample_pts_num
return
sample_pts_num
def
detect_sast
(
self
,
tcl_map
,
tvo_map
,
tbo_map
,
tco_map
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
def
detect_sast
(
self
,
shrink_ratio_of_width
=
0.3
,
tcl_map_thresh
=
0.5
,
offset_expand
=
1.0
,
out_strid
=
4.0
):
tcl_map
,
tvo_map
,
tbo_map
,
tco_map
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
shrink_ratio_of_width
=
0.3
,
tcl_map_thresh
=
0.5
,
offset_expand
=
1.0
,
out_strid
=
4.0
):
"""
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
"""
# restore quad
# restore quad
scores
,
quads
,
xy_text
=
self
.
restore_quad
(
tcl_map
,
tcl_map_thresh
,
tvo_map
)
scores
,
quads
,
xy_text
=
self
.
restore_quad
(
tcl_map
,
tcl_map_thresh
,
tvo_map
)
dets
=
np
.
hstack
((
quads
,
scores
)).
astype
(
np
.
float32
,
copy
=
False
)
dets
=
np
.
hstack
((
quads
,
scores
)).
astype
(
np
.
float32
,
copy
=
False
)
dets
=
self
.
nms
(
dets
)
dets
=
self
.
nms
(
dets
)
if
dets
.
shape
[
0
]
==
0
:
if
dets
.
shape
[
0
]
==
0
:
...
@@ -202,7 +234,8 @@ class SASTPostProcess(object):
...
@@ -202,7 +234,8 @@ class SASTPostProcess(object):
# instance segmentation
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count
,
instance_label_map
=
self
.
cluster_by_quads_tco
(
tcl_map
,
tcl_map_thresh
,
quads
,
tco_map
)
instance_count
,
instance_label_map
=
self
.
cluster_by_quads_tco
(
tcl_map
,
tcl_map_thresh
,
quads
,
tco_map
)
# restore single poly with tcl instance.
# restore single poly with tcl instance.
poly_list
=
[]
poly_list
=
[]
...
@@ -212,10 +245,10 @@ class SASTPostProcess(object):
...
@@ -212,10 +245,10 @@ class SASTPostProcess(object):
q_area
=
quad_areas
[
instance_idx
-
1
]
q_area
=
quad_areas
[
instance_idx
-
1
]
if
q_area
<
5
:
if
q_area
<
5
:
continue
continue
#
#
len1
=
float
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
]))
len1
=
float
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
]))
len2
=
float
(
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
len2
=
float
(
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
min_len
=
min
(
len1
,
len2
)
min_len
=
min
(
len1
,
len2
)
if
min_len
<
3
:
if
min_len
<
3
:
continue
continue
...
@@ -225,16 +258,18 @@ class SASTPostProcess(object):
...
@@ -225,16 +258,18 @@ class SASTPostProcess(object):
continue
continue
# filter low confidence instance
# filter low confidence instance
xy_text_scores
=
tcl_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
0
]
xy_text_scores
=
tcl_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
0
]
if
np
.
sum
(
xy_text_scores
)
/
quad_areas
[
instance_idx
-
1
]
<
0.1
:
if
np
.
sum
(
xy_text_scores
)
/
quad_areas
[
instance_idx
-
1
]
<
0.1
:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue
continue
# sort xy_text
# sort xy_text
left_center_pt
=
np
.
array
([[(
quad
[
0
,
0
]
+
quad
[
-
1
,
0
])
/
2.0
,
left_center_pt
=
np
.
array
(
(
quad
[
0
,
1
]
+
quad
[
-
1
,
1
])
/
2.0
]])
# (1, 2)
[[(
quad
[
0
,
0
]
+
quad
[
-
1
,
0
])
/
2.0
,
right_center_pt
=
np
.
array
([[(
quad
[
1
,
0
]
+
quad
[
2
,
0
])
/
2.0
,
(
quad
[
0
,
1
]
+
quad
[
-
1
,
1
])
/
2.0
]])
# (1, 2)
(
quad
[
1
,
1
]
+
quad
[
2
,
1
])
/
2.0
]])
# (1, 2)
right_center_pt
=
np
.
array
(
[[(
quad
[
1
,
0
]
+
quad
[
2
,
0
])
/
2.0
,
(
quad
[
1
,
1
]
+
quad
[
2
,
1
])
/
2.0
]])
# (1, 2)
proj_unit_vec
=
(
right_center_pt
-
left_center_pt
)
/
\
proj_unit_vec
=
(
right_center_pt
-
left_center_pt
)
/
\
(
np
.
linalg
.
norm
(
right_center_pt
-
left_center_pt
)
+
1e-6
)
(
np
.
linalg
.
norm
(
right_center_pt
-
left_center_pt
)
+
1e-6
)
proj_value
=
np
.
sum
(
xy_text
*
proj_unit_vec
,
axis
=
1
)
proj_value
=
np
.
sum
(
xy_text
*
proj_unit_vec
,
axis
=
1
)
...
@@ -245,33 +280,45 @@ class SASTPostProcess(object):
...
@@ -245,33 +280,45 @@ class SASTPostProcess(object):
sample_pts_num
=
self
.
estimate_sample_pts_num
(
quad
,
xy_text
)
sample_pts_num
=
self
.
estimate_sample_pts_num
(
quad
,
xy_text
)
else
:
else
:
sample_pts_num
=
self
.
sample_pts_num
sample_pts_num
=
self
.
sample_pts_num
xy_center_line
=
xy_text
[
np
.
linspace
(
0
,
xy_text
.
shape
[
0
]
-
1
,
sample_pts_num
,
xy_center_line
=
xy_text
[
np
.
linspace
(
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
0
,
xy_text
.
shape
[
0
]
-
1
,
sample_pts_num
,
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
point_pair_list
=
[]
point_pair_list
=
[]
for
x
,
y
in
xy_center_line
:
for
x
,
y
in
xy_center_line
:
# get corresponding offset
# get corresponding offset
offset
=
tbo_map
[
y
,
x
,
:].
reshape
(
2
,
2
)
offset
=
tbo_map
[
y
,
x
,
:].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
offset_length
=
np
.
linalg
.
norm
(
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
offset
=
offset
+
offset_detal
# original point
# original point
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
out_strid
/
np
.
array
([
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
out_strid
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
point_pair_list
.
append
(
point_pair
)
# ndarry: (x, 2), expand poly along width
# ndarry: (x, 2), expand poly along width
detected_poly
=
self
.
point_pair2poly
(
point_pair_list
)
detected_poly
=
self
.
point_pair2poly
(
point_pair_list
)
detected_poly
=
self
.
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
)
detected_poly
=
self
.
expand_poly_along_width
(
detected_poly
,
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
shrink_ratio_of_width
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
poly_list
.
append
(
detected_poly
)
poly_list
.
append
(
detected_poly
)
return
poly_list
return
poly_list
def
__call__
(
self
,
outs_dict
,
shape_list
):
def
__call__
(
self
,
outs_dict
,
shape_list
):
score_list
=
outs_dict
[
'f_score'
]
score_list
=
outs_dict
[
'f_score'
]
border_list
=
outs_dict
[
'f_border'
]
border_list
=
outs_dict
[
'f_border'
]
tvo_list
=
outs_dict
[
'f_tvo'
]
tvo_list
=
outs_dict
[
'f_tvo'
]
...
@@ -281,20 +328,28 @@ class SASTPostProcess(object):
...
@@ -281,20 +328,28 @@ class SASTPostProcess(object):
border_list
=
border_list
.
numpy
()
border_list
=
border_list
.
numpy
()
tvo_list
=
tvo_list
.
numpy
()
tvo_list
=
tvo_list
.
numpy
()
tco_list
=
tco_list
.
numpy
()
tco_list
=
tco_list
.
numpy
()
img_num
=
len
(
shape_list
)
img_num
=
len
(
shape_list
)
poly_lists
=
[]
poly_lists
=
[]
for
ino
in
range
(
img_num
):
for
ino
in
range
(
img_num
):
p_score
=
score_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_score
=
score_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_border
=
border_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_border
=
border_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_tvo
=
tvo_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_tvo
=
tvo_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_tco
=
tco_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_tco
=
tco_list
[
ino
].
transpose
((
1
,
2
,
0
))
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
ino
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
ino
]
poly_list
=
self
.
detect_sast
(
p_score
,
p_tvo
,
p_border
,
p_tco
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
poly_list
=
self
.
detect_sast
(
shrink_ratio_of_width
=
self
.
shrink_ratio_of_width
,
p_score
,
tcl_map_thresh
=
self
.
tcl_map_thresh
,
offset_expand
=
self
.
expand_scale
)
p_tvo
,
p_border
,
p_tco
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
shrink_ratio_of_width
=
self
.
shrink_ratio_of_width
,
tcl_map_thresh
=
self
.
tcl_map_thresh
,
offset_expand
=
self
.
expand_scale
)
poly_lists
.
append
({
'points'
:
np
.
array
(
poly_list
)})
poly_lists
.
append
({
'points'
:
np
.
array
(
poly_list
)})
return
poly_lists
return
poly_lists
ppocr/utils/e2e_metric/Deteval.py
0 → 100755
View file @
1f76f449
from
os
import
listdir
import
os
,
sys
from
scipy
import
io
import
numpy
as
np
from
ppocr.utils.e2e_metric.polygon_fast
import
iod
,
area_of_intersection
,
area
from
tqdm
import
tqdm
try
:
# python2
range
=
xrange
except
Exception
:
# python3
range
=
range
"""
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('
\n
')'
"""
# if len(sys.argv) != 4:
# print('\n usage: test.py pred_dir gt_dir savefile')
# sys.exit()
def
get_socre
(
gt_dict
,
pred_dict
):
# allInputs = listdir(input_dir)
allInputs
=
1
def
input_reading_mod
(
pred_dict
,
input
):
"""This helper reads input from txt files"""
det
=
[]
n
=
len
(
pred_dict
)
for
i
in
range
(
n
):
points
=
pred_dict
[
i
][
'points'
]
text
=
pred_dict
[
i
][
'text'
]
# for i in range(len(points)):
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
det
.
append
([
point
,
text
])
return
det
def
gt_reading_mod
(
gt_dict
,
gt_id
):
"""This helper reads groundtruths from mat files"""
# gt_id = gt_id.split('.')[0]
gt
=
[]
n
=
len
(
gt_dict
)
for
i
in
range
(
n
):
points
=
gt_dict
[
i
][
'points'
].
tolist
()
h
=
len
(
points
)
text
=
gt_dict
[
i
][
'text'
]
xx
=
[
np
.
array
(
[
'x:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'y:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'#'
],
dtype
=
'<U1'
),
np
.
array
(
[
'#'
],
dtype
=
'<U1'
)
]
t_x
,
t_y
=
[],
[]
for
j
in
range
(
h
):
t_x
.
append
(
points
[
j
][
0
])
t_y
.
append
(
points
[
j
][
1
])
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
if
text
!=
""
:
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
gt
.
append
(
xx
)
return
gt
def
detection_filtering
(
detections
,
groundtruths
,
threshold
=
0.5
):
for
gt_id
,
gt
in
enumerate
(
groundtruths
):
print
"liushanshan gt[1] = {}"
.
format
(
gt
[
1
])
print
"liushanshan gt[2] = {}"
.
format
(
gt
[
2
])
print
"liushanshan gt[3] = {}"
.
format
(
gt
[
3
])
print
"liushanshan gt[4] = {}"
.
format
(
gt
[
4
])
print
"liushanshan gt[5] = {}"
.
format
(
gt
[
5
])
if
(
gt
[
5
]
==
'#'
)
and
(
gt
[
1
].
shape
[
1
]
>
1
):
gt_x
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
1
])))
gt_y
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
3
])))
for
det_id
,
detection
in
enumerate
(
detections
):
detection_orig
=
detection
detection
=
[
float
(
x
)
for
x
in
detection
[
0
].
split
(
','
)]
# detection = detection.split(',')
detection
=
list
(
map
(
int
,
detection
))
det_x
=
detection
[
0
::
2
]
det_y
=
detection
[
1
::
2
]
det_gt_iou
=
iod
(
det_x
,
det_y
,
gt_x
,
gt_y
)
if
det_gt_iou
>
threshold
:
detections
[
det_id
]
=
[]
detections
[:]
=
[
item
for
item
in
detections
if
item
!=
[]]
return
detections
def
sigma_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
sigma = inter_area / gt_area
"""
# print(area_of_intersection(det_x, det_y, gt_x, gt_y))
return
np
.
round
((
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
area
(
gt_x
,
gt_y
)),
2
)
def
tau_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
tau = inter_area / det_area
"""
# print "liushanshan det_x {}".format(det_x)
# print "liushanshan det_y {}".format(det_y)
# print "liushanshan area {}".format(area(det_x, det_y))
# print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2))
if
area
(
det_x
,
det_y
)
==
0.0
:
return
0
return
np
.
round
((
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
area
(
det_x
,
det_y
)),
2
)
##############################Initialization###################################
global_tp
=
0
global_fp
=
0
global_fn
=
0
global_sigma
=
[]
global_tau
=
[]
tr
=
0.7
tp
=
0.6
fsc_k
=
0.8
k
=
2
global_pred_str
=
[]
global_gt_str
=
[]
###############################################################################
for
input_id
in
range
(
allInputs
):
if
(
input_id
!=
'.DS_Store'
)
and
(
input_id
!=
'Pascal_result.txt'
)
and
(
input_id
!=
'Pascal_result_curved.txt'
)
and
(
input_id
!=
'Pascal_result_non_curved.txt'
)
and
(
input_id
!=
'Deteval_result.txt'
)
and
(
input_id
!=
'Deteval_result_curved.txt'
)
\
and
(
input_id
!=
'Deteval_result_non_curved.txt'
):
print
(
input_id
)
detections
=
input_reading_mod
(
pred_dict
,
input_id
)
# print "liushanshan detections = {}".format(detections)
groundtruths
=
gt_reading_mod
(
gt_dict
,
input_id
)
detections
=
detection_filtering
(
detections
,
groundtruths
)
# filters detections overlapping with DC area
dc_id
=
[]
for
i
in
range
(
len
(
groundtruths
)):
if
groundtruths
[
i
][
5
]
==
'#'
:
dc_id
.
append
(
i
)
cnt
=
0
for
a
in
dc_id
:
num
=
a
-
cnt
del
groundtruths
[
num
]
cnt
+=
1
local_sigma_table
=
np
.
zeros
((
len
(
groundtruths
),
len
(
detections
)))
local_tau_table
=
np
.
zeros
((
len
(
groundtruths
),
len
(
detections
)))
local_pred_str
=
{}
local_gt_str
=
{}
for
gt_id
,
gt
in
enumerate
(
groundtruths
):
if
len
(
detections
)
>
0
:
for
det_id
,
detection
in
enumerate
(
detections
):
detection_orig
=
detection
detection
=
[
float
(
x
)
for
x
in
detection
[
0
].
split
(
','
)]
detection
=
list
(
map
(
int
,
detection
))
pred_seq_str
=
detection_orig
[
1
].
strip
()
det_x
=
detection
[
0
::
2
]
det_y
=
detection
[
1
::
2
]
gt_x
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
1
])))
gt_y
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
3
])))
gt_seq_str
=
str
(
gt
[
4
].
tolist
()[
0
])
local_sigma_table
[
gt_id
,
det_id
]
=
sigma_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
)
local_tau_table
[
gt_id
,
det_id
]
=
tau_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
)
local_pred_str
[
det_id
]
=
pred_seq_str
local_gt_str
[
gt_id
]
=
gt_seq_str
global_sigma
.
append
(
local_sigma_table
)
global_tau
.
append
(
local_tau_table
)
global_pred_str
.
append
(
local_pred_str
)
global_gt_str
.
append
(
local_gt_str
)
print
"liushanshan global_pred_str = {}"
.
format
(
global_pred_str
)
print
"liushanshan global_gt_str = {}"
.
format
(
global_gt_str
)
global_accumulative_recall
=
0
global_accumulative_precision
=
0
total_num_gt
=
0
total_num_det
=
0
hit_str_count
=
0
hit_count
=
0
def
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
gt_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
gt_matching_num_qualified_sigma_candidates
=
gt_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
gt_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[
gt_id
,
:]
>
tp
)
gt_matching_num_qualified_tau_candidates
=
gt_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
det_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[:,
gt_matching_qualified_sigma_candidates
[
0
]]
>
tr
)
det_matching_num_qualified_sigma_candidates
=
det_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
det_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[:,
gt_matching_qualified_tau_candidates
[
0
]]
>
tp
)
det_matching_num_qualified_tau_candidates
=
det_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
if
(
gt_matching_num_qualified_sigma_candidates
==
1
)
and
(
gt_matching_num_qualified_tau_candidates
==
1
)
and
\
(
det_matching_num_qualified_sigma_candidates
==
1
)
and
(
det_matching_num_qualified_tau_candidates
==
1
):
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
matched_det_id
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
# recg start
print
"liushanshan one to one det_id = {}"
.
format
(
matched_det_id
)
print
"liushanshan one to one gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
matched_det_id
[
0
].
tolist
()[
0
]]
print
"liushanshan one to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
det_flag
[
0
,
matched_det_id
]
=
1
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
# skip the following if the groundtruth was matched
if
gt_flag
[
0
,
gt_id
]
>
0
:
continue
non_zero_in_sigma
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
0
)
num_non_zero_in_sigma
=
non_zero_in_sigma
[
0
].
shape
[
0
]
if
num_non_zero_in_sigma
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates
=
np
.
where
((
local_tau_table
[
gt_id
,
:]
>=
tp
)
&
(
det_flag
[
0
,
:]
==
0
))
num_qualified_tau_candidates
=
qualified_tau_candidates
[
0
].
shape
[
0
]
if
num_qualified_tau_candidates
==
1
:
if
((
local_tau_table
[
gt_id
,
qualified_tau_candidates
]
>=
tp
)
and
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
print
"liushanshan one to many det_id = {}"
.
format
(
qualified_tau_candidates
)
print
"liushanshan one to many gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
print
"liushanshan one to many gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to many pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
elif
(
np
.
sum
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
])
>=
tr
):
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
print
"liushanshan one to many det_id = {}"
.
format
(
qualified_tau_candidates
)
print
"liushanshan one to many gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
print
"liushanshan one to many gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to many pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
det_id
in
range
(
num_det
):
# skip the following if the detection was matched
if
det_flag
[
0
,
det_id
]
>
0
:
continue
non_zero_in_tau
=
np
.
where
(
local_tau_table
[:,
det_id
]
>
0
)
num_non_zero_in_tau
=
non_zero_in_tau
[
0
].
shape
[
0
]
if
num_non_zero_in_tau
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates
=
np
.
where
((
local_sigma_table
[:,
det_id
]
>=
tp
)
&
(
gt_flag
[
0
,
:]
==
0
))
num_qualified_sigma_candidates
=
qualified_sigma_candidates
[
0
].
shape
[
0
]
if
num_qualified_sigma_candidates
==
1
:
if
((
local_tau_table
[
qualified_sigma_candidates
,
det_id
]
>=
tp
)
and
(
local_sigma_table
[
qualified_sigma_candidates
,
det_id
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
det_flag
[
0
,
det_id
]
=
1
# recg start
print
"liushanshan many to one det_id = {}"
.
format
(
det_id
)
print
"liushanshan many to one gt_id = {}"
.
format
(
qualified_sigma_candidates
)
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
not
global_gt_str
[
idy
].
has_key
(
ele_gt_id
):
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
print
"liushanshan many to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan many to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
elif
(
np
.
sum
(
local_tau_table
[
qualified_sigma_candidates
,
det_id
])
>=
tp
):
det_flag
[
0
,
det_id
]
=
1
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
# recg start
print
"liushanshan many to one det_id = {}"
.
format
(
det_id
)
print
"liushanshan many to one gt_id = {}"
.
format
(
qualified_sigma_candidates
)
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
not
global_gt_str
[
idy
].
has_key
(
ele_gt_id
):
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
print
"liushanshan many to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan many to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
else
:
print
'no match'
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
single_data
=
{}
for
idx
in
range
(
len
(
global_sigma
)):
# print(allInputs[idx])
local_sigma_table
=
global_sigma
[
idx
]
local_tau_table
=
global_tau
[
idx
]
num_gt
=
local_sigma_table
.
shape
[
0
]
num_det
=
local_sigma_table
.
shape
[
1
]
total_num_gt
=
total_num_gt
+
num_gt
total_num_det
=
total_num_det
+
num_det
local_accumulative_recall
=
0
local_accumulative_precision
=
0
gt_flag
=
np
.
zeros
((
1
,
num_gt
))
det_flag
=
np
.
zeros
((
1
,
num_det
))
#######first check for one-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
# fid = open(fid_path, 'a+')
try
:
local_precision
=
local_accumulative_precision
/
num_det
except
ZeroDivisionError
:
local_precision
=
0
try
:
local_recall
=
local_accumulative_recall
/
num_gt
except
ZeroDivisionError
:
local_recall
=
0
try
:
local_f_score
=
2
*
local_precision
*
local_recall
/
(
local_precision
+
local_recall
)
except
ZeroDivisionError
:
local_f_score
=
0
# temp = ('%s: Recall=%.4f, Precision=%.4f, f_score=%.4f\n' % (
# allInputs[idx], local_recall, local_precision, local_f_score))
single_data
[
'sigma'
]
=
global_sigma
single_data
[
'global_tau'
]
=
global_tau
single_data
[
'global_pred_str'
]
=
global_pred_str
single_data
[
'global_gt_str'
]
=
global_gt_str
single_data
[
"recall"
]
=
local_recall
single_data
[
'precision'
]
=
local_precision
single_data
[
'f_score'
]
=
local_f_score
return
single_data
def
combine_results
(
all_data
):
tr
=
0.7
tp
=
0.6
fsc_k
=
0.8
k
=
2
global_sigma
=
[]
global_tau
=
[]
global_pred_str
=
[]
global_gt_str
=
[]
for
data
in
all_data
:
global_sigma
.
append
(
data
[
'sigma'
][
0
])
global_tau
.
append
(
data
[
'global_tau'
][
0
])
global_pred_str
.
append
(
data
[
'global_pred_str'
][
0
])
global_gt_str
.
append
(
data
[
'global_gt_str'
][
0
])
global_accumulative_recall
=
0
global_accumulative_precision
=
0
total_num_gt
=
0
total_num_det
=
0
hit_str_count
=
0
hit_count
=
0
def
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
gt_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
gt_matching_num_qualified_sigma_candidates
=
gt_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
gt_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[
gt_id
,
:]
>
tp
)
gt_matching_num_qualified_tau_candidates
=
gt_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
det_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[:,
gt_matching_qualified_sigma_candidates
[
0
]]
>
tr
)
det_matching_num_qualified_sigma_candidates
=
det_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
det_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[:,
gt_matching_qualified_tau_candidates
[
0
]]
>
tp
)
det_matching_num_qualified_tau_candidates
=
det_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
if
(
gt_matching_num_qualified_sigma_candidates
==
1
)
and
(
gt_matching_num_qualified_tau_candidates
==
1
)
and
\
(
det_matching_num_qualified_sigma_candidates
==
1
)
and
(
det_matching_num_qualified_tau_candidates
==
1
):
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
matched_det_id
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
# recg start
print
"liushanshan one to one det_id = {}"
.
format
(
matched_det_id
)
print
"liushanshan one to one gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
matched_det_id
[
0
].
tolist
()[
0
]]
print
"liushanshan one to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
det_flag
[
0
,
matched_det_id
]
=
1
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
# skip the following if the groundtruth was matched
if
gt_flag
[
0
,
gt_id
]
>
0
:
continue
non_zero_in_sigma
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
0
)
num_non_zero_in_sigma
=
non_zero_in_sigma
[
0
].
shape
[
0
]
if
num_non_zero_in_sigma
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates
=
np
.
where
((
local_tau_table
[
gt_id
,
:]
>=
tp
)
&
(
det_flag
[
0
,
:]
==
0
))
num_qualified_tau_candidates
=
qualified_tau_candidates
[
0
].
shape
[
0
]
if
num_qualified_tau_candidates
==
1
:
if
((
local_tau_table
[
gt_id
,
qualified_tau_candidates
]
>=
tp
)
and
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
print
"liushanshan one to many det_id = {}"
.
format
(
qualified_tau_candidates
)
print
"liushanshan one to many gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
print
"liushanshan one to many gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to many pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
elif
(
np
.
sum
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
])
>=
tr
):
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
print
"liushanshan one to many det_id = {}"
.
format
(
qualified_tau_candidates
)
print
"liushanshan one to many gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
print
"liushanshan one to many gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to many pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
det_id
in
range
(
num_det
):
# skip the following if the detection was matched
if
det_flag
[
0
,
det_id
]
>
0
:
continue
non_zero_in_tau
=
np
.
where
(
local_tau_table
[:,
det_id
]
>
0
)
num_non_zero_in_tau
=
non_zero_in_tau
[
0
].
shape
[
0
]
if
num_non_zero_in_tau
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates
=
np
.
where
((
local_sigma_table
[:,
det_id
]
>=
tp
)
&
(
gt_flag
[
0
,
:]
==
0
))
num_qualified_sigma_candidates
=
qualified_sigma_candidates
[
0
].
shape
[
0
]
if
num_qualified_sigma_candidates
==
1
:
if
((
local_tau_table
[
qualified_sigma_candidates
,
det_id
]
>=
tp
)
and
(
local_sigma_table
[
qualified_sigma_candidates
,
det_id
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
det_flag
[
0
,
det_id
]
=
1
# recg start
print
"liushanshan many to one det_id = {}"
.
format
(
det_id
)
print
"liushanshan many to one gt_id = {}"
.
format
(
qualified_sigma_candidates
)
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
ele_gt_id
not
in
global_gt_str
[
idy
]:
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
print
"liushanshan many to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan many to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
elif
(
np
.
sum
(
local_tau_table
[
qualified_sigma_candidates
,
det_id
])
>=
tp
):
det_flag
[
0
,
det_id
]
=
1
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
# recg start
print
"liushanshan many to one det_id = {}"
.
format
(
det_id
)
print
"liushanshan many to one gt_id = {}"
.
format
(
qualified_sigma_candidates
)
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
not
global_gt_str
[
idy
].
has_key
(
ele_gt_id
):
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
print
"liushanshan many to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan many to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
else
:
print
'no match'
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
for
idx
in
range
(
len
(
global_sigma
)):
local_sigma_table
=
np
.
array
(
global_sigma
[
idx
])
local_tau_table
=
global_tau
[
idx
]
num_gt
=
local_sigma_table
.
shape
[
0
]
num_det
=
local_sigma_table
.
shape
[
1
]
total_num_gt
=
total_num_gt
+
num_gt
total_num_det
=
total_num_det
+
num_det
local_accumulative_recall
=
0
local_accumulative_precision
=
0
gt_flag
=
np
.
zeros
((
1
,
num_gt
))
det_flag
=
np
.
zeros
((
1
,
num_det
))
#######first check for one-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
try
:
recall
=
global_accumulative_recall
/
total_num_gt
except
ZeroDivisionError
:
recall
=
0
try
:
precision
=
global_accumulative_precision
/
total_num_det
except
ZeroDivisionError
:
precision
=
0
try
:
f_score
=
2
*
precision
*
recall
/
(
precision
+
recall
)
except
ZeroDivisionError
:
f_score
=
0
try
:
seqerr
=
1
-
float
(
hit_str_count
)
/
global_accumulative_recall
except
ZeroDivisionError
:
seqerr
=
1
try
:
recall_e2e
=
float
(
hit_str_count
)
/
total_num_gt
except
ZeroDivisionError
:
recall_e2e
=
0
try
:
precision_e2e
=
float
(
hit_str_count
)
/
total_num_det
except
ZeroDivisionError
:
precision_e2e
=
0
try
:
f_score_e2e
=
2
*
precision_e2e
*
recall_e2e
/
(
precision_e2e
+
recall_e2e
)
except
ZeroDivisionError
:
f_score_e2e
=
0
final
=
{
'total_num_gt'
:
total_num_gt
,
'total_num_det'
:
total_num_det
,
'global_accumulative_recall'
:
global_accumulative_recall
,
'hit_str_count'
:
hit_str_count
,
'recall'
:
recall
,
'precision'
:
precision
,
'f_score'
:
f_score
,
'seqerr'
:
seqerr
,
'recall_e2e'
:
recall_e2e
,
'precision_e2e'
:
precision_e2e
,
'f_score_e2e'
:
f_score_e2e
}
return
final
# a = [1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, 1659, 620, 1654, 681, 1631, 680, 1618,
# 681, 1606, 681, 1594, 681, 1584, 682, 1573, 685, 1542, 694]
# gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
# pred_dict = [{'points': np.array(a), 'text': 'ccc'},
# {'points': np.array(a), 'text': 'ccf'}]
# result = []
# for i in range(2):
# result.append(get_socre(gt_dict, pred_dict))
# print(111)
# a = combine_results(result)
# print(a)
ppocr/utils/e2e_metric/polygon_fast.py
0 → 100755
View file @
1f76f449
import
numpy
as
np
from
shapely.geometry
import
Polygon
#import Polygon
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
:param gt_x: [1, N] Xs of groundtruth's vertices
:param gt_y: [1, N] Ys of groundtruth's vertices
##############
All the calculation of 'AREA' in this script is handled by:
1) First generating a binary mask with the polygon area filled up with 1's
2) Summing up all the 1's
"""
def
area
(
x
,
y
):
polygon
=
Polygon
(
np
.
stack
([
x
,
y
],
axis
=
1
))
return
float
(
polygon
.
area
)
def
approx_area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
This helper determine if both polygons are intersecting with each others with an approximation method.
Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
"""
det_ymax
=
np
.
max
(
det_y
)
det_xmax
=
np
.
max
(
det_x
)
det_ymin
=
np
.
min
(
det_y
)
det_xmin
=
np
.
min
(
det_x
)
gt_ymax
=
np
.
max
(
gt_y
)
gt_xmax
=
np
.
max
(
gt_x
)
gt_ymin
=
np
.
min
(
gt_y
)
gt_xmin
=
np
.
min
(
gt_x
)
all_min_ymax
=
np
.
minimum
(
det_ymax
,
gt_ymax
)
all_max_ymin
=
np
.
maximum
(
det_ymin
,
gt_ymin
)
intersect_heights
=
np
.
maximum
(
0.0
,
(
all_min_ymax
-
all_max_ymin
))
all_min_xmax
=
np
.
minimum
(
det_xmax
,
gt_xmax
)
all_max_xmin
=
np
.
maximum
(
det_xmin
,
gt_xmin
)
intersect_widths
=
np
.
maximum
(
0.0
,
(
all_min_xmax
-
all_max_xmin
))
return
intersect_heights
*
intersect_widths
def
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
):
p1
=
Polygon
(
np
.
stack
([
det_x
,
det_y
],
axis
=
1
)).
buffer
(
0
)
p2
=
Polygon
(
np
.
stack
([
gt_x
,
gt_y
],
axis
=
1
)).
buffer
(
0
)
return
float
(
p1
.
intersection
(
p2
).
area
)
def
area_of_union
(
det_x
,
det_y
,
gt_x
,
gt_y
):
p1
=
Polygon
(
np
.
stack
([
det_x
,
det_y
],
axis
=
1
)).
buffer
(
0
)
p2
=
Polygon
(
np
.
stack
([
gt_x
,
gt_y
],
axis
=
1
)).
buffer
(
0
)
return
float
(
p1
.
union
(
p2
).
area
)
def
iou
(
det_x
,
det_y
,
gt_x
,
gt_y
):
return
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
(
area_of_union
(
det_x
,
det_y
,
gt_x
,
gt_y
)
+
1.0
)
def
iod
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
This helper determine the fraction of intersection area over detection area
"""
return
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
(
area
(
det_x
,
det_y
)
+
1.0
)
ppocr/utils/e2e_metric/tttt.py
0 → 100644
View file @
1f76f449
from
os
import
listdir
import
os
,
sys
from
scipy
import
io
import
numpy
as
np
from
ppocr.utils.e2e_metric.polygon_fast
import
iod
,
area_of_intersection
,
area
from
tqdm
import
tqdm
try
:
# python2
range
=
xrange
except
Exception
:
# python3
range
=
range
"""
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('
\n
')'
"""
# if len(sys.argv) != 4:
# print('\n usage: test.py pred_dir gt_dir savefile')
# sys.exit()
global_tp
=
0
global_fp
=
0
global_fn
=
0
tr
=
0.7
tp
=
0.6
fsc_k
=
0.8
k
=
2
def
get_socre
(
gt_dict
,
pred_dict
):
# allInputs = listdir(input_dir)
allInputs
=
1
global_pred_str
=
[]
global_gt_str
=
[]
global_sigma
=
[]
global_tau
=
[]
def
input_reading_mod
(
pred_dict
,
input
):
"""This helper reads input from txt files"""
det
=
[]
n
=
len
(
pred_dict
)
for
i
in
range
(
n
):
points
=
pred_dict
[
i
][
'points'
]
text
=
pred_dict
[
i
][
'text'
]
# for i in range(len(points)):
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
det
.
append
([
point
,
text
])
return
det
def
gt_reading_mod
(
gt_dict
,
gt_id
):
"""This helper reads groundtruths from mat files"""
# gt_id = gt_id.split('.')[0]
gt
=
[]
n
=
len
(
gt_dict
)
for
i
in
range
(
n
):
points
=
gt_dict
[
i
][
'points'
].
tolist
()
h
=
len
(
points
)
text
=
gt_dict
[
i
][
'text'
]
xx
=
[
np
.
array
(
[
'x:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'y:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'#'
],
dtype
=
'<U1'
),
np
.
array
(
[
'#'
],
dtype
=
'<U1'
)
]
t_x
,
t_y
=
[],
[]
for
j
in
range
(
h
):
t_x
.
append
(
points
[
j
][
0
])
t_y
.
append
(
points
[
j
][
1
])
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
if
text
!=
""
:
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
gt
.
append
(
xx
)
return
gt
def
detection_filtering
(
detections
,
groundtruths
,
threshold
=
0.5
):
for
gt_id
,
gt
in
enumerate
(
groundtruths
):
print
"liushanshan gt[1] = {}"
.
format
(
gt
[
1
])
print
"liushanshan gt[2] = {}"
.
format
(
gt
[
2
])
print
"liushanshan gt[3] = {}"
.
format
(
gt
[
3
])
print
"liushanshan gt[4] = {}"
.
format
(
gt
[
4
])
print
"liushanshan gt[5] = {}"
.
format
(
gt
[
5
])
if
(
gt
[
5
]
==
'#'
)
and
(
gt
[
1
].
shape
[
1
]
>
1
):
gt_x
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
1
])))
gt_y
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
3
])))
for
det_id
,
detection
in
enumerate
(
detections
):
detection_orig
=
detection
detection
=
[
float
(
x
)
for
x
in
detection
[
0
].
split
(
','
)]
# detection = detection.split(',')
detection
=
list
(
map
(
int
,
detection
))
det_x
=
detection
[
0
::
2
]
det_y
=
detection
[
1
::
2
]
det_gt_iou
=
iod
(
det_x
,
det_y
,
gt_x
,
gt_y
)
if
det_gt_iou
>
threshold
:
detections
[
det_id
]
=
[]
detections
[:]
=
[
item
for
item
in
detections
if
item
!=
[]]
return
detections
def
sigma_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
sigma = inter_area / gt_area
"""
# print(area_of_intersection(det_x, det_y, gt_x, gt_y))
return
np
.
round
((
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
area
(
gt_x
,
gt_y
)),
2
)
def
tau_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
tau = inter_area / det_area
"""
# print "liushanshan det_x {}".format(det_x)
# print "liushanshan det_y {}".format(det_y)
# print "liushanshan area {}".format(area(det_x, det_y))
# print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2))
if
area
(
det_x
,
det_y
)
==
0.0
:
return
0
return
np
.
round
((
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
area
(
det_x
,
det_y
)),
2
)
##############################Initialization###################################
###############################################################################
single_data
=
{}
for
input_id
in
range
(
allInputs
):
if
(
input_id
!=
'.DS_Store'
)
and
(
input_id
!=
'Pascal_result.txt'
)
and
(
input_id
!=
'Pascal_result_curved.txt'
)
and
(
input_id
!=
'Pascal_result_non_curved.txt'
)
and
(
input_id
!=
'Deteval_result.txt'
)
and
(
input_id
!=
'Deteval_result_curved.txt'
)
\
and
(
input_id
!=
'Deteval_result_non_curved.txt'
):
print
(
input_id
)
detections
=
input_reading_mod
(
pred_dict
,
input_id
)
# print "liushanshan detections = {}".format(detections)
groundtruths
=
gt_reading_mod
(
gt_dict
,
input_id
)
detections
=
detection_filtering
(
detections
,
groundtruths
)
# filters detections overlapping with DC area
dc_id
=
[]
for
i
in
range
(
len
(
groundtruths
)):
if
groundtruths
[
i
][
5
]
==
'#'
:
dc_id
.
append
(
i
)
cnt
=
0
for
a
in
dc_id
:
num
=
a
-
cnt
del
groundtruths
[
num
]
cnt
+=
1
local_sigma_table
=
np
.
zeros
((
len
(
groundtruths
),
len
(
detections
)))
local_tau_table
=
np
.
zeros
((
len
(
groundtruths
),
len
(
detections
)))
local_pred_str
=
{}
local_gt_str
=
{}
for
gt_id
,
gt
in
enumerate
(
groundtruths
):
if
len
(
detections
)
>
0
:
for
det_id
,
detection
in
enumerate
(
detections
):
detection_orig
=
detection
detection
=
[
float
(
x
)
for
x
in
detection
[
0
].
split
(
','
)]
detection
=
list
(
map
(
int
,
detection
))
pred_seq_str
=
detection_orig
[
1
].
strip
()
det_x
=
detection
[
0
::
2
]
det_y
=
detection
[
1
::
2
]
gt_x
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
1
])))
gt_y
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
3
])))
gt_seq_str
=
str
(
gt
[
4
].
tolist
()[
0
])
local_sigma_table
[
gt_id
,
det_id
]
=
sigma_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
)
local_tau_table
[
gt_id
,
det_id
]
=
tau_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
)
local_pred_str
[
det_id
]
=
pred_seq_str
local_gt_str
[
gt_id
]
=
gt_seq_str
global_sigma
.
append
(
local_sigma_table
)
global_tau
.
append
(
local_tau_table
)
global_pred_str
.
append
(
local_pred_str
)
global_gt_str
.
append
(
local_gt_str
)
print
"liushanshan global_pred_str = {}"
.
format
(
global_pred_str
)
print
"liushanshan global_gt_str = {}"
.
format
(
global_gt_str
)
single_data
[
'sigma'
]
=
global_sigma
single_data
[
'global_tau'
]
=
global_tau
single_data
[
'global_pred_str'
]
=
global_pred_str
single_data
[
'global_gt_str'
]
=
global_gt_str
return
single_data
def
combine_results
(
all_data
):
global_sigma
,
global_tau
,
global_pred_str
,
global_gt_str
=
[],
[],
[],
[]
for
data
in
all_data
:
global_sigma
.
append
(
data
[
'sigma'
])
global_tau
.
append
(
data
[
'global_tau'
])
global_pred_str
.
append
(
data
[
'global_pred_str'
])
global_gt_str
.
append
(
data
[
'global_gt_str'
])
global_accumulative_recall
=
0
global_accumulative_precision
=
0
total_num_gt
=
0
total_num_det
=
0
hit_str_count
=
0
hit_count
=
0
def
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
gt_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
gt_matching_num_qualified_sigma_candidates
=
gt_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
gt_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[
gt_id
,
:]
>
tp
)
gt_matching_num_qualified_tau_candidates
=
gt_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
det_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[:,
gt_matching_qualified_sigma_candidates
[
0
]]
>
tr
)
det_matching_num_qualified_sigma_candidates
=
det_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
det_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[:,
gt_matching_qualified_tau_candidates
[
0
]]
>
tp
)
det_matching_num_qualified_tau_candidates
=
det_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
if
(
gt_matching_num_qualified_sigma_candidates
==
1
)
and
(
gt_matching_num_qualified_tau_candidates
==
1
)
and
\
(
det_matching_num_qualified_sigma_candidates
==
1
)
and
(
det_matching_num_qualified_tau_candidates
==
1
):
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
matched_det_id
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
# recg start
print
"liushanshan one to one det_id = {}"
.
format
(
matched_det_id
)
print
"liushanshan one to one gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
matched_det_id
[
0
].
tolist
()[
0
]]
print
"liushanshan one to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
det_flag
[
0
,
matched_det_id
]
=
1
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
# skip the following if the groundtruth was matched
if
gt_flag
[
0
,
gt_id
]
>
0
:
continue
non_zero_in_sigma
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
0
)
num_non_zero_in_sigma
=
non_zero_in_sigma
[
0
].
shape
[
0
]
if
num_non_zero_in_sigma
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates
=
np
.
where
((
local_tau_table
[
gt_id
,
:]
>=
tp
)
&
(
det_flag
[
0
,
:]
==
0
))
num_qualified_tau_candidates
=
qualified_tau_candidates
[
0
].
shape
[
0
]
if
num_qualified_tau_candidates
==
1
:
if
((
local_tau_table
[
gt_id
,
qualified_tau_candidates
]
>=
tp
)
and
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
print
"liushanshan one to many det_id = {}"
.
format
(
qualified_tau_candidates
)
print
"liushanshan one to many gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
print
"liushanshan one to many gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to many pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
elif
(
np
.
sum
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
])
>=
tr
):
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
print
"liushanshan one to many det_id = {}"
.
format
(
qualified_tau_candidates
)
print
"liushanshan one to many gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
print
"liushanshan one to many gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to many pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
det_id
in
range
(
num_det
):
# skip the following if the detection was matched
if
det_flag
[
0
,
det_id
]
>
0
:
continue
non_zero_in_tau
=
np
.
where
(
local_tau_table
[:,
det_id
]
>
0
)
num_non_zero_in_tau
=
non_zero_in_tau
[
0
].
shape
[
0
]
if
num_non_zero_in_tau
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates
=
np
.
where
((
local_sigma_table
[:,
det_id
]
>=
tp
)
&
(
gt_flag
[
0
,
:]
==
0
))
num_qualified_sigma_candidates
=
qualified_sigma_candidates
[
0
].
shape
[
0
]
if
num_qualified_sigma_candidates
==
1
:
if
((
local_tau_table
[
qualified_sigma_candidates
,
det_id
]
>=
tp
)
and
(
local_sigma_table
[
qualified_sigma_candidates
,
det_id
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
det_flag
[
0
,
det_id
]
=
1
# recg start
print
"liushanshan many to one det_id = {}"
.
format
(
det_id
)
print
"liushanshan many to one gt_id = {}"
.
format
(
qualified_sigma_candidates
)
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
not
global_gt_str
[
idy
].
has_key
(
ele_gt_id
):
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
print
"liushanshan many to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan many to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
elif
(
np
.
sum
(
local_tau_table
[
qualified_sigma_candidates
,
det_id
])
>=
tp
):
det_flag
[
0
,
det_id
]
=
1
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
# recg start
print
"liushanshan many to one det_id = {}"
.
format
(
det_id
)
print
"liushanshan many to one gt_id = {}"
.
format
(
qualified_sigma_candidates
)
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
not
global_gt_str
[
idy
].
has_key
(
ele_gt_id
):
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
print
"liushanshan many to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan many to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
else
:
print
'no match'
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
for
idx
in
range
(
len
(
global_sigma
)):
# print(allInputs[idx])
local_sigma_table
=
np
.
array
(
global_sigma
[
idx
])
local_tau_table
=
global_tau
[
idx
]
num_gt
=
local_sigma_table
.
shape
[
0
]
num_det
=
local_sigma_table
.
shape
[
1
]
total_num_gt
=
total_num_gt
+
num_gt
total_num_det
=
total_num_det
+
num_det
local_accumulative_recall
=
0
local_accumulative_precision
=
0
gt_flag
=
np
.
zeros
((
1
,
num_gt
))
det_flag
=
np
.
zeros
((
1
,
num_det
))
#######first check for one-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
try
:
recall
=
global_accumulative_recall
/
total_num_gt
except
ZeroDivisionError
:
recall
=
0
try
:
precision
=
global_accumulative_precision
/
total_num_det
except
ZeroDivisionError
:
precision
=
0
try
:
f_score
=
2
*
precision
*
recall
/
(
precision
+
recall
)
except
ZeroDivisionError
:
f_score
=
0
try
:
seqerr
=
1
-
float
(
hit_str_count
)
/
global_accumulative_recall
except
ZeroDivisionError
:
seqerr
=
1
try
:
recall_e2e
=
float
(
hit_str_count
)
/
total_num_gt
except
ZeroDivisionError
:
recall_e2e
=
0
try
:
precision_e2e
=
float
(
hit_str_count
)
/
total_num_det
except
ZeroDivisionError
:
precision_e2e
=
0
try
:
f_score_e2e
=
2
*
precision_e2e
*
recall_e2e
/
(
precision_e2e
+
recall_e2e
)
except
ZeroDivisionError
:
f_score_e2e
=
0
final
=
{
'total_num_gt'
:
total_num_gt
,
'total_num_det'
:
total_num_det
,
'global_accumulative_recall'
:
global_accumulative_recall
,
'hit_str_count'
:
hit_str_count
,
'recall'
:
recall
,
'precision'
:
precision
,
'f_score'
:
f_score
,
'seqerr'
:
seqerr
,
'recall_e2e'
:
recall_e2e
,
'precision_e2e'
:
precision_e2e
,
'f_score_e2e'
:
f_score_e2e
}
return
final
# def combine_results(all_data):
# tr = 0.7
# tp = 0.6
# fsc_k = 0.8
# k = 2
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
# for data in all_data:
# global_sigma.append(data['sigma'])
# global_tau.append(data['global_tau'])
# global_pred_str.append(data['global_pred_str'])
# global_gt_str.append(data['global_gt_str'])
#
# global_accumulative_recall = 0
# global_accumulative_precision = 0
# total_num_gt = 0
# total_num_det = 0
# hit_str_count = 0
# hit_count = 0
#
# def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for gt_id in range(num_gt):
# gt_matching_qualified_sigma_candidates = np.where(local_sigma_table[gt_id, :] > tr)
# gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[0].shape[0]
# gt_matching_qualified_tau_candidates = np.where(local_tau_table[gt_id, :] > tp)
# gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[0].shape[0]
#
# det_matching_qualified_sigma_candidates = np.where(
# local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] > tr)
# det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[0].shape[0]
# det_matching_qualified_tau_candidates = np.where(
# local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > tp)
# det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[0].shape[0]
#
# if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
# (det_matching_num_qualified_sigma_candidates == 1) and (
# det_matching_num_qualified_tau_candidates == 1):
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, gt_id] = 1
# matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# # recg start
# print
# "liushanshan one to one det_id = {}".format(matched_det_id)
# print
# "liushanshan one to one gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[0]]
# print
# "liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
# det_flag[0, matched_det_id] = 1
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for gt_id in range(num_gt):
# # skip the following if the groundtruth was matched
# if gt_flag[0, gt_id] > 0:
# continue
#
# non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
# num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
#
# if num_non_zero_in_sigma >= k:
# ####search for all detections that overlaps with this groundtruth
# qualified_tau_candidates = np.where((local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0))
# num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0]
#
# if num_qualified_tau_candidates == 1:
# if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) and (
# local_sigma_table[gt_id, qualified_tau_candidates] >= tr)):
# # became an one-to-one case
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, gt_id] = 1
# det_flag[0, qualified_tau_candidates] = 1
# # recg start
# print
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
# print
# "liushanshan one to many gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
# print
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
# elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr):
# gt_flag[0, gt_id] = 1
# det_flag[0, qualified_tau_candidates] = 1
# # recg start
# print
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
# print
# "liushanshan one to many gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
# print
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
#
# global_accumulative_recall = global_accumulative_recall + fsc_k
# global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
#
# local_accumulative_recall = local_accumulative_recall + fsc_k
# local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
#
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for det_id in range(num_det):
# # skip the following if the detection was matched
# if det_flag[0, det_id] > 0:
# continue
#
# non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
# num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
#
# if num_non_zero_in_tau >= k:
# ####search for all detections that overlaps with this groundtruth
# qualified_sigma_candidates = np.where((local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
# num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0]
#
# if num_qualified_sigma_candidates == 1:
# if ((local_tau_table[qualified_sigma_candidates, det_id] >= tp) and (
# local_sigma_table[qualified_sigma_candidates, det_id] >= tr)):
# # became an one-to-one case
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, qualified_sigma_candidates] = 1
# det_flag[0, det_id] = 1
# # recg start
# print
# "liushanshan many to one det_id = {}".format(det_id)
# print
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
# pred_str_cur = global_pred_str[idy][det_id]
# gt_len = len(qualified_sigma_candidates[0])
# for idx in range(gt_len):
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
# if ele_gt_id not in global_gt_str[idy]:
# continue
# gt_str_cur = global_gt_str[idy][ele_gt_id]
# print
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# break
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# break
# # recg end
# elif (np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp):
# det_flag[0, det_id] = 1
# gt_flag[0, qualified_sigma_candidates] = 1
# # recg start
# print
# "liushanshan many to one det_id = {}".format(det_id)
# print
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
# pred_str_cur = global_pred_str[idy][det_id]
# gt_len = len(qualified_sigma_candidates[0])
# for idx in range(gt_len):
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
# if not global_gt_str[idy].has_key(ele_gt_id):
# continue
# gt_str_cur = global_gt_str[idy][ele_gt_id]
# print
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# break
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# break
# else:
# print
# 'no match'
# # recg end
#
# global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
# global_accumulative_precision = global_accumulative_precision + fsc_k
#
# local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
# local_accumulative_precision = local_accumulative_precision + fsc_k
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# for idx in range(len(global_sigma)):
# local_sigma_table = np.array(global_sigma[idx])
# local_tau_table = np.array(global_tau[idx])
#
# num_gt = local_sigma_table.shape[0]
# num_det = local_sigma_table.shape[1]
#
# total_num_gt = total_num_gt + num_gt
# total_num_det = total_num_det + num_det
#
# local_accumulative_recall = 0
# local_accumulative_precision = 0
# gt_flag = np.zeros((1, num_gt))
# det_flag = np.zeros((1, num_det))
#
# #######first check for one-to-one case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
#
# hit_str_count += hit_str_num
# #######then check for one-to-many case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
# hit_str_count += hit_str_num
# #######then check for many-to-one case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
# try:
# recall = global_accumulative_recall / total_num_gt
# except ZeroDivisionError:
# recall = 0
#
# try:
# precision = global_accumulative_precision / total_num_det
# except ZeroDivisionError:
# precision = 0
#
# try:
# f_score = 2 * precision * recall / (precision + recall)
# except ZeroDivisionError:
# f_score = 0
#
# try:
# seqerr = 1 - float(hit_str_count) / global_accumulative_recall
# except ZeroDivisionError:
# seqerr = 1
#
# try:
# recall_e2e = float(hit_str_count) / total_num_gt
# except ZeroDivisionError:
# recall_e2e = 0
#
# try:
# precision_e2e = float(hit_str_count) / total_num_det
# except ZeroDivisionError:
# precision_e2e = 0
#
# try:
# f_score_e2e = 2 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e)
# except ZeroDivisionError:
# f_score_e2e = 0
#
# final = {
# 'total_num_gt': total_num_gt,
# 'total_num_det': total_num_det,
# 'global_accumulative_recall': global_accumulative_recall,
# 'hit_str_count': hit_str_count,
# 'recall': recall,
# 'precision': precision,
# 'f_score': f_score,
# 'seqerr': seqerr,
# 'recall_e2e': recall_e2e,
# 'precision_e2e': precision_e2e,
# 'f_score_e2e': f_score_e2e
# }
# return final
a
=
[
1526
,
642
,
1565
,
629
,
1579
,
627
,
1593
,
625
,
1607
,
623
,
1620
,
622
,
1634
,
620
,
1659
,
620
,
1654
,
681
,
1631
,
680
,
1618
,
681
,
1606
,
681
,
1594
,
681
,
1584
,
682
,
1573
,
685
,
1542
,
694
]
gt_dict
=
[{
'points'
:
np
.
array
(
a
).
reshape
(
-
1
,
2
),
'text'
:
'MILK'
}]
pred_dict
=
[{
'points'
:
np
.
array
(
a
),
'text'
:
'ccc'
},
{
'points'
:
np
.
array
(
a
),
'text'
:
'ccf'
}]
result
=
[]
result
.
append
(
get_socre
(
gt_dict
,
gt_dict
))
a
=
combine_results
(
result
)
print
(
a
)
ppocr/utils/e2e_utils/extract_textpoint.py
0 → 100644
View file @
1f76f449
"""Contains various CTC decoders."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
cv2
import
time
import
math
import
numpy
as
np
from
itertools
import
groupby
from
ppocr.utils.e2e_utils.ski_thin
import
thin
def
softmax
(
logits
):
"""
logits: N x d
"""
max_value
=
np
.
max
(
logits
,
axis
=
1
,
keepdims
=
True
)
exp
=
np
.
exp
(
logits
-
max_value
)
exp_sum
=
np
.
sum
(
exp
,
axis
=
1
,
keepdims
=
True
)
dist
=
exp
/
exp_sum
return
dist
def
get_keep_pos_idxs
(
labels
,
remove_blank
=
None
):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list
=
[]
keep_pos_idx_list
=
[]
keep_char_idx_list
=
[]
for
k
,
v_
in
groupby
(
labels
):
current_len
=
len
(
list
(
v_
))
if
k
!=
remove_blank
:
current_idx
=
int
(
sum
(
duplicate_len_list
)
+
current_len
//
2
)
keep_pos_idx_list
.
append
(
current_idx
)
keep_char_idx_list
.
append
(
k
)
duplicate_len_list
.
append
(
current_len
)
return
keep_char_idx_list
,
keep_pos_idx_list
def
remove_blank
(
labels
,
blank
=
0
):
new_labels
=
[
x
for
x
in
labels
if
x
!=
blank
]
return
new_labels
def
insert_blank
(
labels
,
blank
=
0
):
new_labels
=
[
blank
]
for
l
in
labels
:
new_labels
+=
[
l
,
blank
]
return
new_labels
def
ctc_greedy_decoder
(
probs_seq
,
blank
=
95
,
keep_blank_in_idxs
=
True
):
"""
CTC greedy (best path) decoder.
"""
raw_str
=
np
.
argmax
(
np
.
array
(
probs_seq
),
axis
=
1
)
remove_blank_in_pos
=
None
if
keep_blank_in_idxs
else
blank
dedup_str
,
keep_idx_list
=
get_keep_pos_idxs
(
raw_str
,
remove_blank
=
remove_blank_in_pos
)
dst_str
=
remove_blank
(
dedup_str
,
blank
=
blank
)
return
dst_str
,
keep_idx_list
def
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
keep_blank_in_idxs
=
True
):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
_
,
_
,
C
=
logits_map
.
shape
ys
,
xs
=
zip
(
*
gather_info
)
logits_seq
=
logits_map
[
list
(
ys
),
list
(
xs
)]
# n x 96
probs_seq
=
softmax
(
logits_seq
)
dst_str
,
keep_idx_list
=
ctc_greedy_decoder
(
probs_seq
,
blank
=
C
-
1
,
keep_blank_in_idxs
=
keep_blank_in_idxs
)
keep_gather_list
=
[
gather_info
[
idx
]
for
idx
in
keep_idx_list
]
return
dst_str
,
keep_gather_list
def
ctc_decoder_for_image
(
gather_info_list
,
logits_map
,
keep_blank_in_idxs
=
True
):
"""
CTC decoder using multiple processes.
"""
decoder_results
=
[]
for
gather_info
in
gather_info_list
:
res
=
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
keep_blank_in_idxs
=
keep_blank_in_idxs
)
decoder_results
.
append
(
res
)
return
decoder_results
def
sort_with_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def
sort_part_with_direction
(
pos_list
,
point_direction
):
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
2
)
point_direction
=
np
.
array
(
point_direction
).
reshape
(
-
1
,
2
)
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
sorted_direction
=
point_direction
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
,
sorted_direction
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
2
)
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
sorted_point
,
sorted_direction
=
sort_part_with_direction
(
pos_list
,
point_direction
)
point_num
=
len
(
sorted_point
)
if
point_num
>=
16
:
middle_num
=
point_num
//
2
first_part_point
=
sorted_point
[:
middle_num
]
first_point_direction
=
sorted_direction
[:
middle_num
]
sorted_fist_part_point
,
sorted_fist_part_direction
=
sort_part_with_direction
(
first_part_point
,
first_point_direction
)
last_part_point
=
sorted_point
[
middle_num
:]
last_point_direction
=
sorted_direction
[
middle_num
:]
sorted_last_part_point
,
sorted_last_part_direction
=
sort_part_with_direction
(
last_part_point
,
last_point_direction
)
sorted_point
=
sorted_fist_part_point
+
sorted_last_part_point
sorted_direction
=
sorted_fist_part_direction
+
sorted_last_part_direction
return
sorted_point
,
np
.
array
(
sorted_direction
)
def
add_id
(
pos_list
,
image_id
=
0
):
"""
Add id for gather feature, for inference.
"""
new_list
=
[]
for
item
in
pos_list
:
new_list
.
append
((
image_id
,
item
[
0
],
item
[
1
]))
return
new_list
def
sort_and_expand_with_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
# expand along
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
right_dirction
=
point_direction
[
point_num
-
sub_direction_len
:,
:]
left_average_direction
=
-
np
.
mean
(
left_direction
,
axis
=
0
,
keepdims
=
True
)
left_average_len
=
np
.
linalg
.
norm
(
left_average_direction
)
left_start
=
np
.
array
(
sorted_list
[
0
])
left_step
=
left_average_direction
/
(
left_average_len
+
1e-6
)
right_average_direction
=
np
.
mean
(
right_dirction
,
axis
=
0
,
keepdims
=
True
)
right_average_len
=
np
.
linalg
.
norm
(
right_average_direction
)
right_step
=
right_average_direction
/
(
right_average_len
+
1e-6
)
right_start
=
np
.
array
(
sorted_list
[
-
1
])
append_num
=
max
(
int
((
left_average_len
+
right_average_len
)
/
2.0
*
0.15
),
1
)
left_list
=
[]
right_list
=
[]
for
i
in
range
(
append_num
):
ly
,
lx
=
np
.
round
(
left_start
+
left_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ly
<
h
and
lx
<
w
and
(
ly
,
lx
)
not
in
left_list
:
left_list
.
append
((
ly
,
lx
))
ry
,
rx
=
np
.
round
(
right_start
+
right_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ry
<
h
and
rx
<
w
and
(
ry
,
rx
)
not
in
right_list
:
right_list
.
append
((
ry
,
rx
))
all_list
=
left_list
[::
-
1
]
+
sorted_list
+
right_list
return
all_list
def
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
binary_tcl_map
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
# expand along
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
right_dirction
=
point_direction
[
point_num
-
sub_direction_len
:,
:]
left_average_direction
=
-
np
.
mean
(
left_direction
,
axis
=
0
,
keepdims
=
True
)
left_average_len
=
np
.
linalg
.
norm
(
left_average_direction
)
left_start
=
np
.
array
(
sorted_list
[
0
])
left_step
=
left_average_direction
/
(
left_average_len
+
1e-6
)
right_average_direction
=
np
.
mean
(
right_dirction
,
axis
=
0
,
keepdims
=
True
)
right_average_len
=
np
.
linalg
.
norm
(
right_average_direction
)
right_step
=
right_average_direction
/
(
right_average_len
+
1e-6
)
right_start
=
np
.
array
(
sorted_list
[
-
1
])
append_num
=
max
(
int
((
left_average_len
+
right_average_len
)
/
2.0
*
0.15
),
1
)
max_append_num
=
2
*
append_num
left_list
=
[]
right_list
=
[]
for
i
in
range
(
max_append_num
):
ly
,
lx
=
np
.
round
(
left_start
+
left_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ly
<
h
and
lx
<
w
and
(
ly
,
lx
)
not
in
left_list
:
if
binary_tcl_map
[
ly
,
lx
]
>
0.5
:
left_list
.
append
((
ly
,
lx
))
else
:
break
for
i
in
range
(
max_append_num
):
ry
,
rx
=
np
.
round
(
right_start
+
right_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ry
<
h
and
rx
<
w
and
(
ry
,
rx
)
not
in
right_list
:
if
binary_tcl_map
[
ry
,
rx
]
>
0.5
:
right_list
.
append
((
ry
,
rx
))
else
:
break
all_list
=
left_list
[::
-
1
]
+
sorted_list
+
right_list
return
all_list
def
generate_pivot_list_curved
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_expand
=
True
,
is_backbone
=
False
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
skeleton_map
=
thin
(
p_tcl_map
)
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
center_pos_yxs
=
[]
end_points_yxs
=
[]
instance_center_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
3
:
continue
if
is_expand
:
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
p_tcl_map
)
else
:
pos_list_sorted
,
_
=
sort_with_direction
(
pos_list
,
f_direction
)
all_pos_yxs
.
append
(
pos_list_sorted
)
# use decoder to filter backgroud points.
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decode_res
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
keep_blank_in_idxs
=
True
)
for
decoded_str
,
keep_yxs_list
in
decode_res
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
else
:
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
if
is_backbone
:
return
instance_center_pos_yxs
else
:
return
center_pos_yxs
,
end_points_yxs
def
generate_pivot_list_horizontal
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map_bi
=
(
p_score
>
score_thresh
)
*
1.0
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
p_tcl_map_bi
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
center_pos_yxs
=
[]
end_points_yxs
=
[]
instance_center_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
5
:
continue
# add rule here
main_direction
=
extract_main_direction
(
pos_list
,
f_direction
)
# y x
reference_directin
=
np
.
array
([
0
,
1
]).
reshape
([
-
1
,
2
])
# y x
is_h_angle
=
abs
(
np
.
sum
(
main_direction
*
reference_directin
))
<
math
.
cos
(
math
.
pi
/
180
*
70
)
point_yxs
=
np
.
array
(
pos_list
)
max_y
,
max_x
=
np
.
max
(
point_yxs
,
axis
=
0
)
min_y
,
min_x
=
np
.
min
(
point_yxs
,
axis
=
0
)
is_h_len
=
(
max_y
-
min_y
)
<
1.5
*
(
max_x
-
min_x
)
pos_list_final
=
[]
if
is_h_len
:
xs
=
np
.
unique
(
xs
)
for
x
in
xs
:
ys
=
instance_label_map
[:,
x
].
copy
().
reshape
((
-
1
,
))
y
=
int
(
np
.
where
(
ys
==
instance_id
)[
0
].
mean
())
pos_list_final
.
append
((
y
,
x
))
else
:
ys
=
np
.
unique
(
ys
)
for
y
in
ys
:
xs
=
instance_label_map
[
y
,
:].
copy
().
reshape
((
-
1
,
))
x
=
int
(
np
.
where
(
xs
==
instance_id
)[
0
].
mean
())
pos_list_final
.
append
((
y
,
x
))
pos_list_sorted
,
_
=
sort_with_direction
(
pos_list_final
,
f_direction
)
all_pos_yxs
.
append
(
pos_list_sorted
)
# use decoder to filter backgroud points.
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decode_res
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
keep_blank_in_idxs
=
True
)
for
decoded_str
,
keep_yxs_list
in
decode_res
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
else
:
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
if
is_backbone
:
return
instance_center_pos_yxs
else
:
return
center_pos_yxs
,
end_points_yxs
def
generate_pivot_list
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
"""
Warp all the function together.
"""
if
is_curved
:
return
generate_pivot_list_curved
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
score_thresh
,
is_expand
=
True
,
is_backbone
=
is_backbone
,
image_id
=
image_id
)
else
:
return
generate_pivot_list_horizontal
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
score_thresh
,
is_backbone
=
is_backbone
,
image_id
=
image_id
)
# for refine module
def
extract_main_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list
=
np
.
array
(
pos_list
)
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
average_direction
=
average_direction
/
(
np
.
linalg
.
norm
(
average_direction
)
+
1e-6
)
return
average_direction
def
sort_by_direction_with_image_id_deprecated
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
def
sort_by_direction_with_image_id
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def
sort_part_with_direction
(
pos_list_full
,
point_direction
):
pos_list_full
=
np
.
array
(
pos_list_full
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
np
.
array
(
point_direction
).
reshape
(
-
1
,
2
)
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
sorted_direction
=
point_direction
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
,
sorted_direction
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
point_direction
=
f_direction
[
pos_list
[:,
1
],
pos_list
[:,
2
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
sorted_point
,
sorted_direction
=
sort_part_with_direction
(
pos_list
,
point_direction
)
point_num
=
len
(
sorted_point
)
if
point_num
>=
16
:
middle_num
=
point_num
//
2
first_part_point
=
sorted_point
[:
middle_num
]
first_point_direction
=
sorted_direction
[:
middle_num
]
sorted_fist_part_point
,
sorted_fist_part_direction
=
sort_part_with_direction
(
first_part_point
,
first_point_direction
)
last_part_point
=
sorted_point
[
middle_num
:]
last_point_direction
=
sorted_direction
[
middle_num
:]
sorted_last_part_point
,
sorted_last_part_direction
=
sort_part_with_direction
(
last_part_point
,
last_point_direction
)
sorted_point
=
sorted_fist_part_point
+
sorted_last_part_point
sorted_direction
=
sorted_fist_part_direction
+
sorted_last_part_direction
return
sorted_point
def
generate_pivot_list_tt_inference
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
skeleton_map
=
thin
(
p_tcl_map
)
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
3
:
continue
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
p_tcl_map
)
# pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
pos_list_sorted_with_id
=
add_id
(
pos_list_sorted
,
image_id
=
image_id
)
all_pos_yxs
.
append
(
pos_list_sorted_with_id
)
return
all_pos_yxs
if
__name__
==
'__main__'
:
np
.
random
.
seed
(
0
)
import
time
logits_map
=
np
.
random
.
random
([
10
,
20
,
33
])
# a list of [x, y]
instance_gather_info_1
=
[(
2
,
3
),
(
2
,
4
),
(
3
,
5
)]
instance_gather_info_2
=
[(
15
,
6
),
(
15
,
7
),
(
18
,
8
)]
instance_gather_info_3
=
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
)]
gather_info_list
=
[
instance_gather_info_1
,
instance_gather_info_2
,
instance_gather_info_3
]
time0
=
time
.
time
()
res
=
ctc_decoder_for_image
(
gather_info_list
,
logits_map
,
keep_blank_in_idxs
=
True
)
print
(
res
)
print
(
'cost {}'
.
format
(
time
.
time
()
-
time0
))
print
(
'--'
*
20
)
ppocr/utils/e2e_utils/ski_thin.py
0 → 100644
View file @
1f76f449
"""
Algorithms for computing the skeleton of a binary image
"""
import
numpy
as
np
from
scipy
import
ndimage
as
ndi
G123_LUT
=
np
.
array
(
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
1
,
0
,
0
,
0
],
dtype
=
np
.
bool
)
G123P_LUT
=
np
.
array
(
[
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
dtype
=
np
.
bool
)
def
thin
(
image
,
max_iter
=
None
):
"""
Perform morphological thinning of a binary image.
Parameters
----------
image : binary (M, N) ndarray
The image to be thinned.
max_iter : int, number of iterations, optional
Regardless of the value of this parameter, the thinned image
is returned immediately if an iteration produces no change.
If this parameter is specified it thus sets an upper bound on
the number of iterations performed.
Returns
-------
out : ndarray of bool
Thinned image.
See also
--------
skeletonize, medial_axis
Notes
-----
This algorithm [1]_ works by making multiple passes over the image,
removing pixels matching a set of criteria designed to thin
connected regions while preserving eight-connected components and
2 x 2 squares [2]_. In each of the two sub-iterations the algorithm
correlates the intermediate skeleton image with a neighborhood mask,
then looks up each neighborhood in a lookup table indicating whether
the central pixel should be deleted in that sub-iteration.
References
----------
.. [1] Z. Guo and R. W. Hall, "Parallel thinning with
two-subiteration algorithms," Comm. ACM, vol. 32, no. 3,
pp. 359-373, 1989. :DOI:`10.1145/62065.62074`
.. [2] Lam, L., Seong-Whan Lee, and Ching Y. Suen, "Thinning
Methodologies-A Comprehensive Survey," IEEE Transactions on
Pattern Analysis and Machine Intelligence, Vol 14, No. 9,
p. 879, 1992. :DOI:`10.1109/34.161346`
Examples
--------
>>> square = np.zeros((7, 7), dtype=np.uint8)
>>> square[1:-1, 2:-2] = 1
>>> square[0, 1] = 1
>>> square
array([[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
>>> skel = thin(square)
>>> skel.astype(np.uint8)
array([[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
"""
# convert image to uint8 with values in {0, 1}
skel
=
np
.
asanyarray
(
image
,
dtype
=
bool
).
astype
(
np
.
uint8
)
# neighborhood mask
mask
=
np
.
array
([[
8
,
4
,
2
],
[
16
,
0
,
1
],
[
32
,
64
,
128
]],
dtype
=
np
.
uint8
)
# iterate until convergence, up to the iteration limit
max_iter
=
max_iter
or
np
.
inf
n_iter
=
0
n_pts_old
,
n_pts_new
=
np
.
inf
,
np
.
sum
(
skel
)
while
n_pts_old
!=
n_pts_new
and
n_iter
<
max_iter
:
n_pts_old
=
n_pts_new
# perform the two "subiterations" described in the paper
for
lut
in
[
G123_LUT
,
G123P_LUT
]:
# correlate image with neighborhood mask
N
=
ndi
.
correlate
(
skel
,
mask
,
mode
=
'constant'
)
# take deletion decision from this subiteration's LUT
D
=
np
.
take
(
lut
,
N
)
# perform deletion
skel
[
D
]
=
0
n_pts_new
=
np
.
sum
(
skel
)
# count points after thinning
n_iter
+=
1
return
skel
.
astype
(
np
.
bool
)
ppocr/utils/e2e_utils/visual.py
0 → 100644
View file @
1f76f449
import
os
import
numpy
as
np
import
cv2
import
time
def
visualize_e2e_result
(
im_fn
,
poly_list
,
seq_strs
,
src_im
):
"""
"""
result_path
=
'./out'
im_basename
=
os
.
path
.
basename
(
im_fn
)
im_prefix
=
im_basename
[:
im_basename
.
rfind
(
'.'
)]
vis_det_img
=
src_im
.
copy
()
valid_set
=
'partvgg'
gt_dir
=
"/Users/hongyongjie/Downloads/part_vgg_synth/train"
text_path
=
os
.
path
.
join
(
gt_dir
,
im_prefix
+
'.txt'
)
fid
=
open
(
text_path
,
'r'
)
lines
=
[
line
.
strip
()
for
line
in
fid
.
readlines
()]
for
line
in
lines
:
if
valid_set
==
'partvgg'
:
tokens
=
line
.
strip
().
split
(
'
\t
'
)[
0
].
split
(
','
)
# tokens = line.strip().split(',')
coords
=
tokens
[:]
coords
=
list
(
map
(
float
,
coords
))
gt_poly
=
np
.
array
(
coords
).
reshape
(
1
,
4
,
2
)
elif
valid_set
==
'totaltext'
:
tokens
=
line
.
strip
().
split
(
'
\t
'
)[
0
].
split
(
','
)
coords
=
tokens
[:]
coords_len
=
len
(
coords
)
/
2
coords
=
list
(
map
(
float
,
coords
))
gt_poly
=
np
.
array
(
coords
).
reshape
(
1
,
coords_len
,
2
)
cv2
.
polylines
(
vis_det_img
,
np
.
array
(
gt_poly
).
astype
(
np
.
int32
),
isClosed
=
True
,
color
=
(
255
,
0
,
0
),
thickness
=
2
)
for
detected_poly
,
recognized_str
in
zip
(
poly_list
,
seq_strs
):
cv2
.
polylines
(
vis_det_img
,
np
.
array
(
detected_poly
[
np
.
newaxis
,
...]).
astype
(
np
.
int32
),
isClosed
=
True
,
color
=
(
0
,
0
,
255
),
thickness
=
2
)
cv2
.
putText
(
vis_det_img
,
recognized_str
,
org
=
(
int
(
detected_poly
[
0
,
0
]),
int
(
detected_poly
[
0
,
1
])),
fontFace
=
cv2
.
FONT_HERSHEY_COMPLEX
,
fontScale
=
0.7
,
color
=
(
0
,
255
,
0
),
thickness
=
1
)
if
not
os
.
path
.
exists
(
result_path
):
os
.
makedirs
(
result_path
)
cv2
.
imwrite
(
"{}/{}_detection.jpg"
.
format
(
result_path
,
im_prefix
),
vis_det_img
)
def
visualization_output
(
src_image
,
f_tcl
,
f_chars
,
output_dir
,
image_prefix
=
None
):
"""
"""
# restore BGR image, CHW -> HWC
im_mean
=
[
0.485
,
0.456
,
0.406
]
im_std
=
[
0.229
,
0.224
,
0.225
]
im_mean
=
np
.
array
(
im_mean
).
reshape
((
3
,
1
,
1
))
im_std
=
np
.
array
(
im_std
).
reshape
((
3
,
1
,
1
))
src_image
*=
im_std
src_image
+=
im_mean
src_image
=
src_image
.
transpose
([
1
,
2
,
0
])
src_image
=
src_image
[:,
:,
::
-
1
]
*
255
# BGR -> RGB
H
,
W
,
_
=
src_image
.
shape
file_prefix
=
image_prefix
if
image_prefix
is
not
None
else
str
(
int
(
time
.
time
()
*
1000
))
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
# visualization f_tcl
tcl_file_name
=
os
.
path
.
join
(
output_dir
,
file_prefix
+
'_0_tcl.jpg'
)
vis_tcl_img
=
src_image
.
copy
()
f_tcl_resized
=
cv2
.
resize
(
f_tcl
,
dsize
=
(
W
,
H
))
vis_tcl_img
[:,
:,
1
]
=
f_tcl_resized
*
255
cv2
.
imwrite
(
tcl_file_name
,
vis_tcl_img
)
# visualization char maps
vis_char_img
=
src_image
.
copy
()
# CHW -> HWC
char_file_name
=
os
.
path
.
join
(
output_dir
,
file_prefix
+
'_1_chars.jpg'
)
f_chars
=
np
.
argmax
(
f_chars
,
axis
=
2
)[:,
:,
np
.
newaxis
].
astype
(
'float32'
)
f_chars
[
f_chars
<
95
]
=
1.0
f_chars
[
f_chars
==
95
]
=
0.0
f_chars_resized
=
cv2
.
resize
(
f_chars
,
dsize
=
(
W
,
H
))
vis_char_img
[:,
:,
1
]
=
f_chars_resized
*
255
cv2
.
imwrite
(
char_file_name
,
vis_char_img
)
def
visualize_point_result
(
im_fn
,
point_list
,
point_pair_list
,
src_im
,
gt_dir
,
result_path
):
"""
"""
im_basename
=
os
.
path
.
basename
(
im_fn
)
im_prefix
=
im_basename
[:
im_basename
.
rfind
(
'.'
)]
vis_det_img
=
src_im
.
copy
()
# draw gt bbox on the image.
text_path
=
os
.
path
.
join
(
gt_dir
,
im_prefix
+
'.txt'
)
fid
=
open
(
text_path
,
'r'
)
lines
=
[
line
.
strip
()
for
line
in
fid
.
readlines
()]
for
line
in
lines
:
tokens
=
line
.
strip
().
split
(
'
\t
'
)
coords
=
tokens
[
0
].
split
(
','
)
coords_len
=
len
(
coords
)
coords
=
list
(
map
(
float
,
coords
))
gt_poly
=
np
.
array
(
coords
).
reshape
(
1
,
coords_len
/
2
,
2
)
cv2
.
polylines
(
vis_det_img
,
np
.
array
(
gt_poly
).
astype
(
np
.
int32
),
isClosed
=
True
,
color
=
(
255
,
255
,
255
),
thickness
=
1
)
for
point
,
point_pair
in
zip
(
point_list
,
point_pair_list
):
cv2
.
line
(
vis_det_img
,
tuple
(
point_pair
[
0
]),
tuple
(
point_pair
[
1
]),
(
0
,
255
,
255
),
thickness
=
1
)
cv2
.
circle
(
vis_det_img
,
tuple
(
point
),
2
,
(
0
,
0
,
255
))
cv2
.
circle
(
vis_det_img
,
tuple
(
point_pair
[
0
]),
2
,
(
255
,
0
,
0
))
cv2
.
circle
(
vis_det_img
,
tuple
(
point_pair
[
1
]),
2
,
(
0
,
255
,
0
))
if
not
os
.
path
.
exists
(
result_path
):
os
.
makedirs
(
result_path
)
cv2
.
imwrite
(
"{}/{}_border_points.jpg"
.
format
(
result_path
,
im_prefix
),
vis_det_img
)
def
resize_image
(
im
,
max_side_len
=
512
):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h
,
w
,
_
=
im
.
shape
resize_w
=
w
resize_h
=
h
# Fix the longer side
if
resize_h
>
resize_w
:
ratio
=
float
(
max_side_len
)
/
resize_h
else
:
ratio
=
float
(
max_side_len
)
/
resize_w
resize_h
=
int
(
resize_h
*
ratio
)
resize_w
=
int
(
resize_w
*
ratio
)
max_stride
=
128
resize_h
=
(
resize_h
+
max_stride
-
1
)
//
max_stride
*
max_stride
resize_w
=
(
resize_w
+
max_stride
-
1
)
//
max_stride
*
max_stride
im
=
cv2
.
resize
(
im
,
(
int
(
resize_w
),
int
(
resize_h
)))
ratio_h
=
resize_h
/
float
(
h
)
ratio_w
=
resize_w
/
float
(
w
)
return
im
,
(
ratio_h
,
ratio_w
)
def
resize_image_min
(
im
,
max_side_len
=
512
):
"""
"""
print
(
'--> Using resize_image_min'
)
h
,
w
,
_
=
im
.
shape
resize_w
=
w
resize_h
=
h
# Fix the longer side
if
resize_h
<
resize_w
:
ratio
=
float
(
max_side_len
)
/
resize_h
else
:
ratio
=
float
(
max_side_len
)
/
resize_w
resize_h
=
int
(
resize_h
*
ratio
)
resize_w
=
int
(
resize_w
*
ratio
)
max_stride
=
128
resize_h
=
(
resize_h
+
max_stride
-
1
)
//
max_stride
*
max_stride
resize_w
=
(
resize_w
+
max_stride
-
1
)
//
max_stride
*
max_stride
im
=
cv2
.
resize
(
im
,
(
int
(
resize_w
),
int
(
resize_h
)))
ratio_h
=
resize_h
/
float
(
h
)
ratio_w
=
resize_w
/
float
(
w
)
return
im
,
(
ratio_h
,
ratio_w
)
def
resize_image_for_totaltext
(
im
,
max_side_len
=
512
):
"""
"""
h
,
w
,
_
=
im
.
shape
resize_w
=
w
resize_h
=
h
ratio
=
1.25
if
h
*
ratio
>
max_side_len
:
ratio
=
float
(
max_side_len
)
/
resize_h
# Fix the longer side
# if resize_h > resize_w:
# ratio = float(max_side_len) / resize_h
# else:
# ratio = float(max_side_len) / resize_w
###
resize_h
=
int
(
resize_h
*
ratio
)
resize_w
=
int
(
resize_w
*
ratio
)
max_stride
=
128
resize_h
=
(
resize_h
+
max_stride
-
1
)
//
max_stride
*
max_stride
resize_w
=
(
resize_w
+
max_stride
-
1
)
//
max_stride
*
max_stride
im
=
cv2
.
resize
(
im
,
(
int
(
resize_w
),
int
(
resize_h
)))
ratio_h
=
resize_h
/
float
(
h
)
ratio_w
=
resize_w
/
float
(
w
)
return
im
,
(
ratio_h
,
ratio_w
)
def
point_pair2poly
(
point_pair_list
):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list
=
[]
for
point_pair
in
point_pair_list
:
pair_length
=
np
.
linalg
.
norm
(
point_pair
[
0
]
-
point_pair
[
1
])
pair_length_list
.
append
(
pair_length
)
pair_length_list
=
np
.
array
(
pair_length_list
)
pair_info
=
(
pair_length_list
.
max
(),
pair_length_list
.
min
(),
pair_length_list
.
mean
())
# constract poly
point_num
=
len
(
point_pair_list
)
*
2
point_list
=
[
0
]
*
point_num
for
idx
,
point_pair
in
enumerate
(
point_pair_list
):
point_list
[
idx
]
=
point_pair
[
0
]
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
),
pair_info
def
shrink_quad_along_width
(
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
"""
Generate shrink_quad_along_width.
"""
ratio_pair
=
np
.
array
(
[[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
def
expand_poly_along_width
(
poly
,
shrink_ratio_of_width
=
0.3
):
"""
expand poly along width.
"""
point_num
=
poly
.
shape
[
0
]
left_quad
=
np
.
array
(
[
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
left_quad_expand
=
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
right_quad
=
np
.
array
(
[
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]
],
dtype
=
np
.
float32
)
right_ratio
=
1.0
+
\
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
poly
[
point_num
//
2
]
=
right_quad_expand
[
2
]
return
poly
def
norm2
(
x
,
axis
=
None
):
if
axis
:
return
np
.
sqrt
(
np
.
sum
(
x
**
2
,
axis
=
axis
))
return
np
.
sqrt
(
np
.
sum
(
x
**
2
))
def
cos
(
p1
,
p2
):
return
(
p1
*
p2
).
sum
()
/
(
norm2
(
p1
)
*
norm2
(
p2
))
def
generate_direction_info
(
image_fn
,
H
,
W
,
ratio_h
,
ratio_w
,
max_length
=
640
,
out_scale
=
4
,
gt_dir
=
None
):
"""
"""
im_basename
=
os
.
path
.
basename
(
image_fn
)
im_prefix
=
im_basename
[:
im_basename
.
rfind
(
'.'
)]
instance_direction_map
=
np
.
zeros
(
shape
=
[
H
//
out_scale
,
W
//
out_scale
,
3
])
if
gt_dir
is
None
:
gt_dir
=
'/home/vis/huangzuming/data/SYNTH_DATA/part_vgg_synth_icdar/processed/val/poly'
# get gt label map
text_path
=
os
.
path
.
join
(
gt_dir
,
im_prefix
+
'.txt'
)
fid
=
open
(
text_path
,
'r'
)
lines
=
[
line
.
strip
()
for
line
in
fid
.
readlines
()]
for
label_idx
,
line
in
enumerate
(
lines
,
start
=
1
):
coords
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
if
txt
==
'###'
:
continue
tokens
=
coords
.
strip
().
split
(
','
)
coords
=
list
(
map
(
float
,
tokens
))
poly
=
np
.
array
(
coords
).
reshape
(
4
,
2
)
*
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
1
,
2
)
/
out_scale
mid_idx
=
poly
.
shape
[
0
]
//
2
direct_vector
=
(
(
poly
[
mid_idx
]
+
poly
[
mid_idx
-
1
])
-
(
poly
[
0
]
+
poly
[
-
1
]))
/
2.0
direct_vector
/=
len
(
txt
)
# l2_distance = norm2(direct_vector)
# avg_char_distance = l2_distance / len(txt)
avg_char_distance
=
1.0
direct_label
=
(
direct_vector
[
0
],
direct_vector
[
1
],
avg_char_distance
)
cv2
.
fillPoly
(
instance_direction_map
,
poly
.
round
().
astype
(
np
.
int32
)[
np
.
newaxis
,
:,
:],
direct_label
)
instance_direction_map
=
instance_direction_map
.
transpose
([
2
,
0
,
1
])
return
instance_direction_map
[:
2
,
...]
tools/infer_e2e.py
0 → 100755
View file @
1f76f449
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
json
import
paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
def
draw_e2e_res
(
dt_boxes
,
strs
,
config
,
img
,
img_name
):
if
len
(
dt_boxes
)
>
0
:
src_im
=
img
for
box
,
str
in
zip
(
dt_boxes
,
strs
):
box
=
box
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
cv2
.
putText
(
src_im
,
str
,
org
=
(
int
(
box
[
0
,
0
,
0
]),
int
(
box
[
0
,
0
,
1
])),
fontFace
=
cv2
.
FONT_HERSHEY_COMPLEX
,
fontScale
=
0.7
,
color
=
(
0
,
255
,
0
),
thickness
=
1
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/e2e_results/"
if
not
os
.
path
.
exists
(
save_det_path
):
os
.
makedirs
(
save_det_path
)
save_path
=
os
.
path
.
join
(
save_det_path
,
os
.
path
.
basename
(
img_name
))
cv2
.
imwrite
(
save_path
,
src_im
)
logger
.
info
(
"The e2e Image saved in {}"
.
format
(
save_path
))
def
main
():
global_config
=
config
[
'Global'
]
# build model
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
# create data ops
transforms
=
[]
for
op
in
config
[
'Eval'
][
'dataset'
][
'transforms'
]:
op_name
=
list
(
op
)[
0
]
if
'Label'
in
op_name
:
continue
elif
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
,
'shape'
]
transforms
.
append
(
op
)
ops
=
create_operators
(
transforms
,
global_config
)
save_res_path
=
config
[
'Global'
][
'save_res_path'
]
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
save_res_path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
save_res_path
))
model
.
eval
()
with
open
(
save_res_path
,
"wb"
)
as
fout
:
for
file
in
get_image_file_list
(
config
[
'Global'
][
'infer_img'
]):
logger
.
info
(
"infer_img: {}"
.
format
(
file
))
with
open
(
file
,
'rb'
)
as
f
:
img
=
f
.
read
()
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
ops
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
shape_list
=
np
.
expand_dims
(
batch
[
1
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
,
shape_list
)
points
,
strs
=
post_result
[
'points'
],
post_result
[
'strs'
]
# write resule
dt_boxes_json
=
[]
for
poly
,
str
in
zip
(
points
,
strs
):
tmp_json
=
{
"transcription"
:
str
}
tmp_json
[
'points'
]
=
poly
.
tolist
()
dt_boxes_json
.
append
(
tmp_json
)
otstr
=
file
+
"
\t
"
+
json
.
dumps
(
dt_boxes_json
)
+
"
\n
"
fout
.
write
(
otstr
.
encode
())
src_img
=
cv2
.
imread
(
file
)
draw_e2e_res
(
points
,
strs
,
config
,
src_img
,
file
)
logger
.
info
(
"success!"
)
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
()
\ No newline at end of file
tools/program.py
View file @
1f76f449
...
@@ -44,6 +44,7 @@ class ArgsParser(ArgumentParser):
...
@@ -44,6 +44,7 @@ class ArgsParser(ArgumentParser):
def
parse_args
(
self
,
argv
=
None
):
def
parse_args
(
self
,
argv
=
None
):
args
=
super
(
ArgsParser
,
self
).
parse_args
(
argv
)
args
=
super
(
ArgsParser
,
self
).
parse_args
(
argv
)
args
.
config
=
'/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml'
assert
args
.
config
is
not
None
,
\
assert
args
.
config
is
not
None
,
\
"Please specify --config=configure_file_path."
"Please specify --config=configure_file_path."
args
.
opt
=
self
.
_parse_opt
(
args
.
opt
)
args
.
opt
=
self
.
_parse_opt
(
args
.
opt
)
...
@@ -374,7 +375,8 @@ def preprocess(is_train=False):
...
@@ -374,7 +375,8 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PG'
]
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment