Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
bd7f8f72
"csrc/vscode:/vscode.git/clone" did not exist on "142fb7ce6c2739eda197624155374a711a88c01e"
Unverified
Commit
bd7f8f72
authored
Dec 09, 2020
by
MissPenguin
Committed by
GitHub
Dec 09, 2020
Browse files
Merge pull request #1363 from MissPenguin/dygraph
add east & sast
parents
3c9d3f6b
d42bf7a0
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
641 additions
and
4 deletions
+641
-4
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+3
-1
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+2
-2
ppocr/postprocess/east_postprocess.py
ppocr/postprocess/east_postprocess.py
+141
-0
ppocr/postprocess/locality_aware_nms.py
ppocr/postprocess/locality_aware_nms.py
+199
-0
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+1
-1
ppocr/postprocess/sast_postprocess.py
ppocr/postprocess/sast_postprocess.py
+295
-0
No files found.
ppocr/postprocess/__init__.py
View file @
bd7f8f72
...
...
@@ -24,11 +24,13 @@ __all__ = ['build_post_process']
def
build_post_process
(
config
,
global_config
=
None
):
from
.db_postprocess
import
DBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
from
.cls_postprocess
import
ClsPostProcess
support_dict
=
[
'DBPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/db_postprocess.py
View file @
bd7f8f72
...
...
@@ -138,9 +138,9 @@ class DBPostProcess(object):
boxes_batch
=
[]
for
batch_index
in
range
(
pred
.
shape
[
0
]):
height
,
width
=
shape_list
[
batch_index
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
batch_index
]
boxes
,
scores
=
self
.
boxes_from_bitmap
(
pred
[
batch_index
],
segmentation
[
batch_index
],
width
,
height
)
pred
[
batch_index
],
segmentation
[
batch_index
],
src_w
,
src_h
)
boxes_batch
.
append
({
'points'
:
boxes
})
return
boxes_batch
\ No newline at end of file
ppocr/postprocess/east_postprocess.py
0 → 100644
View file @
bd7f8f72
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
.locality_aware_nms
import
nms_locality
import
cv2
import
os
import
sys
# __dir__ = os.path.dirname(os.path.abspath(__file__))
# sys.path.append(__dir__)
# sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
class
EASTPostProcess
(
object
):
"""
The post process for EAST.
"""
def
__init__
(
self
,
score_thresh
=
0.8
,
cover_thresh
=
0.1
,
nms_thresh
=
0.2
,
**
kwargs
):
self
.
score_thresh
=
score_thresh
self
.
cover_thresh
=
cover_thresh
self
.
nms_thresh
=
nms_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
restore_rectangle_quad
(
self
,
origin
,
geometry
):
"""
Restore rectangle from quadrangle.
"""
# quad
origin_concat
=
np
.
concatenate
(
(
origin
,
origin
,
origin
,
origin
),
axis
=
1
)
# (n, 8)
pred_quads
=
origin_concat
-
geometry
pred_quads
=
pred_quads
.
reshape
((
-
1
,
4
,
2
))
# (n, 4, 2)
return
pred_quads
def
detect
(
self
,
score_map
,
geo_map
,
score_thresh
=
0.8
,
cover_thresh
=
0.1
,
nms_thresh
=
0.2
):
"""
restore text boxes from score map and geo map
"""
score_map
=
score_map
[
0
]
geo_map
=
np
.
swapaxes
(
geo_map
,
1
,
0
)
geo_map
=
np
.
swapaxes
(
geo_map
,
1
,
2
)
# filter the score map
xy_text
=
np
.
argwhere
(
score_map
>
score_thresh
)
if
len
(
xy_text
)
==
0
:
return
[]
# sort the text boxes via the y axis
xy_text
=
xy_text
[
np
.
argsort
(
xy_text
[:,
0
])]
#restore quad proposals
text_box_restored
=
self
.
restore_rectangle_quad
(
xy_text
[:,
::
-
1
]
*
4
,
geo_map
[
xy_text
[:,
0
],
xy_text
[:,
1
],
:])
boxes
=
np
.
zeros
((
text_box_restored
.
shape
[
0
],
9
),
dtype
=
np
.
float32
)
boxes
[:,
:
8
]
=
text_box_restored
.
reshape
((
-
1
,
8
))
boxes
[:,
8
]
=
score_map
[
xy_text
[:,
0
],
xy_text
[:,
1
]]
if
self
.
is_python35
:
import
lanms
boxes
=
lanms
.
merge_quadrangle_n9
(
boxes
,
nms_thresh
)
else
:
boxes
=
nms_locality
(
boxes
.
astype
(
np
.
float64
),
nms_thresh
)
if
boxes
.
shape
[
0
]
==
0
:
return
[]
# Here we filter some low score boxes by the average score map,
# this is different from the orginal paper.
for
i
,
box
in
enumerate
(
boxes
):
mask
=
np
.
zeros_like
(
score_map
,
dtype
=
np
.
uint8
)
cv2
.
fillPoly
(
mask
,
box
[:
8
].
reshape
(
(
-
1
,
4
,
2
)).
astype
(
np
.
int32
)
//
4
,
1
)
boxes
[
i
,
8
]
=
cv2
.
mean
(
score_map
,
mask
)[
0
]
boxes
=
boxes
[
boxes
[:,
8
]
>
cover_thresh
]
return
boxes
def
sort_poly
(
self
,
p
):
"""
Sort polygons.
"""
min_axis
=
np
.
argmin
(
np
.
sum
(
p
,
axis
=
1
))
p
=
p
[[
min_axis
,
(
min_axis
+
1
)
%
4
,
\
(
min_axis
+
2
)
%
4
,
(
min_axis
+
3
)
%
4
]]
if
abs
(
p
[
0
,
0
]
-
p
[
1
,
0
])
>
abs
(
p
[
0
,
1
]
-
p
[
1
,
1
]):
return
p
else
:
return
p
[[
0
,
3
,
2
,
1
]]
def
__call__
(
self
,
outs_dict
,
shape_list
):
score_list
=
outs_dict
[
'f_score'
]
geo_list
=
outs_dict
[
'f_geo'
]
img_num
=
len
(
shape_list
)
dt_boxes_list
=
[]
for
ino
in
range
(
img_num
):
score
=
score_list
[
ino
].
numpy
()
geo
=
geo_list
[
ino
].
numpy
()
boxes
=
self
.
detect
(
score_map
=
score
,
geo_map
=
geo
,
score_thresh
=
self
.
score_thresh
,
cover_thresh
=
self
.
cover_thresh
,
nms_thresh
=
self
.
nms_thresh
)
boxes_norm
=
[]
if
len
(
boxes
)
>
0
:
h
,
w
=
score
.
shape
[
1
:]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
ino
]
boxes
=
boxes
[:,
:
8
].
reshape
((
-
1
,
4
,
2
))
boxes
[:,
:,
0
]
/=
ratio_w
boxes
[:,
:,
1
]
/=
ratio_h
for
i_box
,
box
in
enumerate
(
boxes
):
box
=
self
.
sort_poly
(
box
.
astype
(
np
.
int32
))
if
np
.
linalg
.
norm
(
box
[
0
]
-
box
[
1
])
<
5
\
or
np
.
linalg
.
norm
(
box
[
3
]
-
box
[
0
])
<
5
:
continue
boxes_norm
.
append
(
box
)
dt_boxes_list
.
append
({
'points'
:
np
.
array
(
boxes_norm
)})
return
dt_boxes_list
\ No newline at end of file
ppocr/postprocess/locality_aware_nms.py
0 → 100644
View file @
bd7f8f72
"""
Locality aware nms.
"""
import
numpy
as
np
from
shapely.geometry
import
Polygon
def
intersection
(
g
,
p
):
"""
Intersection.
"""
g
=
Polygon
(
g
[:
8
].
reshape
((
4
,
2
)))
p
=
Polygon
(
p
[:
8
].
reshape
((
4
,
2
)))
g
=
g
.
buffer
(
0
)
p
=
p
.
buffer
(
0
)
if
not
g
.
is_valid
or
not
p
.
is_valid
:
return
0
inter
=
Polygon
(
g
).
intersection
(
Polygon
(
p
)).
area
union
=
g
.
area
+
p
.
area
-
inter
if
union
==
0
:
return
0
else
:
return
inter
/
union
def
intersection_iog
(
g
,
p
):
"""
Intersection_iog.
"""
g
=
Polygon
(
g
[:
8
].
reshape
((
4
,
2
)))
p
=
Polygon
(
p
[:
8
].
reshape
((
4
,
2
)))
if
not
g
.
is_valid
or
not
p
.
is_valid
:
return
0
inter
=
Polygon
(
g
).
intersection
(
Polygon
(
p
)).
area
#union = g.area + p.area - inter
union
=
p
.
area
if
union
==
0
:
print
(
"p_area is very small"
)
return
0
else
:
return
inter
/
union
def
weighted_merge
(
g
,
p
):
"""
Weighted merge.
"""
g
[:
8
]
=
(
g
[
8
]
*
g
[:
8
]
+
p
[
8
]
*
p
[:
8
])
/
(
g
[
8
]
+
p
[
8
])
g
[
8
]
=
(
g
[
8
]
+
p
[
8
])
return
g
def
standard_nms
(
S
,
thres
):
"""
Standard nms.
"""
order
=
np
.
argsort
(
S
[:,
8
])[::
-
1
]
keep
=
[]
while
order
.
size
>
0
:
i
=
order
[
0
]
keep
.
append
(
i
)
ovr
=
np
.
array
([
intersection
(
S
[
i
],
S
[
t
])
for
t
in
order
[
1
:]])
inds
=
np
.
where
(
ovr
<=
thres
)[
0
]
order
=
order
[
inds
+
1
]
return
S
[
keep
]
def
standard_nms_inds
(
S
,
thres
):
"""
Standard nms, retun inds.
"""
order
=
np
.
argsort
(
S
[:,
8
])[::
-
1
]
keep
=
[]
while
order
.
size
>
0
:
i
=
order
[
0
]
keep
.
append
(
i
)
ovr
=
np
.
array
([
intersection
(
S
[
i
],
S
[
t
])
for
t
in
order
[
1
:]])
inds
=
np
.
where
(
ovr
<=
thres
)[
0
]
order
=
order
[
inds
+
1
]
return
keep
def
nms
(
S
,
thres
):
"""
nms.
"""
order
=
np
.
argsort
(
S
[:,
8
])[::
-
1
]
keep
=
[]
while
order
.
size
>
0
:
i
=
order
[
0
]
keep
.
append
(
i
)
ovr
=
np
.
array
([
intersection
(
S
[
i
],
S
[
t
])
for
t
in
order
[
1
:]])
inds
=
np
.
where
(
ovr
<=
thres
)[
0
]
order
=
order
[
inds
+
1
]
return
keep
def
soft_nms
(
boxes_in
,
Nt_thres
=
0.3
,
threshold
=
0.8
,
sigma
=
0.5
,
method
=
2
):
"""
soft_nms
:para boxes_in, N x 9 (coords + score)
:para threshould, eliminate cases min score(0.001)
:para Nt_thres, iou_threshi
:para sigma, gaussian weght
:method, linear or gaussian
"""
boxes
=
boxes_in
.
copy
()
N
=
boxes
.
shape
[
0
]
if
N
is
None
or
N
<
1
:
return
np
.
array
([])
pos
,
maxpos
=
0
,
0
weight
=
0.0
inds
=
np
.
arange
(
N
)
tbox
,
sbox
=
boxes
[
0
].
copy
(),
boxes
[
0
].
copy
()
for
i
in
range
(
N
):
maxscore
=
boxes
[
i
,
8
]
maxpos
=
i
tbox
=
boxes
[
i
].
copy
()
ti
=
inds
[
i
]
pos
=
i
+
1
#get max box
while
pos
<
N
:
if
maxscore
<
boxes
[
pos
,
8
]:
maxscore
=
boxes
[
pos
,
8
]
maxpos
=
pos
pos
=
pos
+
1
#add max box as a detection
boxes
[
i
,
:]
=
boxes
[
maxpos
,
:]
inds
[
i
]
=
inds
[
maxpos
]
#swap
boxes
[
maxpos
,
:]
=
tbox
inds
[
maxpos
]
=
ti
tbox
=
boxes
[
i
].
copy
()
pos
=
i
+
1
#NMS iteration
while
pos
<
N
:
sbox
=
boxes
[
pos
].
copy
()
ts_iou_val
=
intersection
(
tbox
,
sbox
)
if
ts_iou_val
>
0
:
if
method
==
1
:
if
ts_iou_val
>
Nt_thres
:
weight
=
1
-
ts_iou_val
else
:
weight
=
1
elif
method
==
2
:
weight
=
np
.
exp
(
-
1.0
*
ts_iou_val
**
2
/
sigma
)
else
:
if
ts_iou_val
>
Nt_thres
:
weight
=
0
else
:
weight
=
1
boxes
[
pos
,
8
]
=
weight
*
boxes
[
pos
,
8
]
#if box score falls below thresold, discard the box by
#swaping last box update N
if
boxes
[
pos
,
8
]
<
threshold
:
boxes
[
pos
,
:]
=
boxes
[
N
-
1
,
:]
inds
[
pos
]
=
inds
[
N
-
1
]
N
=
N
-
1
pos
=
pos
-
1
pos
=
pos
+
1
return
boxes
[:
N
]
def
nms_locality
(
polys
,
thres
=
0.3
):
"""
locality aware nms of EAST
:param polys: a N*9 numpy array. first 8 coordinates, then prob
:return: boxes after nms
"""
S
=
[]
p
=
None
for
g
in
polys
:
if
p
is
not
None
and
intersection
(
g
,
p
)
>
thres
:
p
=
weighted_merge
(
g
,
p
)
else
:
if
p
is
not
None
:
S
.
append
(
p
)
p
=
g
if
p
is
not
None
:
S
.
append
(
p
)
if
len
(
S
)
==
0
:
return
np
.
array
([])
return
standard_nms
(
np
.
array
(
S
),
thres
)
if
__name__
==
'__main__'
:
# 343,350,448,135,474,143,369,359
print
(
Polygon
(
np
.
array
([[
343
,
350
],
[
448
,
135
],
[
474
,
143
],
[
369
,
359
]]))
.
area
)
\ No newline at end of file
ppocr/postprocess/rec_postprocess.py
View file @
bd7f8f72
...
...
@@ -27,7 +27,7 @@ class BaseRecLabelDecode(object):
'ch'
,
'en'
,
'en_sensitive'
,
'french'
,
'german'
,
'japan'
,
'korean'
]
assert
character_type
in
support_character_type
,
"Only {} are supported now but get {}"
.
format
(
support_character_type
,
self
.
character_
str
)
support_character_type
,
character_
type
)
if
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
...
...
ppocr/postprocess/sast_postprocess.py
0 → 100644
View file @
bd7f8f72
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
import
numpy
as
np
from
.locality_aware_nms
import
nms_locality
# import lanms
import
cv2
import
time
class
SASTPostProcess
(
object
):
"""
The post process for SAST.
"""
def
__init__
(
self
,
score_thresh
=
0.5
,
nms_thresh
=
0.2
,
sample_pts_num
=
2
,
shrink_ratio_of_width
=
0.3
,
expand_scale
=
1.0
,
tcl_map_thresh
=
0.5
,
**
kwargs
):
self
.
score_thresh
=
score_thresh
self
.
nms_thresh
=
nms_thresh
self
.
sample_pts_num
=
sample_pts_num
self
.
shrink_ratio_of_width
=
shrink_ratio_of_width
self
.
expand_scale
=
expand_scale
self
.
tcl_map_thresh
=
tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
point_pair2poly
(
self
,
point_pair_list
):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
# constract poly
point_num
=
len
(
point_pair_list
)
*
2
point_list
=
[
0
]
*
point_num
for
idx
,
point_pair
in
enumerate
(
point_pair_list
):
point_list
[
idx
]
=
point_pair
[
0
]
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
)
def
shrink_quad_along_width
(
self
,
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
"""
Generate shrink_quad_along_width.
"""
ratio_pair
=
np
.
array
([[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
def
expand_poly_along_width
(
self
,
poly
,
shrink_ratio_of_width
=
0.3
):
"""
expand poly along width.
"""
point_num
=
poly
.
shape
[
0
]
left_quad
=
np
.
array
([
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
left_quad_expand
=
self
.
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
right_quad
=
np
.
array
([
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]],
dtype
=
np
.
float32
)
right_ratio
=
1.0
+
\
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
self
.
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
poly
[
point_num
//
2
]
=
right_quad_expand
[
2
]
return
poly
def
restore_quad
(
self
,
tcl_map
,
tcl_map_thresh
,
tvo_map
):
"""Restore quad."""
xy_text
=
np
.
argwhere
(
tcl_map
[:,
:,
0
]
>
tcl_map_thresh
)
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
# Sort the text boxes via the y axis
xy_text
=
xy_text
[
np
.
argsort
(
xy_text
[:,
1
])]
scores
=
tcl_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
0
]
scores
=
scores
[:,
np
.
newaxis
]
# Restore
point_num
=
int
(
tvo_map
.
shape
[
-
1
]
/
2
)
assert
point_num
==
4
tvo_map
=
tvo_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
xy_text_tile
=
np
.
tile
(
xy_text
,
(
1
,
point_num
))
# (n, point_num * 2)
quads
=
xy_text_tile
-
tvo_map
return
scores
,
quads
,
xy_text
def
quad_area
(
self
,
quad
):
"""
compute area of a quad.
"""
edge
=
[
(
quad
[
1
][
0
]
-
quad
[
0
][
0
])
*
(
quad
[
1
][
1
]
+
quad
[
0
][
1
]),
(
quad
[
2
][
0
]
-
quad
[
1
][
0
])
*
(
quad
[
2
][
1
]
+
quad
[
1
][
1
]),
(
quad
[
3
][
0
]
-
quad
[
2
][
0
])
*
(
quad
[
3
][
1
]
+
quad
[
2
][
1
]),
(
quad
[
0
][
0
]
-
quad
[
3
][
0
])
*
(
quad
[
0
][
1
]
+
quad
[
3
][
1
])
]
return
np
.
sum
(
edge
)
/
2.
def
nms
(
self
,
dets
):
if
self
.
is_python35
:
import
lanms
dets
=
lanms
.
merge_quadrangle_n9
(
dets
,
self
.
nms_thresh
)
else
:
dets
=
nms_locality
(
dets
,
self
.
nms_thresh
)
return
dets
def
cluster_by_quads_tco
(
self
,
tcl_map
,
tcl_map_thresh
,
quads
,
tco_map
):
"""
Cluster pixels in tcl_map based on quads.
"""
instance_count
=
quads
.
shape
[
0
]
+
1
# contain background
instance_label_map
=
np
.
zeros
(
tcl_map
.
shape
[:
2
],
dtype
=
np
.
int32
)
if
instance_count
==
1
:
return
instance_count
,
instance_label_map
# predict text center
xy_text
=
np
.
argwhere
(
tcl_map
[:,
:,
0
]
>
tcl_map_thresh
)
n
=
xy_text
.
shape
[
0
]
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
tco
=
tco_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
# (n, 2)
pred_tc
=
xy_text
-
tco
# get gt text center
m
=
quads
.
shape
[
0
]
gt_tc
=
np
.
mean
(
quads
,
axis
=
1
)
# (m, 2)
pred_tc_tile
=
np
.
tile
(
pred_tc
[:,
np
.
newaxis
,
:],
(
1
,
m
,
1
))
# (n, m, 2)
gt_tc_tile
=
np
.
tile
(
gt_tc
[
np
.
newaxis
,
:,
:],
(
n
,
1
,
1
))
# (n, m, 2)
dist_mat
=
np
.
linalg
.
norm
(
pred_tc_tile
-
gt_tc_tile
,
axis
=
2
)
# (n, m)
xy_text_assign
=
np
.
argmin
(
dist_mat
,
axis
=
1
)
+
1
# (n,)
instance_label_map
[
xy_text
[:,
1
],
xy_text
[:,
0
]]
=
xy_text_assign
return
instance_count
,
instance_label_map
def
estimate_sample_pts_num
(
self
,
quad
,
xy_text
):
"""
Estimate sample points number.
"""
eh
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
/
2.0
ew
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
3
]))
/
2.0
dense_sample_pts_num
=
max
(
2
,
int
(
ew
))
dense_xy_center_line
=
xy_text
[
np
.
linspace
(
0
,
xy_text
.
shape
[
0
]
-
1
,
dense_sample_pts_num
,
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
dense_xy_center_line_diff
=
dense_xy_center_line
[
1
:]
-
dense_xy_center_line
[:
-
1
]
estimate_arc_len
=
np
.
sum
(
np
.
linalg
.
norm
(
dense_xy_center_line_diff
,
axis
=
1
))
sample_pts_num
=
max
(
2
,
int
(
estimate_arc_len
/
eh
))
return
sample_pts_num
def
detect_sast
(
self
,
tcl_map
,
tvo_map
,
tbo_map
,
tco_map
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
shrink_ratio_of_width
=
0.3
,
tcl_map_thresh
=
0.5
,
offset_expand
=
1.0
,
out_strid
=
4.0
):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
scores
,
quads
,
xy_text
=
self
.
restore_quad
(
tcl_map
,
tcl_map_thresh
,
tvo_map
)
dets
=
np
.
hstack
((
quads
,
scores
)).
astype
(
np
.
float32
,
copy
=
False
)
dets
=
self
.
nms
(
dets
)
if
dets
.
shape
[
0
]
==
0
:
return
[]
quads
=
dets
[:,
:
-
1
].
reshape
(
-
1
,
4
,
2
)
# Compute quad area
quad_areas
=
[]
for
quad
in
quads
:
quad_areas
.
append
(
-
self
.
quad_area
(
quad
))
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count
,
instance_label_map
=
self
.
cluster_by_quads_tco
(
tcl_map
,
tcl_map_thresh
,
quads
,
tco_map
)
# restore single poly with tcl instance.
poly_list
=
[]
for
instance_idx
in
range
(
1
,
instance_count
):
xy_text
=
np
.
argwhere
(
instance_label_map
==
instance_idx
)[:,
::
-
1
]
quad
=
quads
[
instance_idx
-
1
]
q_area
=
quad_areas
[
instance_idx
-
1
]
if
q_area
<
5
:
continue
#
len1
=
float
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
]))
len2
=
float
(
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
min_len
=
min
(
len1
,
len2
)
if
min_len
<
3
:
continue
# filter small CC
if
xy_text
.
shape
[
0
]
<=
0
:
continue
# filter low confidence instance
xy_text_scores
=
tcl_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
0
]
if
np
.
sum
(
xy_text_scores
)
/
quad_areas
[
instance_idx
-
1
]
<
0.1
:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue
# sort xy_text
left_center_pt
=
np
.
array
([[(
quad
[
0
,
0
]
+
quad
[
-
1
,
0
])
/
2.0
,
(
quad
[
0
,
1
]
+
quad
[
-
1
,
1
])
/
2.0
]])
# (1, 2)
right_center_pt
=
np
.
array
([[(
quad
[
1
,
0
]
+
quad
[
2
,
0
])
/
2.0
,
(
quad
[
1
,
1
]
+
quad
[
2
,
1
])
/
2.0
]])
# (1, 2)
proj_unit_vec
=
(
right_center_pt
-
left_center_pt
)
/
\
(
np
.
linalg
.
norm
(
right_center_pt
-
left_center_pt
)
+
1e-6
)
proj_value
=
np
.
sum
(
xy_text
*
proj_unit_vec
,
axis
=
1
)
xy_text
=
xy_text
[
np
.
argsort
(
proj_value
)]
# Sample pts in tcl map
if
self
.
sample_pts_num
==
0
:
sample_pts_num
=
self
.
estimate_sample_pts_num
(
quad
,
xy_text
)
else
:
sample_pts_num
=
self
.
sample_pts_num
xy_center_line
=
xy_text
[
np
.
linspace
(
0
,
xy_text
.
shape
[
0
]
-
1
,
sample_pts_num
,
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
point_pair_list
=
[]
for
x
,
y
in
xy_center_line
:
# get corresponding offset
offset
=
tbo_map
[
y
,
x
,
:].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
# original point
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
out_strid
/
np
.
array
([
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
# ndarry: (x, 2), expand poly along width
detected_poly
=
self
.
point_pair2poly
(
point_pair_list
)
detected_poly
=
self
.
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
poly_list
.
append
(
detected_poly
)
return
poly_list
def
__call__
(
self
,
outs_dict
,
shape_list
):
score_list
=
outs_dict
[
'f_score'
]
border_list
=
outs_dict
[
'f_border'
]
tvo_list
=
outs_dict
[
'f_tvo'
]
tco_list
=
outs_dict
[
'f_tco'
]
img_num
=
len
(
shape_list
)
poly_lists
=
[]
for
ino
in
range
(
img_num
):
p_score
=
score_list
[
ino
].
transpose
((
1
,
2
,
0
)).
numpy
()
p_border
=
border_list
[
ino
].
transpose
((
1
,
2
,
0
)).
numpy
()
p_tvo
=
tvo_list
[
ino
].
transpose
((
1
,
2
,
0
)).
numpy
()
p_tco
=
tco_list
[
ino
].
transpose
((
1
,
2
,
0
)).
numpy
()
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
ino
]
poly_list
=
self
.
detect_sast
(
p_score
,
p_tvo
,
p_border
,
p_tco
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
shrink_ratio_of_width
=
self
.
shrink_ratio_of_width
,
tcl_map_thresh
=
self
.
tcl_map_thresh
,
offset_expand
=
self
.
expand_scale
)
poly_lists
.
append
({
'points'
:
np
.
array
(
poly_list
)})
return
poly_lists
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment