Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
5c417939
Unverified
Commit
5c417939
authored
Apr 14, 2021
by
zhoujun
Committed by
GitHub
Apr 14, 2021
Browse files
Merge branch 'dygraph' into android_demo
parents
224ed7ca
f2f52cac
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
862 additions
and
220 deletions
+862
-220
doc/imgs_results/multi_lang/french_0.jpg
doc/imgs_results/multi_lang/french_0.jpg
+0
-0
doc/imgs_results/multi_lang/japan_2.jpg
doc/imgs_results/multi_lang/japan_2.jpg
+0
-0
paddleocr.py
paddleocr.py
+23
-14
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+7
-9
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+5
-10
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+10
-39
ppocr/postprocess/pg_postprocess.py
ppocr/postprocess/pg_postprocess.py
+10
-113
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+6
-27
ppocr/utils/e2e_utils/extract_textpoint_fast.py
ppocr/utils/e2e_utils/extract_textpoint_fast.py
+457
-0
ppocr/utils/e2e_utils/extract_textpoint_slow.py
ppocr/utils/e2e_utils/extract_textpoint_slow.py
+65
-7
ppocr/utils/e2e_utils/pgnet_pp_utils.py
ppocr/utils/e2e_utils/pgnet_pp_utils.py
+181
-0
ppocr/utils/en_dict.txt
ppocr/utils/en_dict.txt
+95
-0
setup.py
setup.py
+1
-1
tools/infer/predict_e2e.py
tools/infer/predict_e2e.py
+1
-0
tools/infer/utility.py
tools/infer/utility.py
+1
-0
No files found.
doc/imgs_results/multi_lang/french_0.jpg
0 → 100644
View file @
5c417939
249 KB
doc/imgs_results/multi_lang/japan_2.jpg
0 → 100644
View file @
5c417939
461 KB
paddleocr.py
View file @
5c417939
...
...
@@ -34,8 +34,12 @@ from ppocr.utils.utility import check_and_read_gif, get_image_file_list
__all__
=
[
'PaddleOCR'
]
model_urls
=
{
'det'
:
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar'
,
'det'
:
{
'ch'
:
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar'
,
'en'
:
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
},
'rec'
:
{
'ch'
:
{
'url'
:
...
...
@@ -45,7 +49,7 @@ model_urls = {
'en'
:
{
'url'
:
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar'
,
'dict_path'
:
'./ppocr/utils/
dict/
en_dict.txt'
'dict_path'
:
'./ppocr/utils/en_dict.txt'
},
'french'
:
{
'url'
:
...
...
@@ -113,7 +117,7 @@ model_urls = {
}
SUPPORT_DET_MODEL
=
[
'DB'
]
VERSION
=
2.
0
VERSION
=
2.
1
SUPPORT_REC_MODEL
=
[
'CRNN'
]
BASE_DIR
=
os
.
path
.
expanduser
(
"~/.paddleocr/"
)
...
...
@@ -199,7 +203,7 @@ def parse_args(mMain=True, add_help=True):
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 32, 320"
)
parser
.
add_argument
(
"--rec_char_type"
,
type
=
str
,
default
=
'ch'
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
30
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--max_text_length"
,
type
=
int
,
default
=
25
)
parser
.
add_argument
(
"--rec_char_dict_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--use_space_char"
,
type
=
bool
,
default
=
True
)
...
...
@@ -209,7 +213,7 @@ def parse_args(mMain=True, add_help=True):
parser
.
add_argument
(
"--cls_model_dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--cls_image_shape"
,
type
=
str
,
default
=
"3, 48, 192"
)
parser
.
add_argument
(
"--label_list"
,
type
=
list
,
default
=
[
'0'
,
'180'
])
parser
.
add_argument
(
"--cls_batch_num"
,
type
=
int
,
default
=
30
)
parser
.
add_argument
(
"--cls_batch_num"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--cls_thresh"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
bool
,
default
=
False
)
...
...
@@ -243,7 +247,7 @@ def parse_args(mMain=True, add_help=True):
rec_model_dir
=
None
,
rec_image_shape
=
"3, 32, 320"
,
rec_char_type
=
'ch'
,
rec_batch_num
=
30
,
rec_batch_num
=
6
,
max_text_length
=
25
,
rec_char_dict_path
=
None
,
use_space_char
=
True
,
...
...
@@ -251,7 +255,7 @@ def parse_args(mMain=True, add_help=True):
cls_model_dir
=
None
,
cls_image_shape
=
"3, 48, 192"
,
label_list
=
[
'0'
,
'180'
],
cls_batch_num
=
30
,
cls_batch_num
=
6
,
cls_thresh
=
0.9
,
enable_mkldnn
=
False
,
use_zero_copy_run
=
False
,
...
...
@@ -274,10 +278,10 @@ class PaddleOCR(predict_system.TextSystem):
self
.
use_angle_cls
=
postprocess_params
.
use_angle_cls
lang
=
postprocess_params
.
lang
latin_lang
=
[
'af'
,
'az'
,
'bs'
,
'cs'
,
'cy'
,
'da'
,
'de'
,
'en'
,
'es'
,
'et'
,
'fr'
,
'ga'
,
'hr'
,
'hu'
,
'id'
,
'is'
,
'it'
,
'ku'
,
'la'
,
'lt'
,
'lv'
,
'mi'
,
'ms'
,
'mt'
,
'nl'
,
'no'
,
'oc'
,
'pi'
,
'pl'
,
'pt'
,
'ro'
,
'rs_latin'
,
'sk'
,
'sl'
,
'sq'
,
'sv'
,
'sw'
,
'tl'
,
'tr'
,
'uz'
,
'vi'
'af'
,
'az'
,
'bs'
,
'cs'
,
'cy'
,
'da'
,
'de'
,
'es'
,
'et'
,
'fr'
,
'ga'
,
'hr'
,
'hu'
,
'id'
,
'is'
,
'it'
,
'ku'
,
'la'
,
'lt'
,
'lv'
,
'mi'
,
'ms'
,
'mt'
,
'nl'
,
'no'
,
'oc'
,
'pi'
,
'pl'
,
'pt'
,
'ro'
,
'rs_latin'
,
'sk'
,
'sl'
,
'sq'
,
'sv'
,
'sw'
,
'tl'
,
'tr'
,
'uz'
,
'vi'
]
arabic_lang
=
[
'ar'
,
'fa'
,
'ug'
,
'ur'
]
cyrillic_lang
=
[
...
...
@@ -299,6 +303,10 @@ class PaddleOCR(predict_system.TextSystem):
assert
lang
in
model_urls
[
'rec'
],
'param lang must in {}, but got {}'
.
format
(
model_urls
[
'rec'
].
keys
(),
lang
)
if
lang
==
"ch"
:
det_lang
=
"ch"
else
:
det_lang
=
"en"
use_inner_dict
=
False
if
postprocess_params
.
rec_char_dict_path
is
None
:
use_inner_dict
=
True
...
...
@@ -308,7 +316,7 @@ class PaddleOCR(predict_system.TextSystem):
# init model dir
if
postprocess_params
.
det_model_dir
is
None
:
postprocess_params
.
det_model_dir
=
os
.
path
.
join
(
BASE_DIR
,
'{}/det'
.
format
(
VERSION
))
BASE_DIR
,
'{}/det
/{}
'
.
format
(
VERSION
,
det_lang
))
if
postprocess_params
.
rec_model_dir
is
None
:
postprocess_params
.
rec_model_dir
=
os
.
path
.
join
(
BASE_DIR
,
'{}/rec/{}'
.
format
(
VERSION
,
lang
))
...
...
@@ -317,7 +325,8 @@ class PaddleOCR(predict_system.TextSystem):
BASE_DIR
,
'{}/cls'
.
format
(
VERSION
))
print
(
postprocess_params
)
# download model
maybe_download
(
postprocess_params
.
det_model_dir
,
model_urls
[
'det'
])
maybe_download
(
postprocess_params
.
det_model_dir
,
model_urls
[
'det'
][
det_lang
])
maybe_download
(
postprocess_params
.
rec_model_dir
,
model_urls
[
'rec'
][
lang
][
'url'
])
maybe_download
(
postprocess_params
.
cls_model_dir
,
model_urls
[
'cls'
])
...
...
ppocr/data/imaug/label_ops.py
View file @
5c417939
...
...
@@ -200,18 +200,16 @@ class E2ELabelEncode(BaseRecLabelEncode):
self
.
pad_num
=
len
(
self
.
dict
)
# the length to pad
def
__call__
(
self
,
data
):
text_label_index_list
,
temp_text
=
[],
[]
texts
=
data
[
'strs'
]
temp_texts
=
[]
for
text
in
texts
:
text
=
text
.
lower
()
temp_text
=
[]
for
c_
in
text
:
if
c_
in
self
.
dict
:
temp_text
.
append
(
self
.
dict
[
c_
])
temp_text
=
temp_text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
temp_text
))
text_label_index_list
.
append
(
temp_text
)
data
[
'strs'
]
=
np
.
array
(
text_label_index_list
)
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
text
=
text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
text
))
temp_texts
.
append
(
text
)
data
[
'strs'
]
=
np
.
array
(
temp_texts
)
return
data
...
...
ppocr/data/pgnet_dataset.py
View file @
5c417939
...
...
@@ -64,9 +64,6 @@ class PGDataSet(Dataset):
for
line
in
f
.
readlines
():
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly
=
list
(
map
(
float
,
poly_str
.
split
(
','
)))
if
self
.
mode
.
lower
()
==
"eval"
:
while
len
(
poly
)
<
100
:
poly
.
append
(
-
1
)
text_polys
.
append
(
np
.
array
(
poly
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
...
...
@@ -139,23 +136,21 @@ class PGDataSet(Dataset):
try
:
if
self
.
data_format
==
'icdar'
:
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
if
self
.
mode
.
lower
()
==
"eval"
:
poly_path
=
os
.
path
.
join
(
data_path
,
'poly_gt'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
else
:
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
else
:
image_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
data_path
),
'image'
)
im_path
,
text_polys
,
text_tags
,
text_strs
=
self
.
extract_info_textnet
(
data_line
,
image_dir
)
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
3
:])
data
=
{
'img_path'
:
im_path
,
'polys'
:
text_polys
,
'tags'
:
text_tags
,
'strs'
:
text_strs
'strs'
:
text_strs
,
'img_id'
:
img_id
}
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
img
=
f
.
read
()
...
...
ppocr/metrics/e2e_metric.py
View file @
5c417939
...
...
@@ -19,58 +19,29 @@ from __future__ import print_function
__all__
=
[
'E2EMetric'
]
from
ppocr.utils.e2e_metric.Deteval
import
get_socre
,
combine_results
from
ppocr.utils.e2e_utils.extract_textpoint
import
get_dict
from
ppocr.utils.e2e_utils.extract_textpoint
_slow
import
get_dict
class
E2EMetric
(
object
):
def
__init__
(
self
,
gt_mat_dir
,
character_dict_path
,
main_indicator
=
'f_score_e2e'
,
**
kwargs
):
self
.
gt_mat_dir
=
gt_mat_dir
self
.
label_list
=
get_dict
(
character_dict_path
)
self
.
max_index
=
len
(
self
.
label_list
)
self
.
main_indicator
=
main_indicator
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
temp_gt_polyons_batch
=
batch
[
2
]
temp_gt_strs_batch
=
batch
[
3
]
ignore_tags_batch
=
batch
[
4
]
gt_polyons_batch
=
[]
gt_strs_batch
=
[]
temp_gt_polyons_batch
=
temp_gt_polyons_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_polyons_batch
:
t
=
[]
for
index
in
temp_list
:
if
index
[
0
]
!=
-
1
and
index
[
1
]
!=
-
1
:
t
.
append
(
index
)
gt_polyons_batch
.
append
(
t
)
temp_gt_strs_batch
=
temp_gt_strs_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_strs_batch
:
t
=
""
for
index
in
temp_list
:
if
index
<
self
.
max_index
:
t
+=
self
.
label_list
[
index
]
gt_strs_batch
.
append
(
t
)
for
pred
,
gt_polyons
,
gt_strs
,
ignore_tags
in
zip
(
[
preds
],
[
gt_polyons_batch
],
[
gt_strs_batch
],
ignore_tags_batch
):
# prepare gt
gt_info_list
=
[{
'points'
:
gt_polyon
,
'text'
:
gt_str
,
'ignore'
:
ignore_tag
}
for
gt_polyon
,
gt_str
,
ignore_tag
in
zip
(
gt_polyons
,
gt_strs
,
ignore_tags
)]
# prepare det
e2e_info_list
=
[{
'points'
:
det_polyon
,
'text'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
pred
[
'points'
],
pred
[
'strs'
])]
result
=
get_socre
(
gt_info_list
,
e2e_info_list
)
self
.
results
.
append
(
result
)
img_id
=
batch
[
5
][
0
]
e2e_info_list
=
[{
'points'
:
det_polyon
,
'text'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'strs'
])]
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
self
.
results
.
append
(
result
)
def
get_metric
(
self
):
metircs
=
combine_results
(
self
.
results
)
...
...
ppocr/postprocess/pg_postprocess.py
View file @
5c417939
...
...
@@ -22,10 +22,7 @@ import sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
from
ppocr.utils.e2e_utils.extract_textpoint
import
*
from
ppocr.utils.e2e_utils.visual
import
*
import
paddle
from
ppocr.utils.e2e_utils.pgnet_pp_utils
import
PGNet_PostProcess
class
PGPostProcess
(
object
):
...
...
@@ -33,10 +30,12 @@ class PGPostProcess(object):
The post process for PGNet.
"""
def
__init__
(
self
,
character_dict_path
,
valid_set
,
score_thresh
,
**
kwargs
):
self
.
Lexicon_Table
=
get_dict
(
character_dict_path
)
def
__init__
(
self
,
character_dict_path
,
valid_set
,
score_thresh
,
mode
,
**
kwargs
):
self
.
character_dict_path
=
character_dict_path
self
.
valid_set
=
valid_set
self
.
score_thresh
=
score_thresh
self
.
mode
=
mode
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
...
...
@@ -44,112 +43,10 @@ class PGPostProcess(object):
self
.
is_python35
=
True
def
__call__
(
self
,
outs_dict
,
shape_list
):
p_score
=
outs_dict
[
'f_score'
]
p_border
=
outs_dict
[
'f_border'
]
p_char
=
outs_dict
[
'f_char'
]
p_direction
=
outs_dict
[
'f_direction'
]
if
isinstance
(
p_score
,
paddle
.
Tensor
):
p_score
=
p_score
[
0
].
numpy
()
p_border
=
p_border
[
0
].
numpy
()
p_direction
=
p_direction
[
0
].
numpy
()
p_char
=
p_char
[
0
].
numpy
()
post
=
PGNet_PostProcess
(
self
.
character_dict_path
,
self
.
valid_set
,
self
.
score_thresh
,
outs_dict
,
shape_list
)
if
self
.
mode
==
'fast'
:
data
=
post
.
pg_postprocess_fast
()
else
:
p_score
=
p_score
[
0
]
p_border
=
p_border
[
0
]
p_direction
=
p_direction
[
0
]
p_char
=
p_char
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
0
]
is_curved
=
self
.
valid_set
==
"totaltext"
instance_yxs_list
=
generate_pivot_list
(
p_score
,
p_char
,
p_direction
,
score_thresh
=
self
.
score_thresh
,
is_backbone
=
True
,
is_curved
=
is_curved
)
p_char
=
paddle
.
to_tensor
(
np
.
expand_dims
(
p_char
,
axis
=
0
))
char_seq_idx_set
=
[]
for
i
in
range
(
len
(
instance_yxs_list
)):
gather_info_lod
=
paddle
.
to_tensor
(
instance_yxs_list
[
i
])
f_char_map
=
paddle
.
transpose
(
p_char
,
[
0
,
2
,
3
,
1
])
feature_seq
=
paddle
.
gather_nd
(
f_char_map
,
gather_info_lod
)
feature_seq
=
np
.
expand_dims
(
feature_seq
.
numpy
(),
axis
=
0
)
feature_len
=
[
len
(
feature_seq
[
0
])]
featyre_seq
=
paddle
.
to_tensor
(
feature_seq
)
feature_len
=
np
.
array
([
feature_len
]).
astype
(
np
.
int64
)
length
=
paddle
.
to_tensor
(
feature_len
)
seq_pred
=
paddle
.
fluid
.
layers
.
ctc_greedy_decoder
(
input
=
featyre_seq
,
blank
=
36
,
input_length
=
length
)
seq_pred_str
=
seq_pred
[
0
].
numpy
().
tolist
()[
0
]
seq_len
=
seq_pred
[
1
].
numpy
()[
0
][
0
]
temp_t
=
[]
for
c
in
seq_pred_str
[:
seq_len
]:
temp_t
.
append
(
c
)
char_seq_idx_set
.
append
(
temp_t
)
seq_strs
=
[]
for
char_idx_set
in
char_seq_idx_set
:
pr_str
=
''
.
join
([
self
.
Lexicon_Table
[
pos
]
for
pos
in
char_idx_set
])
seq_strs
.
append
(
pr_str
)
poly_list
=
[]
keep_str_list
=
[]
all_point_list
=
[]
all_point_pair_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
yx_center_line
)
==
1
:
yx_center_line
.
append
(
yx_center_line
[
-
1
])
offset_expand
=
1.0
if
self
.
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
batch_id
,
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
all_point_list
.
append
([
int
(
round
(
x
*
4.0
/
ratio_w
)),
int
(
round
(
y
*
4.0
/
ratio_h
))
])
all_point_pair_list
.
append
(
point_pair
.
round
().
astype
(
np
.
int32
)
.
tolist
())
detected_poly
,
pair_length_info
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
if
len
(
keep_str
)
<
2
:
continue
keep_str_list
.
append
(
keep_str
)
if
self
.
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
self
.
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
print
(
'--> Not supported format.'
)
exit
(
-
1
)
data
=
{
'points'
:
poly_list
,
'strs'
:
keep_str_list
,
}
data
=
post
.
pg_postprocess_slow
()
return
data
ppocr/utils/e2e_metric/Deteval.py
View file @
5c417939
...
...
@@ -13,10 +13,11 @@
# limitations under the License.
import
numpy
as
np
import
scipy.io
as
io
from
ppocr.utils.e2e_metric.polygon_fast
import
iod
,
area_of_intersection
,
area
def
get_socre
(
gt_di
ct
,
pred_dict
):
def
get_socre
(
gt_di
r
,
img_id
,
pred_dict
):
allInputs
=
1
def
input_reading_mod
(
pred_dict
):
...
...
@@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict):
det
.
append
([
point
,
text
])
return
det
def
gt_reading_mod
(
gt_dict
):
"""This helper reads groundtruths from mat files"""
gt
=
[]
n
=
len
(
gt_dict
)
for
i
in
range
(
n
):
points
=
gt_dict
[
i
][
'points'
]
h
=
len
(
points
)
text
=
gt_dict
[
i
][
'text'
]
xx
=
[
np
.
array
(
[
'x:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'y:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'#'
],
dtype
=
'<U1'
),
np
.
array
(
[
'#'
],
dtype
=
'<U1'
)
]
t_x
,
t_y
=
[],
[]
for
j
in
range
(
h
):
t_x
.
append
(
points
[
j
][
0
])
t_y
.
append
(
points
[
j
][
1
])
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
if
text
!=
""
and
"#"
not
in
text
:
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
gt
.
append
(
xx
)
def
gt_reading_mod
(
gt_dir
,
gt_id
):
gt
=
io
.
loadmat
(
'%s/poly_gt_img%s.mat'
%
(
gt_dir
,
gt_id
))
gt
=
gt
[
'polygt'
]
return
gt
def
detection_filtering
(
detections
,
groundtruths
,
threshold
=
0.5
):
...
...
@@ -101,7 +80,7 @@ def get_socre(gt_dict, pred_dict):
input_id
!=
'Deteval_result.txt'
)
and
(
input_id
!=
'Deteval_result_curved.txt'
)
\
and
(
input_id
!=
'Deteval_result_non_curved.txt'
):
detections
=
input_reading_mod
(
pred_dict
)
groundtruths
=
gt_reading_mod
(
gt_di
ct
)
groundtruths
=
gt_reading_mod
(
gt_di
r
,
img_id
).
tolist
(
)
detections
=
detection_filtering
(
detections
,
groundtruths
)
# filters detections overlapping with DC area
...
...
ppocr/utils/e2e_utils/extract_textpoint_fast.py
0 → 100644
View file @
5c417939
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
cv2
import
math
import
numpy
as
np
from
itertools
import
groupby
from
skimage.morphology._skeletonize
import
thin
def
get_dict
(
character_dict_path
):
character_str
=
""
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
line
=
line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
character_str
+=
line
dict_character
=
list
(
character_str
)
return
dict_character
def
softmax
(
logits
):
"""
logits: N x d
"""
max_value
=
np
.
max
(
logits
,
axis
=
1
,
keepdims
=
True
)
exp
=
np
.
exp
(
logits
-
max_value
)
exp_sum
=
np
.
sum
(
exp
,
axis
=
1
,
keepdims
=
True
)
dist
=
exp
/
exp_sum
return
dist
def
get_keep_pos_idxs
(
labels
,
remove_blank
=
None
):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list
=
[]
keep_pos_idx_list
=
[]
keep_char_idx_list
=
[]
for
k
,
v_
in
groupby
(
labels
):
current_len
=
len
(
list
(
v_
))
if
k
!=
remove_blank
:
current_idx
=
int
(
sum
(
duplicate_len_list
)
+
current_len
//
2
)
keep_pos_idx_list
.
append
(
current_idx
)
keep_char_idx_list
.
append
(
k
)
duplicate_len_list
.
append
(
current_len
)
return
keep_char_idx_list
,
keep_pos_idx_list
def
remove_blank
(
labels
,
blank
=
0
):
new_labels
=
[
x
for
x
in
labels
if
x
!=
blank
]
return
new_labels
def
insert_blank
(
labels
,
blank
=
0
):
new_labels
=
[
blank
]
for
l
in
labels
:
new_labels
+=
[
l
,
blank
]
return
new_labels
def
ctc_greedy_decoder
(
probs_seq
,
blank
=
95
,
keep_blank_in_idxs
=
True
):
"""
CTC greedy (best path) decoder.
"""
raw_str
=
np
.
argmax
(
np
.
array
(
probs_seq
),
axis
=
1
)
remove_blank_in_pos
=
None
if
keep_blank_in_idxs
else
blank
dedup_str
,
keep_idx_list
=
get_keep_pos_idxs
(
raw_str
,
remove_blank
=
remove_blank_in_pos
)
dst_str
=
remove_blank
(
dedup_str
,
blank
=
blank
)
return
dst_str
,
keep_idx_list
def
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
pts_num
=
4
):
_
,
_
,
C
=
logits_map
.
shape
ys
,
xs
=
zip
(
*
gather_info
)
logits_seq
=
logits_map
[
list
(
ys
),
list
(
xs
)]
probs_seq
=
logits_seq
labels
=
np
.
argmax
(
probs_seq
,
axis
=
1
)
dst_str
=
[
k
for
k
,
v_
in
groupby
(
labels
)
if
k
!=
C
-
1
]
detal
=
len
(
gather_info
)
//
(
pts_num
-
1
)
keep_idx_list
=
[
0
]
+
[
detal
*
(
i
+
1
)
for
i
in
range
(
pts_num
-
2
)]
+
[
-
1
]
keep_gather_list
=
[
gather_info
[
idx
]
for
idx
in
keep_idx_list
]
return
dst_str
,
keep_gather_list
def
ctc_decoder_for_image
(
gather_info_list
,
logits_map
,
Lexicon_Table
,
pts_num
=
6
):
"""
CTC decoder using multiple processes.
"""
decoder_str
=
[]
decoder_xys
=
[]
for
gather_info
in
gather_info_list
:
if
len
(
gather_info
)
<
pts_num
:
continue
dst_str
,
xys_list
=
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
pts_num
=
pts_num
)
dst_str_readable
=
''
.
join
([
Lexicon_Table
[
idx
]
for
idx
in
dst_str
])
if
len
(
dst_str_readable
)
<
2
:
continue
decoder_str
.
append
(
dst_str_readable
)
decoder_xys
.
append
(
xys_list
)
return
decoder_str
,
decoder_xys
def
sort_with_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def
sort_part_with_direction
(
pos_list
,
point_direction
):
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
2
)
point_direction
=
np
.
array
(
point_direction
).
reshape
(
-
1
,
2
)
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
sorted_direction
=
point_direction
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
,
sorted_direction
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
2
)
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
sorted_point
,
sorted_direction
=
sort_part_with_direction
(
pos_list
,
point_direction
)
point_num
=
len
(
sorted_point
)
if
point_num
>=
16
:
middle_num
=
point_num
//
2
first_part_point
=
sorted_point
[:
middle_num
]
first_point_direction
=
sorted_direction
[:
middle_num
]
sorted_fist_part_point
,
sorted_fist_part_direction
=
sort_part_with_direction
(
first_part_point
,
first_point_direction
)
last_part_point
=
sorted_point
[
middle_num
:]
last_point_direction
=
sorted_direction
[
middle_num
:]
sorted_last_part_point
,
sorted_last_part_direction
=
sort_part_with_direction
(
last_part_point
,
last_point_direction
)
sorted_point
=
sorted_fist_part_point
+
sorted_last_part_point
sorted_direction
=
sorted_fist_part_direction
+
sorted_last_part_direction
return
sorted_point
,
np
.
array
(
sorted_direction
)
def
add_id
(
pos_list
,
image_id
=
0
):
"""
Add id for gather feature, for inference.
"""
new_list
=
[]
for
item
in
pos_list
:
new_list
.
append
((
image_id
,
item
[
0
],
item
[
1
]))
return
new_list
def
sort_and_expand_with_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
right_dirction
=
point_direction
[
point_num
-
sub_direction_len
:,
:]
left_average_direction
=
-
np
.
mean
(
left_direction
,
axis
=
0
,
keepdims
=
True
)
left_average_len
=
np
.
linalg
.
norm
(
left_average_direction
)
left_start
=
np
.
array
(
sorted_list
[
0
])
left_step
=
left_average_direction
/
(
left_average_len
+
1e-6
)
right_average_direction
=
np
.
mean
(
right_dirction
,
axis
=
0
,
keepdims
=
True
)
right_average_len
=
np
.
linalg
.
norm
(
right_average_direction
)
right_step
=
right_average_direction
/
(
right_average_len
+
1e-6
)
right_start
=
np
.
array
(
sorted_list
[
-
1
])
append_num
=
max
(
int
((
left_average_len
+
right_average_len
)
/
2.0
*
0.15
),
1
)
left_list
=
[]
right_list
=
[]
for
i
in
range
(
append_num
):
ly
,
lx
=
np
.
round
(
left_start
+
left_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ly
<
h
and
lx
<
w
and
(
ly
,
lx
)
not
in
left_list
:
left_list
.
append
((
ly
,
lx
))
ry
,
rx
=
np
.
round
(
right_start
+
right_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ry
<
h
and
rx
<
w
and
(
ry
,
rx
)
not
in
right_list
:
right_list
.
append
((
ry
,
rx
))
all_list
=
left_list
[::
-
1
]
+
sorted_list
+
right_list
return
all_list
def
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
binary_tcl_map
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
right_dirction
=
point_direction
[
point_num
-
sub_direction_len
:,
:]
left_average_direction
=
-
np
.
mean
(
left_direction
,
axis
=
0
,
keepdims
=
True
)
left_average_len
=
np
.
linalg
.
norm
(
left_average_direction
)
left_start
=
np
.
array
(
sorted_list
[
0
])
left_step
=
left_average_direction
/
(
left_average_len
+
1e-6
)
right_average_direction
=
np
.
mean
(
right_dirction
,
axis
=
0
,
keepdims
=
True
)
right_average_len
=
np
.
linalg
.
norm
(
right_average_direction
)
right_step
=
right_average_direction
/
(
right_average_len
+
1e-6
)
right_start
=
np
.
array
(
sorted_list
[
-
1
])
append_num
=
max
(
int
((
left_average_len
+
right_average_len
)
/
2.0
*
0.15
),
1
)
max_append_num
=
2
*
append_num
left_list
=
[]
right_list
=
[]
for
i
in
range
(
max_append_num
):
ly
,
lx
=
np
.
round
(
left_start
+
left_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ly
<
h
and
lx
<
w
and
(
ly
,
lx
)
not
in
left_list
:
if
binary_tcl_map
[
ly
,
lx
]
>
0.5
:
left_list
.
append
((
ly
,
lx
))
else
:
break
for
i
in
range
(
max_append_num
):
ry
,
rx
=
np
.
round
(
right_start
+
right_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ry
<
h
and
rx
<
w
and
(
ry
,
rx
)
not
in
right_list
:
if
binary_tcl_map
[
ry
,
rx
]
>
0.5
:
right_list
.
append
((
ry
,
rx
))
else
:
break
all_list
=
left_list
[::
-
1
]
+
sorted_list
+
right_list
return
all_list
def
point_pair2poly
(
point_pair_list
):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
point_num
=
len
(
point_pair_list
)
*
2
point_list
=
[
0
]
*
point_num
for
idx
,
point_pair
in
enumerate
(
point_pair_list
):
point_list
[
idx
]
=
point_pair
[
0
]
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
)
def
shrink_quad_along_width
(
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
ratio_pair
=
np
.
array
(
[[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
def
expand_poly_along_width
(
poly
,
shrink_ratio_of_width
=
0.3
):
"""
expand poly along width.
"""
point_num
=
poly
.
shape
[
0
]
left_quad
=
np
.
array
(
[
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
left_quad_expand
=
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
right_quad
=
np
.
array
(
[
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]
],
dtype
=
np
.
float32
)
right_ratio
=
1.0
+
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
poly
[
point_num
//
2
]
=
right_quad_expand
[
2
]
return
poly
def
restore_poly
(
instance_yxs_list
,
seq_strs
,
p_border
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
valid_set
):
poly_list
=
[]
keep_str_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
keep_str
)
<
2
:
print
(
'--> too short, {}'
.
format
(
keep_str
))
continue
offset_expand
=
1.0
if
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
*
offset_expand
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
detected_poly
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
keep_str_list
.
append
(
keep_str
)
if
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
print
(
'--> Not supported format.'
)
exit
(
-
1
)
return
poly_list
,
keep_str_list
def
generate_pivot_list_fast
(
p_score
,
p_char_maps
,
f_direction
,
Lexicon_Table
,
score_thresh
=
0.5
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
skeleton_map
=
thin
(
p_tcl_map
.
astype
(
np
.
uint8
))
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
if
len
(
pos_list
)
<
3
:
continue
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
p_tcl_map
)
all_pos_yxs
.
append
(
pos_list_sorted
)
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decoded_str
,
keep_yxs_list
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
Lexicon_Table
=
Lexicon_Table
)
return
keep_yxs_list
,
decoded_str
def
extract_main_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list
=
np
.
array
(
pos_list
)
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
average_direction
=
average_direction
/
(
np
.
linalg
.
norm
(
average_direction
)
+
1e-6
)
return
average_direction
def
sort_by_direction_with_image_id_deprecated
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
def
sort_by_direction_with_image_id
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def
sort_part_with_direction
(
pos_list_full
,
point_direction
):
pos_list_full
=
np
.
array
(
pos_list_full
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
np
.
array
(
point_direction
).
reshape
(
-
1
,
2
)
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
sorted_direction
=
point_direction
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
,
sorted_direction
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
point_direction
=
f_direction
[
pos_list
[:,
1
],
pos_list
[:,
2
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
sorted_point
,
sorted_direction
=
sort_part_with_direction
(
pos_list
,
point_direction
)
point_num
=
len
(
sorted_point
)
if
point_num
>=
16
:
middle_num
=
point_num
//
2
first_part_point
=
sorted_point
[:
middle_num
]
first_point_direction
=
sorted_direction
[:
middle_num
]
sorted_fist_part_point
,
sorted_fist_part_direction
=
sort_part_with_direction
(
first_part_point
,
first_point_direction
)
last_part_point
=
sorted_point
[
middle_num
:]
last_point_direction
=
sorted_direction
[
middle_num
:]
sorted_last_part_point
,
sorted_last_part_direction
=
sort_part_with_direction
(
last_part_point
,
last_point_direction
)
sorted_point
=
sorted_fist_part_point
+
sorted_last_part_point
sorted_direction
=
sorted_fist_part_direction
+
sorted_last_part_direction
return
sorted_point
ppocr/utils/e2e_utils/extract_textpoint.py
→
ppocr/utils/e2e_utils/extract_textpoint
_slow
.py
View file @
5c417939
...
...
@@ -35,6 +35,64 @@ def get_dict(character_dict_path):
return
dict_character
def
point_pair2poly
(
point_pair_list
):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list
=
[]
for
point_pair
in
point_pair_list
:
pair_length
=
np
.
linalg
.
norm
(
point_pair
[
0
]
-
point_pair
[
1
])
pair_length_list
.
append
(
pair_length
)
pair_length_list
=
np
.
array
(
pair_length_list
)
pair_info
=
(
pair_length_list
.
max
(),
pair_length_list
.
min
(),
pair_length_list
.
mean
())
point_num
=
len
(
point_pair_list
)
*
2
point_list
=
[
0
]
*
point_num
for
idx
,
point_pair
in
enumerate
(
point_pair_list
):
point_list
[
idx
]
=
point_pair
[
0
]
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
),
pair_info
def
shrink_quad_along_width
(
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
"""
Generate shrink_quad_along_width.
"""
ratio_pair
=
np
.
array
(
[[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
def
expand_poly_along_width
(
poly
,
shrink_ratio_of_width
=
0.3
):
"""
expand poly along width.
"""
point_num
=
poly
.
shape
[
0
]
left_quad
=
np
.
array
(
[
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
left_quad_expand
=
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
right_quad
=
np
.
array
(
[
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]
],
dtype
=
np
.
float32
)
right_ratio
=
1.0
+
\
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
poly
[
point_num
//
2
]
=
right_quad_expand
[
2
]
return
poly
def
softmax
(
logits
):
"""
logits: N x d
...
...
@@ -399,13 +457,13 @@ def generate_pivot_list_horizontal(p_score,
return
center_pos_yxs
,
end_points_yxs
def
generate_pivot_list
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
def
generate_pivot_list
_slow
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
"""
Warp all the function together.
"""
...
...
ppocr/utils/e2e_utils/pgnet_pp_utils.py
0 → 100644
View file @
5c417939
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
from
extract_textpoint_slow
import
*
from
extract_textpoint_fast
import
generate_pivot_list_fast
,
restore_poly
class
PGNet_PostProcess
(
object
):
# two different post-process
def
__init__
(
self
,
character_dict_path
,
valid_set
,
score_thresh
,
outs_dict
,
shape_list
):
self
.
Lexicon_Table
=
get_dict
(
character_dict_path
)
self
.
valid_set
=
valid_set
self
.
score_thresh
=
score_thresh
self
.
outs_dict
=
outs_dict
self
.
shape_list
=
shape_list
def
pg_postprocess_fast
(
self
):
p_score
=
self
.
outs_dict
[
'f_score'
]
p_border
=
self
.
outs_dict
[
'f_border'
]
p_char
=
self
.
outs_dict
[
'f_char'
]
p_direction
=
self
.
outs_dict
[
'f_direction'
]
if
isinstance
(
p_score
,
paddle
.
Tensor
):
p_score
=
p_score
[
0
].
numpy
()
p_border
=
p_border
[
0
].
numpy
()
p_direction
=
p_direction
[
0
].
numpy
()
p_char
=
p_char
[
0
].
numpy
()
else
:
p_score
=
p_score
[
0
]
p_border
=
p_border
[
0
]
p_direction
=
p_direction
[
0
]
p_char
=
p_char
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
self
.
shape_list
[
0
]
instance_yxs_list
,
seq_strs
=
generate_pivot_list_fast
(
p_score
,
p_char
,
p_direction
,
self
.
Lexicon_Table
,
score_thresh
=
self
.
score_thresh
)
poly_list
,
keep_str_list
=
restore_poly
(
instance_yxs_list
,
seq_strs
,
p_border
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
self
.
valid_set
)
data
=
{
'points'
:
poly_list
,
'strs'
:
keep_str_list
,
}
return
data
def
pg_postprocess_slow
(
self
):
p_score
=
self
.
outs_dict
[
'f_score'
]
p_border
=
self
.
outs_dict
[
'f_border'
]
p_char
=
self
.
outs_dict
[
'f_char'
]
p_direction
=
self
.
outs_dict
[
'f_direction'
]
if
isinstance
(
p_score
,
paddle
.
Tensor
):
p_score
=
p_score
[
0
].
numpy
()
p_border
=
p_border
[
0
].
numpy
()
p_direction
=
p_direction
[
0
].
numpy
()
p_char
=
p_char
[
0
].
numpy
()
else
:
p_score
=
p_score
[
0
]
p_border
=
p_border
[
0
]
p_direction
=
p_direction
[
0
]
p_char
=
p_char
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
self
.
shape_list
[
0
]
is_curved
=
self
.
valid_set
==
"totaltext"
instance_yxs_list
=
generate_pivot_list_slow
(
p_score
,
p_char
,
p_direction
,
score_thresh
=
self
.
score_thresh
,
is_backbone
=
True
,
is_curved
=
is_curved
)
p_char
=
paddle
.
to_tensor
(
np
.
expand_dims
(
p_char
,
axis
=
0
))
char_seq_idx_set
=
[]
for
i
in
range
(
len
(
instance_yxs_list
)):
gather_info_lod
=
paddle
.
to_tensor
(
instance_yxs_list
[
i
])
f_char_map
=
paddle
.
transpose
(
p_char
,
[
0
,
2
,
3
,
1
])
feature_seq
=
paddle
.
gather_nd
(
f_char_map
,
gather_info_lod
)
feature_seq
=
np
.
expand_dims
(
feature_seq
.
numpy
(),
axis
=
0
)
feature_len
=
[
len
(
feature_seq
[
0
])]
featyre_seq
=
paddle
.
to_tensor
(
feature_seq
)
feature_len
=
np
.
array
([
feature_len
]).
astype
(
np
.
int64
)
length
=
paddle
.
to_tensor
(
feature_len
)
seq_pred
=
paddle
.
fluid
.
layers
.
ctc_greedy_decoder
(
input
=
featyre_seq
,
blank
=
36
,
input_length
=
length
)
seq_pred_str
=
seq_pred
[
0
].
numpy
().
tolist
()[
0
]
seq_len
=
seq_pred
[
1
].
numpy
()[
0
][
0
]
temp_t
=
[]
for
c
in
seq_pred_str
[:
seq_len
]:
temp_t
.
append
(
c
)
char_seq_idx_set
.
append
(
temp_t
)
seq_strs
=
[]
for
char_idx_set
in
char_seq_idx_set
:
pr_str
=
''
.
join
([
self
.
Lexicon_Table
[
pos
]
for
pos
in
char_idx_set
])
seq_strs
.
append
(
pr_str
)
poly_list
=
[]
keep_str_list
=
[]
all_point_list
=
[]
all_point_pair_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
yx_center_line
)
==
1
:
yx_center_line
.
append
(
yx_center_line
[
-
1
])
offset_expand
=
1.0
if
self
.
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
batch_id
,
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
all_point_list
.
append
([
int
(
round
(
x
*
4.0
/
ratio_w
)),
int
(
round
(
y
*
4.0
/
ratio_h
))
])
all_point_pair_list
.
append
(
point_pair
.
round
().
astype
(
np
.
int32
)
.
tolist
())
detected_poly
,
pair_length_info
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
if
len
(
keep_str
)
<
2
:
continue
keep_str_list
.
append
(
keep_str
)
detected_poly
=
np
.
round
(
detected_poly
).
astype
(
'int32'
)
if
self
.
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
self
.
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
print
(
'--> Not supported format.'
)
exit
(
-
1
)
data
=
{
'points'
:
poly_list
,
'strs'
:
keep_str_list
,
}
return
data
ppocr/utils/en_dict.txt
0 → 100644
View file @
5c417939
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
setup.py
View file @
5c417939
...
...
@@ -32,7 +32,7 @@ setup(
package_dir
=
{
'paddleocr'
:
''
},
include_package_data
=
True
,
entry_points
=
{
"console_scripts"
:
[
"paddleocr= paddleocr.paddleocr:main"
]},
version
=
'2.0.
4
'
,
version
=
'2.0.
6
'
,
install_requires
=
requirements
,
license
=
'Apache License 2.0'
,
description
=
'Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices'
,
...
...
tools/infer/predict_e2e.py
View file @
5c417939
...
...
@@ -66,6 +66,7 @@ class TextE2E(object):
postprocess_params
[
"score_thresh"
]
=
args
.
e2e_pgnet_score_thresh
postprocess_params
[
"character_dict_path"
]
=
args
.
e2e_char_dict_path
postprocess_params
[
"valid_set"
]
=
args
.
e2e_pgnet_valid_set
postprocess_params
[
"mode"
]
=
args
.
e2e_pgnet_mode
self
.
e2e_pgnet_polygon
=
args
.
e2e_pgnet_polygon
else
:
logger
.
info
(
"unknown e2e_algorithm:{}"
.
format
(
self
.
e2e_algorithm
))
...
...
tools/infer/utility.py
View file @
5c417939
...
...
@@ -86,6 +86,7 @@ def parse_args():
"--e2e_char_dict_path"
,
type
=
str
,
default
=
"./ppocr/utils/ic15_dict.txt"
)
parser
.
add_argument
(
"--e2e_pgnet_valid_set"
,
type
=
str
,
default
=
'totaltext'
)
parser
.
add_argument
(
"--e2e_pgnet_polygon"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--e2e_pgnet_mode"
,
type
=
str
,
default
=
'fast'
)
# params for text classifier
parser
.
add_argument
(
"--use_angle_cls"
,
type
=
str2bool
,
default
=
False
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment