Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
631fd9fd
Unverified
Commit
631fd9fd
authored
Dec 10, 2020
by
xiaoting
Committed by
GitHub
Dec 10, 2020
Browse files
Merge branch 'dygraph' into dygraph_doc
parents
8520dd1e
90b968d5
Changes
98
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
766 additions
and
308 deletions
+766
-308
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+3
-1
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+12
-4
ppocr/postprocess/east_postprocess.py
ppocr/postprocess/east_postprocess.py
+141
-0
ppocr/postprocess/locality_aware_nms.py
ppocr/postprocess/locality_aware_nms.py
+199
-0
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+1
-1
ppocr/postprocess/sast_postprocess.py
ppocr/postprocess/sast_postprocess.py
+295
-0
ppocr/utils/character.py
ppocr/utils/character.py
+0
-214
ppocr/utils/check.py
ppocr/utils/check.py
+0
-31
ppocr/utils/dict/en_dict.txt
ppocr/utils/dict/en_dict.txt
+63
-0
ppocr/utils/dict/french_dict.txt
ppocr/utils/dict/french_dict.txt
+2
-1
ppocr/utils/dict/german_dict.txt
ppocr/utils/dict/german_dict.txt
+2
-1
ppocr/utils/dict/japan_dict.txt
ppocr/utils/dict/japan_dict.txt
+2
-1
ppocr/utils/dict/korean_dict.txt
ppocr/utils/dict/korean_dict.txt
+3
-2
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+2
-2
setup.py
setup.py
+1
-1
tools/export_model.py
tools/export_model.py
+14
-28
tools/infer/predict_system.py
tools/infer/predict_system.py
+20
-13
tools/infer/utility.py
tools/infer/utility.py
+6
-8
No files found.
ppocr/postprocess/__init__.py
View file @
631fd9fd
...
...
@@ -24,11 +24,13 @@ __all__ = ['build_post_process']
def
build_post_process
(
config
,
global_config
=
None
):
from
.db_postprocess
import
DBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
from
.cls_postprocess
import
ClsPostProcess
support_dict
=
[
'DBPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/db_postprocess.py
View file @
631fd9fd
...
...
@@ -33,12 +33,14 @@ class DBPostProcess(object):
box_thresh
=
0.7
,
max_candidates
=
1000
,
unclip_ratio
=
2.0
,
use_dilation
=
False
,
**
kwargs
):
self
.
thresh
=
thresh
self
.
box_thresh
=
box_thresh
self
.
max_candidates
=
max_candidates
self
.
unclip_ratio
=
unclip_ratio
self
.
min_size
=
3
self
.
dilation_kernel
=
None
if
not
use_dilation
else
[[
1
,
1
],
[
1
,
1
]]
def
boxes_from_bitmap
(
self
,
pred
,
_bitmap
,
dest_width
,
dest_height
):
'''
...
...
@@ -138,9 +140,15 @@ class DBPostProcess(object):
boxes_batch
=
[]
for
batch_index
in
range
(
pred
.
shape
[
0
]):
height
,
width
=
shape_list
[
batch_index
]
boxes
,
scores
=
self
.
boxes_from_bitmap
(
pred
[
batch_index
],
segmentation
[
batch_index
],
width
,
height
)
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
batch_index
]
if
self
.
dilation_kernel
is
not
None
:
mask
=
cv2
.
dilate
(
np
.
array
(
segmentation
[
batch_index
]).
astype
(
np
.
uint8
),
self
.
dilation_kernel
)
else
:
mask
=
segmentation
[
batch_index
]
boxes
,
scores
=
self
.
boxes_from_bitmap
(
pred
[
batch_index
],
mask
,
src_w
,
src_h
)
boxes_batch
.
append
({
'points'
:
boxes
})
return
boxes_batch
\ No newline at end of file
return
boxes_batch
ppocr/postprocess/east_postprocess.py
0 → 100644
View file @
631fd9fd
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
.locality_aware_nms
import
nms_locality
import
cv2
import
os
import
sys
# __dir__ = os.path.dirname(os.path.abspath(__file__))
# sys.path.append(__dir__)
# sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
class
EASTPostProcess
(
object
):
"""
The post process for EAST.
"""
def
__init__
(
self
,
score_thresh
=
0.8
,
cover_thresh
=
0.1
,
nms_thresh
=
0.2
,
**
kwargs
):
self
.
score_thresh
=
score_thresh
self
.
cover_thresh
=
cover_thresh
self
.
nms_thresh
=
nms_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
restore_rectangle_quad
(
self
,
origin
,
geometry
):
"""
Restore rectangle from quadrangle.
"""
# quad
origin_concat
=
np
.
concatenate
(
(
origin
,
origin
,
origin
,
origin
),
axis
=
1
)
# (n, 8)
pred_quads
=
origin_concat
-
geometry
pred_quads
=
pred_quads
.
reshape
((
-
1
,
4
,
2
))
# (n, 4, 2)
return
pred_quads
def
detect
(
self
,
score_map
,
geo_map
,
score_thresh
=
0.8
,
cover_thresh
=
0.1
,
nms_thresh
=
0.2
):
"""
restore text boxes from score map and geo map
"""
score_map
=
score_map
[
0
]
geo_map
=
np
.
swapaxes
(
geo_map
,
1
,
0
)
geo_map
=
np
.
swapaxes
(
geo_map
,
1
,
2
)
# filter the score map
xy_text
=
np
.
argwhere
(
score_map
>
score_thresh
)
if
len
(
xy_text
)
==
0
:
return
[]
# sort the text boxes via the y axis
xy_text
=
xy_text
[
np
.
argsort
(
xy_text
[:,
0
])]
#restore quad proposals
text_box_restored
=
self
.
restore_rectangle_quad
(
xy_text
[:,
::
-
1
]
*
4
,
geo_map
[
xy_text
[:,
0
],
xy_text
[:,
1
],
:])
boxes
=
np
.
zeros
((
text_box_restored
.
shape
[
0
],
9
),
dtype
=
np
.
float32
)
boxes
[:,
:
8
]
=
text_box_restored
.
reshape
((
-
1
,
8
))
boxes
[:,
8
]
=
score_map
[
xy_text
[:,
0
],
xy_text
[:,
1
]]
if
self
.
is_python35
:
import
lanms
boxes
=
lanms
.
merge_quadrangle_n9
(
boxes
,
nms_thresh
)
else
:
boxes
=
nms_locality
(
boxes
.
astype
(
np
.
float64
),
nms_thresh
)
if
boxes
.
shape
[
0
]
==
0
:
return
[]
# Here we filter some low score boxes by the average score map,
# this is different from the orginal paper.
for
i
,
box
in
enumerate
(
boxes
):
mask
=
np
.
zeros_like
(
score_map
,
dtype
=
np
.
uint8
)
cv2
.
fillPoly
(
mask
,
box
[:
8
].
reshape
(
(
-
1
,
4
,
2
)).
astype
(
np
.
int32
)
//
4
,
1
)
boxes
[
i
,
8
]
=
cv2
.
mean
(
score_map
,
mask
)[
0
]
boxes
=
boxes
[
boxes
[:,
8
]
>
cover_thresh
]
return
boxes
def
sort_poly
(
self
,
p
):
"""
Sort polygons.
"""
min_axis
=
np
.
argmin
(
np
.
sum
(
p
,
axis
=
1
))
p
=
p
[[
min_axis
,
(
min_axis
+
1
)
%
4
,
\
(
min_axis
+
2
)
%
4
,
(
min_axis
+
3
)
%
4
]]
if
abs
(
p
[
0
,
0
]
-
p
[
1
,
0
])
>
abs
(
p
[
0
,
1
]
-
p
[
1
,
1
]):
return
p
else
:
return
p
[[
0
,
3
,
2
,
1
]]
def
__call__
(
self
,
outs_dict
,
shape_list
):
score_list
=
outs_dict
[
'f_score'
]
geo_list
=
outs_dict
[
'f_geo'
]
img_num
=
len
(
shape_list
)
dt_boxes_list
=
[]
for
ino
in
range
(
img_num
):
score
=
score_list
[
ino
].
numpy
()
geo
=
geo_list
[
ino
].
numpy
()
boxes
=
self
.
detect
(
score_map
=
score
,
geo_map
=
geo
,
score_thresh
=
self
.
score_thresh
,
cover_thresh
=
self
.
cover_thresh
,
nms_thresh
=
self
.
nms_thresh
)
boxes_norm
=
[]
if
len
(
boxes
)
>
0
:
h
,
w
=
score
.
shape
[
1
:]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
ino
]
boxes
=
boxes
[:,
:
8
].
reshape
((
-
1
,
4
,
2
))
boxes
[:,
:,
0
]
/=
ratio_w
boxes
[:,
:,
1
]
/=
ratio_h
for
i_box
,
box
in
enumerate
(
boxes
):
box
=
self
.
sort_poly
(
box
.
astype
(
np
.
int32
))
if
np
.
linalg
.
norm
(
box
[
0
]
-
box
[
1
])
<
5
\
or
np
.
linalg
.
norm
(
box
[
3
]
-
box
[
0
])
<
5
:
continue
boxes_norm
.
append
(
box
)
dt_boxes_list
.
append
({
'points'
:
np
.
array
(
boxes_norm
)})
return
dt_boxes_list
\ No newline at end of file
ppocr/postprocess/locality_aware_nms.py
0 → 100644
View file @
631fd9fd
"""
Locality aware nms.
"""
import
numpy
as
np
from
shapely.geometry
import
Polygon
def
intersection
(
g
,
p
):
"""
Intersection.
"""
g
=
Polygon
(
g
[:
8
].
reshape
((
4
,
2
)))
p
=
Polygon
(
p
[:
8
].
reshape
((
4
,
2
)))
g
=
g
.
buffer
(
0
)
p
=
p
.
buffer
(
0
)
if
not
g
.
is_valid
or
not
p
.
is_valid
:
return
0
inter
=
Polygon
(
g
).
intersection
(
Polygon
(
p
)).
area
union
=
g
.
area
+
p
.
area
-
inter
if
union
==
0
:
return
0
else
:
return
inter
/
union
def
intersection_iog
(
g
,
p
):
"""
Intersection_iog.
"""
g
=
Polygon
(
g
[:
8
].
reshape
((
4
,
2
)))
p
=
Polygon
(
p
[:
8
].
reshape
((
4
,
2
)))
if
not
g
.
is_valid
or
not
p
.
is_valid
:
return
0
inter
=
Polygon
(
g
).
intersection
(
Polygon
(
p
)).
area
#union = g.area + p.area - inter
union
=
p
.
area
if
union
==
0
:
print
(
"p_area is very small"
)
return
0
else
:
return
inter
/
union
def
weighted_merge
(
g
,
p
):
"""
Weighted merge.
"""
g
[:
8
]
=
(
g
[
8
]
*
g
[:
8
]
+
p
[
8
]
*
p
[:
8
])
/
(
g
[
8
]
+
p
[
8
])
g
[
8
]
=
(
g
[
8
]
+
p
[
8
])
return
g
def
standard_nms
(
S
,
thres
):
"""
Standard nms.
"""
order
=
np
.
argsort
(
S
[:,
8
])[::
-
1
]
keep
=
[]
while
order
.
size
>
0
:
i
=
order
[
0
]
keep
.
append
(
i
)
ovr
=
np
.
array
([
intersection
(
S
[
i
],
S
[
t
])
for
t
in
order
[
1
:]])
inds
=
np
.
where
(
ovr
<=
thres
)[
0
]
order
=
order
[
inds
+
1
]
return
S
[
keep
]
def
standard_nms_inds
(
S
,
thres
):
"""
Standard nms, retun inds.
"""
order
=
np
.
argsort
(
S
[:,
8
])[::
-
1
]
keep
=
[]
while
order
.
size
>
0
:
i
=
order
[
0
]
keep
.
append
(
i
)
ovr
=
np
.
array
([
intersection
(
S
[
i
],
S
[
t
])
for
t
in
order
[
1
:]])
inds
=
np
.
where
(
ovr
<=
thres
)[
0
]
order
=
order
[
inds
+
1
]
return
keep
def
nms
(
S
,
thres
):
"""
nms.
"""
order
=
np
.
argsort
(
S
[:,
8
])[::
-
1
]
keep
=
[]
while
order
.
size
>
0
:
i
=
order
[
0
]
keep
.
append
(
i
)
ovr
=
np
.
array
([
intersection
(
S
[
i
],
S
[
t
])
for
t
in
order
[
1
:]])
inds
=
np
.
where
(
ovr
<=
thres
)[
0
]
order
=
order
[
inds
+
1
]
return
keep
def
soft_nms
(
boxes_in
,
Nt_thres
=
0.3
,
threshold
=
0.8
,
sigma
=
0.5
,
method
=
2
):
"""
soft_nms
:para boxes_in, N x 9 (coords + score)
:para threshould, eliminate cases min score(0.001)
:para Nt_thres, iou_threshi
:para sigma, gaussian weght
:method, linear or gaussian
"""
boxes
=
boxes_in
.
copy
()
N
=
boxes
.
shape
[
0
]
if
N
is
None
or
N
<
1
:
return
np
.
array
([])
pos
,
maxpos
=
0
,
0
weight
=
0.0
inds
=
np
.
arange
(
N
)
tbox
,
sbox
=
boxes
[
0
].
copy
(),
boxes
[
0
].
copy
()
for
i
in
range
(
N
):
maxscore
=
boxes
[
i
,
8
]
maxpos
=
i
tbox
=
boxes
[
i
].
copy
()
ti
=
inds
[
i
]
pos
=
i
+
1
#get max box
while
pos
<
N
:
if
maxscore
<
boxes
[
pos
,
8
]:
maxscore
=
boxes
[
pos
,
8
]
maxpos
=
pos
pos
=
pos
+
1
#add max box as a detection
boxes
[
i
,
:]
=
boxes
[
maxpos
,
:]
inds
[
i
]
=
inds
[
maxpos
]
#swap
boxes
[
maxpos
,
:]
=
tbox
inds
[
maxpos
]
=
ti
tbox
=
boxes
[
i
].
copy
()
pos
=
i
+
1
#NMS iteration
while
pos
<
N
:
sbox
=
boxes
[
pos
].
copy
()
ts_iou_val
=
intersection
(
tbox
,
sbox
)
if
ts_iou_val
>
0
:
if
method
==
1
:
if
ts_iou_val
>
Nt_thres
:
weight
=
1
-
ts_iou_val
else
:
weight
=
1
elif
method
==
2
:
weight
=
np
.
exp
(
-
1.0
*
ts_iou_val
**
2
/
sigma
)
else
:
if
ts_iou_val
>
Nt_thres
:
weight
=
0
else
:
weight
=
1
boxes
[
pos
,
8
]
=
weight
*
boxes
[
pos
,
8
]
#if box score falls below thresold, discard the box by
#swaping last box update N
if
boxes
[
pos
,
8
]
<
threshold
:
boxes
[
pos
,
:]
=
boxes
[
N
-
1
,
:]
inds
[
pos
]
=
inds
[
N
-
1
]
N
=
N
-
1
pos
=
pos
-
1
pos
=
pos
+
1
return
boxes
[:
N
]
def
nms_locality
(
polys
,
thres
=
0.3
):
"""
locality aware nms of EAST
:param polys: a N*9 numpy array. first 8 coordinates, then prob
:return: boxes after nms
"""
S
=
[]
p
=
None
for
g
in
polys
:
if
p
is
not
None
and
intersection
(
g
,
p
)
>
thres
:
p
=
weighted_merge
(
g
,
p
)
else
:
if
p
is
not
None
:
S
.
append
(
p
)
p
=
g
if
p
is
not
None
:
S
.
append
(
p
)
if
len
(
S
)
==
0
:
return
np
.
array
([])
return
standard_nms
(
np
.
array
(
S
),
thres
)
if
__name__
==
'__main__'
:
# 343,350,448,135,474,143,369,359
print
(
Polygon
(
np
.
array
([[
343
,
350
],
[
448
,
135
],
[
474
,
143
],
[
369
,
359
]]))
.
area
)
\ No newline at end of file
ppocr/postprocess/rec_postprocess.py
View file @
631fd9fd
...
...
@@ -27,7 +27,7 @@ class BaseRecLabelDecode(object):
'ch'
,
'en'
,
'en_sensitive'
,
'french'
,
'german'
,
'japan'
,
'korean'
]
assert
character_type
in
support_character_type
,
"Only {} are supported now but get {}"
.
format
(
support_character_type
,
self
.
character_
str
)
support_character_type
,
character_
type
)
if
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
...
...
ppocr/postprocess/sast_postprocess.py
0 → 100644
View file @
631fd9fd
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
import
numpy
as
np
from
.locality_aware_nms
import
nms_locality
# import lanms
import
cv2
import
time
class
SASTPostProcess
(
object
):
"""
The post process for SAST.
"""
def
__init__
(
self
,
score_thresh
=
0.5
,
nms_thresh
=
0.2
,
sample_pts_num
=
2
,
shrink_ratio_of_width
=
0.3
,
expand_scale
=
1.0
,
tcl_map_thresh
=
0.5
,
**
kwargs
):
self
.
score_thresh
=
score_thresh
self
.
nms_thresh
=
nms_thresh
self
.
sample_pts_num
=
sample_pts_num
self
.
shrink_ratio_of_width
=
shrink_ratio_of_width
self
.
expand_scale
=
expand_scale
self
.
tcl_map_thresh
=
tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
point_pair2poly
(
self
,
point_pair_list
):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
# constract poly
point_num
=
len
(
point_pair_list
)
*
2
point_list
=
[
0
]
*
point_num
for
idx
,
point_pair
in
enumerate
(
point_pair_list
):
point_list
[
idx
]
=
point_pair
[
0
]
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
)
def
shrink_quad_along_width
(
self
,
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
"""
Generate shrink_quad_along_width.
"""
ratio_pair
=
np
.
array
([[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
def
expand_poly_along_width
(
self
,
poly
,
shrink_ratio_of_width
=
0.3
):
"""
expand poly along width.
"""
point_num
=
poly
.
shape
[
0
]
left_quad
=
np
.
array
([
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
left_quad_expand
=
self
.
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
right_quad
=
np
.
array
([
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]],
dtype
=
np
.
float32
)
right_ratio
=
1.0
+
\
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
self
.
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
poly
[
point_num
//
2
]
=
right_quad_expand
[
2
]
return
poly
def
restore_quad
(
self
,
tcl_map
,
tcl_map_thresh
,
tvo_map
):
"""Restore quad."""
xy_text
=
np
.
argwhere
(
tcl_map
[:,
:,
0
]
>
tcl_map_thresh
)
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
# Sort the text boxes via the y axis
xy_text
=
xy_text
[
np
.
argsort
(
xy_text
[:,
1
])]
scores
=
tcl_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
0
]
scores
=
scores
[:,
np
.
newaxis
]
# Restore
point_num
=
int
(
tvo_map
.
shape
[
-
1
]
/
2
)
assert
point_num
==
4
tvo_map
=
tvo_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
xy_text_tile
=
np
.
tile
(
xy_text
,
(
1
,
point_num
))
# (n, point_num * 2)
quads
=
xy_text_tile
-
tvo_map
return
scores
,
quads
,
xy_text
def
quad_area
(
self
,
quad
):
"""
compute area of a quad.
"""
edge
=
[
(
quad
[
1
][
0
]
-
quad
[
0
][
0
])
*
(
quad
[
1
][
1
]
+
quad
[
0
][
1
]),
(
quad
[
2
][
0
]
-
quad
[
1
][
0
])
*
(
quad
[
2
][
1
]
+
quad
[
1
][
1
]),
(
quad
[
3
][
0
]
-
quad
[
2
][
0
])
*
(
quad
[
3
][
1
]
+
quad
[
2
][
1
]),
(
quad
[
0
][
0
]
-
quad
[
3
][
0
])
*
(
quad
[
0
][
1
]
+
quad
[
3
][
1
])
]
return
np
.
sum
(
edge
)
/
2.
def
nms
(
self
,
dets
):
if
self
.
is_python35
:
import
lanms
dets
=
lanms
.
merge_quadrangle_n9
(
dets
,
self
.
nms_thresh
)
else
:
dets
=
nms_locality
(
dets
,
self
.
nms_thresh
)
return
dets
def
cluster_by_quads_tco
(
self
,
tcl_map
,
tcl_map_thresh
,
quads
,
tco_map
):
"""
Cluster pixels in tcl_map based on quads.
"""
instance_count
=
quads
.
shape
[
0
]
+
1
# contain background
instance_label_map
=
np
.
zeros
(
tcl_map
.
shape
[:
2
],
dtype
=
np
.
int32
)
if
instance_count
==
1
:
return
instance_count
,
instance_label_map
# predict text center
xy_text
=
np
.
argwhere
(
tcl_map
[:,
:,
0
]
>
tcl_map_thresh
)
n
=
xy_text
.
shape
[
0
]
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
tco
=
tco_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
# (n, 2)
pred_tc
=
xy_text
-
tco
# get gt text center
m
=
quads
.
shape
[
0
]
gt_tc
=
np
.
mean
(
quads
,
axis
=
1
)
# (m, 2)
pred_tc_tile
=
np
.
tile
(
pred_tc
[:,
np
.
newaxis
,
:],
(
1
,
m
,
1
))
# (n, m, 2)
gt_tc_tile
=
np
.
tile
(
gt_tc
[
np
.
newaxis
,
:,
:],
(
n
,
1
,
1
))
# (n, m, 2)
dist_mat
=
np
.
linalg
.
norm
(
pred_tc_tile
-
gt_tc_tile
,
axis
=
2
)
# (n, m)
xy_text_assign
=
np
.
argmin
(
dist_mat
,
axis
=
1
)
+
1
# (n,)
instance_label_map
[
xy_text
[:,
1
],
xy_text
[:,
0
]]
=
xy_text_assign
return
instance_count
,
instance_label_map
def
estimate_sample_pts_num
(
self
,
quad
,
xy_text
):
"""
Estimate sample points number.
"""
eh
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
/
2.0
ew
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
3
]))
/
2.0
dense_sample_pts_num
=
max
(
2
,
int
(
ew
))
dense_xy_center_line
=
xy_text
[
np
.
linspace
(
0
,
xy_text
.
shape
[
0
]
-
1
,
dense_sample_pts_num
,
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
dense_xy_center_line_diff
=
dense_xy_center_line
[
1
:]
-
dense_xy_center_line
[:
-
1
]
estimate_arc_len
=
np
.
sum
(
np
.
linalg
.
norm
(
dense_xy_center_line_diff
,
axis
=
1
))
sample_pts_num
=
max
(
2
,
int
(
estimate_arc_len
/
eh
))
return
sample_pts_num
def
detect_sast
(
self
,
tcl_map
,
tvo_map
,
tbo_map
,
tco_map
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
shrink_ratio_of_width
=
0.3
,
tcl_map_thresh
=
0.5
,
offset_expand
=
1.0
,
out_strid
=
4.0
):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
scores
,
quads
,
xy_text
=
self
.
restore_quad
(
tcl_map
,
tcl_map_thresh
,
tvo_map
)
dets
=
np
.
hstack
((
quads
,
scores
)).
astype
(
np
.
float32
,
copy
=
False
)
dets
=
self
.
nms
(
dets
)
if
dets
.
shape
[
0
]
==
0
:
return
[]
quads
=
dets
[:,
:
-
1
].
reshape
(
-
1
,
4
,
2
)
# Compute quad area
quad_areas
=
[]
for
quad
in
quads
:
quad_areas
.
append
(
-
self
.
quad_area
(
quad
))
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count
,
instance_label_map
=
self
.
cluster_by_quads_tco
(
tcl_map
,
tcl_map_thresh
,
quads
,
tco_map
)
# restore single poly with tcl instance.
poly_list
=
[]
for
instance_idx
in
range
(
1
,
instance_count
):
xy_text
=
np
.
argwhere
(
instance_label_map
==
instance_idx
)[:,
::
-
1
]
quad
=
quads
[
instance_idx
-
1
]
q_area
=
quad_areas
[
instance_idx
-
1
]
if
q_area
<
5
:
continue
#
len1
=
float
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
]))
len2
=
float
(
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
min_len
=
min
(
len1
,
len2
)
if
min_len
<
3
:
continue
# filter small CC
if
xy_text
.
shape
[
0
]
<=
0
:
continue
# filter low confidence instance
xy_text_scores
=
tcl_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
0
]
if
np
.
sum
(
xy_text_scores
)
/
quad_areas
[
instance_idx
-
1
]
<
0.1
:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue
# sort xy_text
left_center_pt
=
np
.
array
([[(
quad
[
0
,
0
]
+
quad
[
-
1
,
0
])
/
2.0
,
(
quad
[
0
,
1
]
+
quad
[
-
1
,
1
])
/
2.0
]])
# (1, 2)
right_center_pt
=
np
.
array
([[(
quad
[
1
,
0
]
+
quad
[
2
,
0
])
/
2.0
,
(
quad
[
1
,
1
]
+
quad
[
2
,
1
])
/
2.0
]])
# (1, 2)
proj_unit_vec
=
(
right_center_pt
-
left_center_pt
)
/
\
(
np
.
linalg
.
norm
(
right_center_pt
-
left_center_pt
)
+
1e-6
)
proj_value
=
np
.
sum
(
xy_text
*
proj_unit_vec
,
axis
=
1
)
xy_text
=
xy_text
[
np
.
argsort
(
proj_value
)]
# Sample pts in tcl map
if
self
.
sample_pts_num
==
0
:
sample_pts_num
=
self
.
estimate_sample_pts_num
(
quad
,
xy_text
)
else
:
sample_pts_num
=
self
.
sample_pts_num
xy_center_line
=
xy_text
[
np
.
linspace
(
0
,
xy_text
.
shape
[
0
]
-
1
,
sample_pts_num
,
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
point_pair_list
=
[]
for
x
,
y
in
xy_center_line
:
# get corresponding offset
offset
=
tbo_map
[
y
,
x
,
:].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
# original point
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
out_strid
/
np
.
array
([
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
# ndarry: (x, 2), expand poly along width
detected_poly
=
self
.
point_pair2poly
(
point_pair_list
)
detected_poly
=
self
.
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
poly_list
.
append
(
detected_poly
)
return
poly_list
def
__call__
(
self
,
outs_dict
,
shape_list
):
score_list
=
outs_dict
[
'f_score'
]
border_list
=
outs_dict
[
'f_border'
]
tvo_list
=
outs_dict
[
'f_tvo'
]
tco_list
=
outs_dict
[
'f_tco'
]
img_num
=
len
(
shape_list
)
poly_lists
=
[]
for
ino
in
range
(
img_num
):
p_score
=
score_list
[
ino
].
transpose
((
1
,
2
,
0
)).
numpy
()
p_border
=
border_list
[
ino
].
transpose
((
1
,
2
,
0
)).
numpy
()
p_tvo
=
tvo_list
[
ino
].
transpose
((
1
,
2
,
0
)).
numpy
()
p_tco
=
tco_list
[
ino
].
transpose
((
1
,
2
,
0
)).
numpy
()
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
ino
]
poly_list
=
self
.
detect_sast
(
p_score
,
p_tvo
,
p_border
,
p_tco
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
shrink_ratio_of_width
=
self
.
shrink_ratio_of_width
,
tcl_map_thresh
=
self
.
tcl_map_thresh
,
offset_expand
=
self
.
expand_scale
)
poly_lists
.
append
({
'points'
:
np
.
array
(
poly_list
)})
return
poly_lists
ppocr/utils/character.py
deleted
100755 → 0
View file @
8520dd1e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
string
import
re
from
.check
import
check_config_params
import
sys
class
CharacterOps
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
config
):
self
.
character_type
=
config
[
'character_type'
]
self
.
loss_type
=
config
[
'loss_type'
]
self
.
max_text_len
=
config
[
'max_text_length'
]
if
self
.
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
elif
self
.
character_type
==
"ch"
:
character_dict_path
=
config
[
'character_dict_path'
]
add_space
=
False
if
'use_space_char'
in
config
:
add_space
=
config
[
'use_space_char'
]
self
.
character_str
=
""
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
line
=
line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
self
.
character_str
+=
line
if
add_space
:
self
.
character_str
+=
" "
dict_character
=
list
(
self
.
character_str
)
elif
self
.
character_type
==
"en_sensitive"
:
# same with ASTER setting (use 94 char).
self
.
character_str
=
string
.
printable
[:
-
6
]
dict_character
=
list
(
self
.
character_str
)
else
:
self
.
character_str
=
None
assert
self
.
character_str
is
not
None
,
\
"Nonsupport type of the character: {}"
.
format
(
self
.
character_str
)
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
if
self
.
loss_type
==
"attention"
:
dict_character
=
[
self
.
beg_str
,
self
.
end_str
]
+
dict_character
elif
self
.
loss_type
==
"srn"
:
dict_character
=
dict_character
+
[
self
.
beg_str
,
self
.
end_str
]
self
.
dict
=
{}
for
i
,
char
in
enumerate
(
dict_character
):
self
.
dict
[
char
]
=
i
self
.
character
=
dict_character
def
encode
(
self
,
text
):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if
self
.
character_type
==
"en"
:
text
=
text
.
lower
()
text_list
=
[]
for
char
in
text
:
if
char
not
in
self
.
dict
:
continue
text_list
.
append
(
self
.
dict
[
char
])
text
=
np
.
array
(
text_list
)
return
text
def
decode
(
self
,
text_index
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
char_list
=
[]
char_num
=
self
.
get_char_num
()
if
self
.
loss_type
==
"attention"
:
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
ignored_tokens
=
[
beg_idx
,
end_idx
]
else
:
ignored_tokens
=
[
char_num
]
for
idx
in
range
(
len
(
text_index
)):
if
text_index
[
idx
]
in
ignored_tokens
:
continue
if
is_remove_duplicate
:
if
idx
>
0
and
text_index
[
idx
-
1
]
==
text_index
[
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
idx
])])
text
=
''
.
join
(
char_list
)
return
text
def
get_char_num
(
self
):
return
len
(
self
.
character
)
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
self
.
loss_type
==
"attention"
:
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
else
:
err
=
"error in get_beg_end_flag_idx when using the loss %s"
\
%
(
self
.
loss_type
)
assert
False
,
err
def
cal_predicts_accuracy
(
char_ops
,
preds
,
preds_lod
,
labels
,
labels_lod
,
is_remove_duplicate
=
False
):
acc_num
=
0
img_num
=
0
for
ino
in
range
(
len
(
labels_lod
)
-
1
):
beg_no
=
preds_lod
[
ino
]
end_no
=
preds_lod
[
ino
+
1
]
preds_text
=
preds
[
beg_no
:
end_no
].
reshape
(
-
1
)
preds_text
=
char_ops
.
decode
(
preds_text
,
is_remove_duplicate
)
beg_no
=
labels_lod
[
ino
]
end_no
=
labels_lod
[
ino
+
1
]
labels_text
=
labels
[
beg_no
:
end_no
].
reshape
(
-
1
)
labels_text
=
char_ops
.
decode
(
labels_text
,
is_remove_duplicate
)
img_num
+=
1
if
preds_text
==
labels_text
:
acc_num
+=
1
acc
=
acc_num
*
1.0
/
img_num
return
acc
,
acc_num
,
img_num
def
cal_predicts_accuracy_srn
(
char_ops
,
preds
,
labels
,
max_text_len
,
is_debug
=
False
):
acc_num
=
0
img_num
=
0
char_num
=
char_ops
.
get_char_num
()
total_len
=
preds
.
shape
[
0
]
img_num
=
int
(
total_len
/
max_text_len
)
for
i
in
range
(
img_num
):
cur_label
=
[]
cur_pred
=
[]
for
j
in
range
(
max_text_len
):
if
labels
[
j
+
i
*
max_text_len
]
!=
int
(
char_num
-
1
):
#0
cur_label
.
append
(
labels
[
j
+
i
*
max_text_len
][
0
])
else
:
break
for
j
in
range
(
max_text_len
+
1
):
if
j
<
len
(
cur_label
)
and
preds
[
j
+
i
*
max_text_len
][
0
]
!=
cur_label
[
j
]:
break
elif
j
==
len
(
cur_label
)
and
j
==
max_text_len
:
acc_num
+=
1
break
elif
j
==
len
(
cur_label
)
and
preds
[
j
+
i
*
max_text_len
][
0
]
==
int
(
char_num
-
1
):
acc_num
+=
1
break
acc
=
acc_num
*
1.0
/
img_num
return
acc
,
acc_num
,
img_num
def
convert_rec_attention_infer_res
(
preds
):
img_num
=
preds
.
shape
[
0
]
target_lod
=
[
0
]
convert_ids
=
[]
for
ino
in
range
(
img_num
):
end_pos
=
np
.
where
(
preds
[
ino
,
:]
==
1
)[
0
]
if
len
(
end_pos
)
<=
1
:
text_list
=
preds
[
ino
,
1
:]
else
:
text_list
=
preds
[
ino
,
1
:
end_pos
[
1
]]
target_lod
.
append
(
target_lod
[
ino
]
+
len
(
text_list
))
convert_ids
=
convert_ids
+
list
(
text_list
)
convert_ids
=
np
.
array
(
convert_ids
)
convert_ids
=
convert_ids
.
reshape
((
-
1
,
1
))
return
convert_ids
,
target_lod
def
convert_rec_label_to_lod
(
ori_labels
):
img_num
=
len
(
ori_labels
)
target_lod
=
[
0
]
convert_ids
=
[]
for
ino
in
range
(
img_num
):
target_lod
.
append
(
target_lod
[
ino
]
+
len
(
ori_labels
[
ino
]))
convert_ids
=
convert_ids
+
list
(
ori_labels
[
ino
])
convert_ids
=
np
.
array
(
convert_ids
)
convert_ids
=
convert_ids
.
reshape
((
-
1
,
1
))
return
convert_ids
,
target_lod
ppocr/utils/check.py
deleted
100755 → 0
View file @
8520dd1e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
sys
import
logging
logger
=
logging
.
getLogger
(
__name__
)
def
check_config_params
(
config
,
config_name
,
params
):
for
param
in
params
:
if
param
not
in
config
:
err
=
"param %s didn't find in %s!"
%
(
param
,
config_name
)
assert
False
,
err
return
ppocr/utils/dict/en_dict.txt
0 → 100644
View file @
631fd9fd
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
ppocr/utils/dict/french_dict.txt
View file @
631fd9fd
...
...
@@ -132,4 +132,5 @@ j
³
Å
$
#
\ No newline at end of file
#
ppocr/utils/dict/german_dict.txt
View file @
631fd9fd
...
...
@@ -123,4 +123,5 @@ z
â
å
æ
é
\ No newline at end of file
é
ppocr/utils/dict/japan_dict.txt
View file @
631fd9fd
...
...
@@ -4395,4 +4395,5 @@ z
y
z
~
・
\ No newline at end of file
・
ppocr/utils/dict/korean_dict.txt
View file @
631fd9fd
...
...
@@ -179,7 +179,7 @@ z
с
т
я
’
“
”
...
...
@@ -3684,4 +3684,5 @@ z
立
茶
切
宅
\ No newline at end of file
宅
ppocr/utils/save_load.py
View file @
631fd9fd
...
...
@@ -55,8 +55,8 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
weight_name
=
weight_name
.
replace
(
'binarize'
,
''
).
replace
(
'thresh'
,
''
)
# for DB
if
weight_name
in
pre_state_dict
.
keys
():
logger
.
info
(
'Load weight: {}, shape: {}'
.
format
(
weight_name
,
pre_state_dict
[
weight_name
].
shape
))
#
logger.info('Load weight: {}, shape: {}'.format(
#
weight_name, pre_state_dict[weight_name].shape))
if
'encoder_rnn'
in
key
:
# delete axis which is 1
pre_state_dict
[
weight_name
]
=
pre_state_dict
[
...
...
setup.py
View file @
631fd9fd
...
...
@@ -32,7 +32,7 @@ setup(
package_dir
=
{
'paddleocr'
:
''
},
include_package_data
=
True
,
entry_points
=
{
"console_scripts"
:
[
"paddleocr= paddleocr.paddleocr:main"
]},
version
=
'
0
.0
.3
'
,
version
=
'
2
.0'
,
install_requires
=
requirements
,
license
=
'Apache License 2.0'
,
description
=
'Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices'
,
...
...
tools/export_model.py
View file @
631fd9fd
...
...
@@ -28,37 +28,17 @@ from ppocr.modeling.architectures import build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.logging
import
get_logger
from
tools.program
import
load_config
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
parser
.
add_argument
(
"-o"
,
"--output_path"
,
type
=
str
,
default
=
'./output/infer/'
)
return
parser
.
parse_args
()
class
Model
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
model
):
super
(
Model
,
self
).
__init__
()
self
.
pre_model
=
model
# Please modify the 'shape' according to actual needs
@
to_static
(
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
640
,
640
],
dtype
=
'float32'
)
])
def
forward
(
self
,
inputs
):
x
=
self
.
pre_model
(
inputs
)
return
x
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
def
main
():
FLAGS
=
parse_args
()
FLAGS
=
ArgsParser
().
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
logger
=
get_logger
()
print
(
config
)
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
config
[
'Global'
])
...
...
@@ -71,9 +51,15 @@ def main():
init_model
(
config
,
model
,
logger
)
model
.
eval
()
model
=
Model
(
model
)
save_path
=
'{}/{}'
.
format
(
FLAGS
.
output_path
,
config
[
'Architecture'
][
'model_type'
])
save_path
=
'{}/inference'
.
format
(
config
[
'Global'
][
'save_inference_dir'
])
infer_shape
=
[
3
,
32
,
100
]
if
config
[
'Architecture'
][
'model_type'
]
!=
"det"
else
[
3
,
640
,
640
]
model
=
to_static
(
model
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
]
+
infer_shape
,
dtype
=
'float32'
)
])
paddle
.
jit
.
save
(
model
,
save_path
)
logger
.
info
(
'inference model is saved to {}'
.
format
(
save_path
))
...
...
tools/infer/predict_system.py
View file @
631fd9fd
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
...
...
@@ -30,12 +31,15 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from
ppocr.utils.logging
import
get_logger
from
tools.infer.utility
import
draw_ocr_box_txt
logger
=
get_logger
()
class
TextSystem
(
object
):
def
__init__
(
self
,
args
):
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
use_angle_cls
=
args
.
use_angle_cls
self
.
drop_score
=
args
.
drop_score
if
self
.
use_angle_cls
:
self
.
text_classifier
=
predict_cls
.
TextClassifier
(
args
)
...
...
@@ -81,7 +85,8 @@ class TextSystem(object):
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
logger
.
info
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
logger
.
info
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
if
dt_boxes
is
None
:
return
None
,
None
img_crop_list
=
[]
...
...
@@ -99,9 +104,16 @@ class TextSystem(object):
len
(
img_crop_list
),
elapse
))
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
logger
.
info
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
logger
.
info
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
return
dt_boxes
,
rec_res
filter_boxes
,
filter_rec_res
=
[],
[]
for
box
,
rec_reuslt
in
zip
(
dt_boxes
,
rec_res
):
text
,
score
=
rec_reuslt
if
score
>=
self
.
drop_score
:
filter_boxes
.
append
(
box
)
filter_rec_res
.
append
(
rec_reuslt
)
return
filter_boxes
,
filter_rec_res
def
sorted_boxes
(
dt_boxes
):
...
...
@@ -117,8 +129,8 @@ def sorted_boxes(dt_boxes):
_boxes
=
list
(
sorted_boxes
)
for
i
in
range
(
num_boxes
-
1
):
if
abs
(
_boxes
[
i
+
1
][
0
][
1
]
-
_boxes
[
i
][
0
][
1
])
<
10
and
\
(
_boxes
[
i
+
1
][
0
][
0
]
<
_boxes
[
i
][
0
][
0
]):
if
abs
(
_boxes
[
i
+
1
][
0
][
1
]
-
_boxes
[
i
][
0
][
1
])
<
10
and
\
(
_boxes
[
i
+
1
][
0
][
0
]
<
_boxes
[
i
][
0
][
0
]):
tmp
=
_boxes
[
i
]
_boxes
[
i
]
=
_boxes
[
i
+
1
]
_boxes
[
i
+
1
]
=
tmp
...
...
@@ -143,12 +155,8 @@ def main(args):
elapse
=
time
.
time
()
-
starttime
logger
.
info
(
"Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
dt_num
=
len
(
dt_boxes
)
for
dno
in
range
(
dt_num
):
text
,
score
=
rec_res
[
dno
]
if
score
>=
drop_score
:
text_str
=
"%s, %.3f"
%
(
text
,
score
)
logger
.
info
(
text_str
)
for
text
,
score
in
rec_res
:
logger
.
info
(
"{}, {:.3f}"
.
format
(
text
,
score
))
if
is_visualize
:
image
=
Image
.
fromarray
(
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
))
...
...
@@ -174,5 +182,4 @@ def main(args):
if
__name__
==
"__main__"
:
logger
=
get_logger
()
main
(
utility
.
parse_args
())
main
(
utility
.
parse_args
())
\ No newline at end of file
tools/infer/utility.py
View file @
631fd9fd
...
...
@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger):
if
model_dir
is
None
:
logger
.
info
(
"not find {} model file path {}"
.
format
(
mode
,
model_dir
))
sys
.
exit
(
0
)
model_file_path
=
model_dir
+
"/model"
params_file_path
=
model_dir
+
"/params"
model_file_path
=
model_dir
+
"/
inference.pd
model"
params_file_path
=
model_dir
+
"/
inference.pdi
params"
if
not
os
.
path
.
exists
(
model_file_path
):
logger
.
info
(
"not find model file path {}"
.
format
(
model_file_path
))
sys
.
exit
(
0
)
...
...
@@ -230,10 +230,10 @@ def draw_ocr_box_txt(image,
box
[
2
][
1
],
box
[
3
][
0
],
box
[
3
][
1
]
],
outline
=
color
)
box_height
=
math
.
sqrt
((
box
[
0
][
0
]
-
box
[
3
][
0
])
**
2
+
(
box
[
0
][
1
]
-
box
[
3
][
1
])
**
2
)
box_width
=
math
.
sqrt
((
box
[
0
][
0
]
-
box
[
1
][
0
])
**
2
+
(
box
[
0
][
1
]
-
box
[
1
][
1
])
**
2
)
box_height
=
math
.
sqrt
((
box
[
0
][
0
]
-
box
[
3
][
0
])
**
2
+
(
box
[
0
][
1
]
-
box
[
3
][
1
])
**
2
)
box_width
=
math
.
sqrt
((
box
[
0
][
0
]
-
box
[
1
][
0
])
**
2
+
(
box
[
0
][
1
]
-
box
[
1
][
1
])
**
2
)
if
box_height
>
2
*
box_width
:
font_size
=
max
(
int
(
box_width
*
0.9
),
10
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
...
...
@@ -260,7 +260,6 @@ def str_count(s):
Count the number of Chinese characters,
a single English character and a single number
equal to half the length of Chinese characters.
args:
s(string): the input of string
return(int):
...
...
@@ -295,7 +294,6 @@ def text_visual(texts,
img_w(int): the width of blank img
font_path: the path of font which is used to draw text
return(array):
"""
if
scores
is
not
None
:
assert
len
(
texts
)
==
len
(
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment