Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
a48dac50
Unverified
Commit
a48dac50
authored
Apr 16, 2021
by
zhoujun
Committed by
GitHub
Apr 16, 2021
Browse files
Merge branch 'dygraph' into lite
parents
6abb1382
713ceb4e
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
38 additions
and
105 deletions
+38
-105
ppocr/data/imaug/pg_process.py
ppocr/data/imaug/pg_process.py
+4
-4
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+24
-92
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+3
-3
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+1
-1
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+1
-1
ppocr/utils/e2e_utils/extract_textpoint_fast.py
ppocr/utils/e2e_utils/extract_textpoint_fast.py
+1
-0
ppocr/utils/e2e_utils/pgnet_pp_utils.py
ppocr/utils/e2e_utils/pgnet_pp_utils.py
+2
-2
tools/infer/predict_e2e.py
tools/infer/predict_e2e.py
+1
-1
tools/infer_e2e.py
tools/infer_e2e.py
+1
-1
No files found.
ppocr/data/imaug/pg_process.py
View file @
a48dac50
...
...
@@ -88,7 +88,7 @@ class PGProcessTrain(object):
return
min_area_quad
def
check_and_validate_polys
(
self
,
polys
,
tags
,
xxx_todo_changem
e
):
def
check_and_validate_polys
(
self
,
polys
,
tags
,
im_siz
e
):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
...
...
@@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags:
:return:
"""
(
h
,
w
)
=
xxx_todo_changem
e
(
h
,
w
)
=
im_siz
e
if
polys
.
shape
[
0
]
==
0
:
return
polys
,
np
.
array
([]),
np
.
array
([])
polys
[:,
:,
0
]
=
np
.
clip
(
polys
[:,
:,
0
],
0
,
w
-
1
)
...
...
@@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size
=
512
im
=
data
[
'image'
]
text_polys
=
data
[
'polys'
]
text_tags
=
data
[
'tags'
]
text_strs
=
data
[
'
str
s'
]
text_tags
=
data
[
'
ignore_
tags'
]
text_strs
=
data
[
'
text
s'
]
h
,
w
,
_
=
im
.
shape
text_polys
,
text_tags
,
hv_tags
=
self
.
check_and_validate_polys
(
text_polys
,
text_tags
,
(
h
,
w
))
...
...
ppocr/data/pgnet_dataset.py
View file @
a48dac50
...
...
@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config
=
config
[
mode
][
'dataset'
]
loader_config
=
config
[
mode
][
'loader'
]
self
.
delimiter
=
dataset_config
.
get
(
'delimiter'
,
'
\t
'
)
label_file_list
=
dataset_config
.
pop
(
'label_file_list'
)
data_source_num
=
len
(
label_file_list
)
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
if
isinstance
(
ratio_list
,
(
float
,
int
)):
ratio_list
=
[
float
(
ratio_list
)]
*
int
(
data_source_num
)
self
.
data_format
=
dataset_config
.
get
(
'data_format'
,
'icdar'
)
assert
len
(
ratio_list
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
,
self
.
data_format
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
if
mode
.
lower
()
==
"train"
:
self
.
shuffle_data_random
()
...
...
@@ -55,108 +55,40 @@ class PGDataSet(Dataset):
random
.
shuffle
(
self
.
data_lines
)
return
def
extract_polys
(
self
,
poly_txt_path
):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys
,
txt_tags
,
txts
=
[],
[],
[]
with
open
(
poly_txt_path
)
as
f
:
for
line
in
f
.
readlines
():
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly
=
list
(
map
(
float
,
poly_str
.
split
(
','
)))
text_polys
.
append
(
np
.
array
(
poly
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
txts
.
append
(
txt
)
txt_tags
.
append
(
txt
==
'###'
)
return
np
.
array
(
list
(
map
(
np
.
array
,
text_polys
))),
\
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
),
txts
def
extract_info_textnet
(
self
,
im_fn
,
img_dir
=
''
):
"""
Extract information from line in textnet format.
"""
info_list
=
im_fn
.
split
(
'
\t
'
)
img_path
=
''
for
ext
in
[
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'JPG'
]:
if
os
.
path
.
exists
(
os
.
path
.
join
(
img_dir
,
info_list
[
0
]
+
"."
+
ext
)):
img_path
=
os
.
path
.
join
(
img_dir
,
info_list
[
0
]
+
"."
+
ext
)
break
if
img_path
==
''
:
print
(
'Image {0} NOT found in {1}, and it will be ignored.'
.
format
(
info_list
[
0
],
img_dir
))
nBox
=
(
len
(
info_list
)
-
1
)
//
9
wordBBs
,
txts
,
txt_tags
=
[],
[],
[]
for
n
in
range
(
0
,
nBox
):
wordBB
=
list
(
map
(
float
,
info_list
[
n
*
9
+
1
:(
n
+
1
)
*
9
]))
txt
=
info_list
[(
n
+
1
)
*
9
]
wordBBs
.
append
([[
wordBB
[
0
],
wordBB
[
1
]],
[
wordBB
[
2
],
wordBB
[
3
]],
[
wordBB
[
4
],
wordBB
[
5
]],
[
wordBB
[
6
],
wordBB
[
7
]]])
txts
.
append
(
txt
)
if
txt
==
'###'
:
txt_tags
.
append
(
True
)
else
:
txt_tags
.
append
(
False
)
return
img_path
,
np
.
array
(
wordBBs
,
dtype
=
np
.
float32
),
txt_tags
,
txts
def
get_image_info_list
(
self
,
file_list
,
ratio_list
,
data_format
=
'textnet'
):
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
if
isinstance
(
file_list
,
str
):
file_list
=
[
file_list
]
data_lines
=
[]
for
idx
,
data_source
in
enumerate
(
file_list
):
image_files
=
[]
if
data_format
==
'icdar'
:
image_files
=
[(
data_source
,
x
)
for
x
in
os
.
listdir
(
os
.
path
.
join
(
data_source
,
'rgb'
))
if
x
.
split
(
'.'
)[
-
1
]
in
[
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'JPG'
]]
elif
data_format
==
'textnet'
:
with
open
(
data_source
)
as
f
:
image_files
=
[(
data_source
,
x
.
strip
())
for
x
in
f
.
readlines
()]
else
:
print
(
"Unrecognized data format..."
)
exit
(
-
1
)
random
.
seed
(
self
.
seed
)
image_files
=
random
.
sample
(
image_files
,
round
(
len
(
image_files
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
image_files
)
for
idx
,
file
in
enumerate
(
file_list
):
with
open
(
file
,
"rb"
)
as
f
:
lines
=
f
.
readlines
()
if
self
.
mode
==
"train"
or
ratio_list
[
idx
]
<
1.0
:
random
.
seed
(
self
.
seed
)
lines
=
random
.
sample
(
lines
,
round
(
len
(
lines
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
lines
)
return
data_lines
def
__getitem__
(
self
,
idx
):
file_idx
=
self
.
data_idx_order_list
[
idx
]
data_path
,
data_line
=
self
.
data_lines
[
file_idx
]
data_line
=
self
.
data_lines
[
file_idx
]
try
:
if
self
.
data_format
==
'icdar'
:
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
data_line
=
data_line
.
decode
(
'utf-8'
)
substr
=
data_line
.
strip
(
"
\n
"
).
split
(
self
.
delimiter
)
file_name
=
substr
[
0
]
label
=
substr
[
1
]
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
if
self
.
mode
.
lower
()
==
'eval'
:
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
7
:])
else
:
image_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
data_path
),
'image'
)
im_path
,
text_polys
,
text_tags
,
text_strs
=
self
.
extract_info_textnet
(
data_line
,
image_dir
)
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
3
:])
data
=
{
'img_path'
:
im_path
,
'polys'
:
text_polys
,
'tags'
:
text_tags
,
'strs'
:
text_strs
,
'img_id'
:
img_id
}
img_id
=
0
data
=
{
'img_path'
:
img_path
,
'label'
:
label
,
'img_id'
:
img_id
}
if
not
os
.
path
.
exists
(
img_path
):
raise
Exception
(
"{} does not exist!"
.
format
(
img_path
))
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
img
=
f
.
read
()
data
[
'image'
]
=
img
outs
=
transform
(
data
,
self
.
ops
)
except
Exception
as
e
:
self
.
logger
.
error
(
"When parsing line {}, error happened with msg: {}"
.
format
(
...
...
ppocr/metrics/e2e_metric.py
View file @
a48dac50
...
...
@@ -35,11 +35,11 @@ class E2EMetric(object):
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
img_id
=
batch
[
5
][
0
]
img_id
=
batch
[
2
][
0
]
e2e_info_list
=
[{
'points'
:
det_polyon
,
'text'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'
str
s'
])]
'text
s
'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'
text
s'
])]
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
self
.
results
.
append
(
result
)
...
...
ppocr/postprocess/rec_postprocess.py
View file @
a48dac50
...
...
@@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
'ch'
,
'en'
,
'EN_symbol'
,
'french'
,
'german'
,
'japan'
,
'korean'
,
'it'
,
'xi'
,
'pu'
,
'ru'
,
'ar'
,
'ta'
,
'ug'
,
'fa'
,
'ur'
,
'rs'
,
'oc'
,
'rsc'
,
'bg'
,
'uk'
,
'be'
,
'te'
,
'ka'
,
'chinese_cht'
,
'hi'
,
'mr'
,
'ne'
,
'EN'
'ne'
,
'EN'
,
'latin'
,
'arabic'
,
'cyrillic'
,
'devanagari'
]
assert
character_type
in
support_character_type
,
"Only {} are supported now but get {}"
.
format
(
support_character_type
,
character_type
)
...
...
ppocr/utils/e2e_metric/Deteval.py
View file @
a48dac50
...
...
@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n
=
len
(
pred_dict
)
for
i
in
range
(
n
):
points
=
pred_dict
[
i
][
'points'
]
text
=
pred_dict
[
i
][
'text'
]
text
=
pred_dict
[
i
][
'text
s
'
]
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
det
.
append
([
point
,
text
])
return
det
...
...
ppocr/utils/e2e_utils/extract_textpoint_fast.py
View file @
a48dac50
...
...
@@ -21,6 +21,7 @@ import math
import
numpy
as
np
from
itertools
import
groupby
from
cv2.ximgproc
import
thinning
as
thin
from
skimage.morphology._skeletonize
import
thin
...
...
ppocr/utils/e2e_utils/pgnet_pp_utils.py
View file @
a48dac50
...
...
@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w
,
src_h
,
self
.
valid_set
)
data
=
{
'points'
:
poly_list
,
'
str
s'
:
keep_str_list
,
'
text
s'
:
keep_str_list
,
}
return
data
...
...
@@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
exit
(
-
1
)
data
=
{
'points'
:
poly_list
,
'
str
s'
:
keep_str_list
,
'
text
s'
:
keep_str_list
,
}
return
data
tools/infer/predict_e2e.py
View file @
a48dac50
...
...
@@ -122,7 +122,7 @@ class TextE2E(object):
else
:
raise
NotImplementedError
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
str
s'
]
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
text
s'
]
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
points
,
ori_im
.
shape
)
elapse
=
time
.
time
()
-
starttime
return
dt_boxes
,
strs
,
elapse
...
...
tools/infer_e2e.py
View file @
a48dac50
...
...
@@ -103,7 +103,7 @@ def main():
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
,
shape_list
)
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
str
s'
]
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
text
s'
]
# write resule
dt_boxes_json
=
[]
for
poly
,
str
in
zip
(
points
,
strs
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment