Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
a48dac50
Unverified
Commit
a48dac50
authored
Apr 16, 2021
by
zhoujun
Committed by
GitHub
Apr 16, 2021
Browse files
Merge branch 'dygraph' into lite
parents
6abb1382
713ceb4e
Changes
29
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
38 additions
and
105 deletions
+38
-105
ppocr/data/imaug/pg_process.py
ppocr/data/imaug/pg_process.py
+4
-4
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+24
-92
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+3
-3
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+1
-1
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+1
-1
ppocr/utils/e2e_utils/extract_textpoint_fast.py
ppocr/utils/e2e_utils/extract_textpoint_fast.py
+1
-0
ppocr/utils/e2e_utils/pgnet_pp_utils.py
ppocr/utils/e2e_utils/pgnet_pp_utils.py
+2
-2
tools/infer/predict_e2e.py
tools/infer/predict_e2e.py
+1
-1
tools/infer_e2e.py
tools/infer_e2e.py
+1
-1
No files found.
ppocr/data/imaug/pg_process.py
View file @
a48dac50
...
@@ -88,7 +88,7 @@ class PGProcessTrain(object):
...
@@ -88,7 +88,7 @@ class PGProcessTrain(object):
return
min_area_quad
return
min_area_quad
def
check_and_validate_polys
(
self
,
polys
,
tags
,
xxx_todo_changem
e
):
def
check_and_validate_polys
(
self
,
polys
,
tags
,
im_siz
e
):
"""
"""
check so that the text poly is in the same direction,
check so that the text poly is in the same direction,
and also filter some invalid polygons
and also filter some invalid polygons
...
@@ -96,7 +96,7 @@ class PGProcessTrain(object):
...
@@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags:
:param tags:
:return:
:return:
"""
"""
(
h
,
w
)
=
xxx_todo_changem
e
(
h
,
w
)
=
im_siz
e
if
polys
.
shape
[
0
]
==
0
:
if
polys
.
shape
[
0
]
==
0
:
return
polys
,
np
.
array
([]),
np
.
array
([])
return
polys
,
np
.
array
([]),
np
.
array
([])
polys
[:,
:,
0
]
=
np
.
clip
(
polys
[:,
:,
0
],
0
,
w
-
1
)
polys
[:,
:,
0
]
=
np
.
clip
(
polys
[:,
:,
0
],
0
,
w
-
1
)
...
@@ -750,8 +750,8 @@ class PGProcessTrain(object):
...
@@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size
=
512
input_size
=
512
im
=
data
[
'image'
]
im
=
data
[
'image'
]
text_polys
=
data
[
'polys'
]
text_polys
=
data
[
'polys'
]
text_tags
=
data
[
'tags'
]
text_tags
=
data
[
'
ignore_
tags'
]
text_strs
=
data
[
'
str
s'
]
text_strs
=
data
[
'
text
s'
]
h
,
w
,
_
=
im
.
shape
h
,
w
,
_
=
im
.
shape
text_polys
,
text_tags
,
hv_tags
=
self
.
check_and_validate_polys
(
text_polys
,
text_tags
,
hv_tags
=
self
.
check_and_validate_polys
(
text_polys
,
text_tags
,
(
h
,
w
))
text_polys
,
text_tags
,
(
h
,
w
))
...
...
ppocr/data/pgnet_dataset.py
View file @
a48dac50
...
@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
...
@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config
=
config
[
mode
][
'dataset'
]
dataset_config
=
config
[
mode
][
'dataset'
]
loader_config
=
config
[
mode
][
'loader'
]
loader_config
=
config
[
mode
][
'loader'
]
self
.
delimiter
=
dataset_config
.
get
(
'delimiter'
,
'
\t
'
)
label_file_list
=
dataset_config
.
pop
(
'label_file_list'
)
label_file_list
=
dataset_config
.
pop
(
'label_file_list'
)
data_source_num
=
len
(
label_file_list
)
data_source_num
=
len
(
label_file_list
)
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
if
isinstance
(
ratio_list
,
(
float
,
int
)):
if
isinstance
(
ratio_list
,
(
float
,
int
)):
ratio_list
=
[
float
(
ratio_list
)]
*
int
(
data_source_num
)
ratio_list
=
[
float
(
ratio_list
)]
*
int
(
data_source_num
)
self
.
data_format
=
dataset_config
.
get
(
'data_format'
,
'icdar'
)
assert
len
(
assert
len
(
ratio_list
ratio_list
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
,
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
self
.
data_format
)
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
if
mode
.
lower
()
==
"train"
:
if
mode
.
lower
()
==
"train"
:
self
.
shuffle_data_random
()
self
.
shuffle_data_random
()
...
@@ -55,108 +55,40 @@ class PGDataSet(Dataset):
...
@@ -55,108 +55,40 @@ class PGDataSet(Dataset):
random
.
shuffle
(
self
.
data_lines
)
random
.
shuffle
(
self
.
data_lines
)
return
return
def
extract_polys
(
self
,
poly_txt_path
):
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys
,
txt_tags
,
txts
=
[],
[],
[]
with
open
(
poly_txt_path
)
as
f
:
for
line
in
f
.
readlines
():
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly
=
list
(
map
(
float
,
poly_str
.
split
(
','
)))
text_polys
.
append
(
np
.
array
(
poly
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
txts
.
append
(
txt
)
txt_tags
.
append
(
txt
==
'###'
)
return
np
.
array
(
list
(
map
(
np
.
array
,
text_polys
))),
\
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
),
txts
def
extract_info_textnet
(
self
,
im_fn
,
img_dir
=
''
):
"""
Extract information from line in textnet format.
"""
info_list
=
im_fn
.
split
(
'
\t
'
)
img_path
=
''
for
ext
in
[
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'JPG'
]:
if
os
.
path
.
exists
(
os
.
path
.
join
(
img_dir
,
info_list
[
0
]
+
"."
+
ext
)):
img_path
=
os
.
path
.
join
(
img_dir
,
info_list
[
0
]
+
"."
+
ext
)
break
if
img_path
==
''
:
print
(
'Image {0} NOT found in {1}, and it will be ignored.'
.
format
(
info_list
[
0
],
img_dir
))
nBox
=
(
len
(
info_list
)
-
1
)
//
9
wordBBs
,
txts
,
txt_tags
=
[],
[],
[]
for
n
in
range
(
0
,
nBox
):
wordBB
=
list
(
map
(
float
,
info_list
[
n
*
9
+
1
:(
n
+
1
)
*
9
]))
txt
=
info_list
[(
n
+
1
)
*
9
]
wordBBs
.
append
([[
wordBB
[
0
],
wordBB
[
1
]],
[
wordBB
[
2
],
wordBB
[
3
]],
[
wordBB
[
4
],
wordBB
[
5
]],
[
wordBB
[
6
],
wordBB
[
7
]]])
txts
.
append
(
txt
)
if
txt
==
'###'
:
txt_tags
.
append
(
True
)
else
:
txt_tags
.
append
(
False
)
return
img_path
,
np
.
array
(
wordBBs
,
dtype
=
np
.
float32
),
txt_tags
,
txts
def
get_image_info_list
(
self
,
file_list
,
ratio_list
,
data_format
=
'textnet'
):
if
isinstance
(
file_list
,
str
):
if
isinstance
(
file_list
,
str
):
file_list
=
[
file_list
]
file_list
=
[
file_list
]
data_lines
=
[]
data_lines
=
[]
for
idx
,
data_source
in
enumerate
(
file_list
):
for
idx
,
file
in
enumerate
(
file_list
):
image_files
=
[]
with
open
(
file
,
"rb"
)
as
f
:
if
data_format
==
'icdar'
:
lines
=
f
.
readlines
()
image_files
=
[(
data_source
,
x
)
for
x
in
if
self
.
mode
==
"train"
or
ratio_list
[
idx
]
<
1.0
:
os
.
listdir
(
os
.
path
.
join
(
data_source
,
'rgb'
))
if
x
.
split
(
'.'
)[
-
1
]
in
[
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'JPG'
]]
elif
data_format
==
'textnet'
:
with
open
(
data_source
)
as
f
:
image_files
=
[(
data_source
,
x
.
strip
())
for
x
in
f
.
readlines
()]
else
:
print
(
"Unrecognized data format..."
)
exit
(
-
1
)
random
.
seed
(
self
.
seed
)
random
.
seed
(
self
.
seed
)
image_fil
es
=
random
.
sample
(
lin
es
=
random
.
sample
(
lines
,
image_files
,
round
(
len
(
image_fil
es
)
*
ratio_list
[
idx
]))
round
(
len
(
lin
es
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
image_fil
es
)
data_lines
.
extend
(
lin
es
)
return
data_lines
return
data_lines
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
file_idx
=
self
.
data_idx_order_list
[
idx
]
file_idx
=
self
.
data_idx_order_list
[
idx
]
data_path
,
data_line
=
self
.
data_lines
[
file_idx
]
data_line
=
self
.
data_lines
[
file_idx
]
try
:
try
:
if
self
.
data_format
==
'icdar'
:
data_line
=
data_line
.
decode
(
'utf-8'
)
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
substr
=
data_line
.
strip
(
"
\n
"
).
split
(
self
.
delimiter
)
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
file_name
=
substr
[
0
]
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
label
=
substr
[
1
]
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
if
self
.
mode
.
lower
()
==
'eval'
:
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
7
:])
else
:
else
:
image_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
data_path
),
'image'
)
img_id
=
0
im_path
,
text_polys
,
text_tags
,
text_strs
=
self
.
extract_info_textnet
(
data
=
{
'img_path'
:
img_path
,
'label'
:
label
,
'img_id'
:
img_id
}
data_line
,
image_dir
)
if
not
os
.
path
.
exists
(
img_path
):
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
3
:])
raise
Exception
(
"{} does not exist!"
.
format
(
img_path
))
data
=
{
'img_path'
:
im_path
,
'polys'
:
text_polys
,
'tags'
:
text_tags
,
'strs'
:
text_strs
,
'img_id'
:
img_id
}
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
img
=
f
.
read
()
img
=
f
.
read
()
data
[
'image'
]
=
img
data
[
'image'
]
=
img
outs
=
transform
(
data
,
self
.
ops
)
outs
=
transform
(
data
,
self
.
ops
)
except
Exception
as
e
:
except
Exception
as
e
:
self
.
logger
.
error
(
self
.
logger
.
error
(
"When parsing line {}, error happened with msg: {}"
.
format
(
"When parsing line {}, error happened with msg: {}"
.
format
(
...
...
ppocr/metrics/e2e_metric.py
View file @
a48dac50
...
@@ -35,11 +35,11 @@ class E2EMetric(object):
...
@@ -35,11 +35,11 @@ class E2EMetric(object):
self
.
reset
()
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
img_id
=
batch
[
5
][
0
]
img_id
=
batch
[
2
][
0
]
e2e_info_list
=
[{
e2e_info_list
=
[{
'points'
:
det_polyon
,
'points'
:
det_polyon
,
'text'
:
pred_str
'text
s
'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'
str
s'
])]
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'
text
s'
])]
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
self
.
results
.
append
(
result
)
self
.
results
.
append
(
result
)
...
...
ppocr/postprocess/rec_postprocess.py
View file @
a48dac50
...
@@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
...
@@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
'ch'
,
'en'
,
'EN_symbol'
,
'french'
,
'german'
,
'japan'
,
'korean'
,
'ch'
,
'en'
,
'EN_symbol'
,
'french'
,
'german'
,
'japan'
,
'korean'
,
'it'
,
'xi'
,
'pu'
,
'ru'
,
'ar'
,
'ta'
,
'ug'
,
'fa'
,
'ur'
,
'rs'
,
'oc'
,
'it'
,
'xi'
,
'pu'
,
'ru'
,
'ar'
,
'ta'
,
'ug'
,
'fa'
,
'ur'
,
'rs'
,
'oc'
,
'rsc'
,
'bg'
,
'uk'
,
'be'
,
'te'
,
'ka'
,
'chinese_cht'
,
'hi'
,
'mr'
,
'rsc'
,
'bg'
,
'uk'
,
'be'
,
'te'
,
'ka'
,
'chinese_cht'
,
'hi'
,
'mr'
,
'ne'
,
'EN'
'ne'
,
'EN'
,
'latin'
,
'arabic'
,
'cyrillic'
,
'devanagari'
]
]
assert
character_type
in
support_character_type
,
"Only {} are supported now but get {}"
.
format
(
assert
character_type
in
support_character_type
,
"Only {} are supported now but get {}"
.
format
(
support_character_type
,
character_type
)
support_character_type
,
character_type
)
...
...
ppocr/utils/e2e_metric/Deteval.py
View file @
a48dac50
...
@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
...
@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n
=
len
(
pred_dict
)
n
=
len
(
pred_dict
)
for
i
in
range
(
n
):
for
i
in
range
(
n
):
points
=
pred_dict
[
i
][
'points'
]
points
=
pred_dict
[
i
][
'points'
]
text
=
pred_dict
[
i
][
'text'
]
text
=
pred_dict
[
i
][
'text
s
'
]
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
det
.
append
([
point
,
text
])
det
.
append
([
point
,
text
])
return
det
return
det
...
...
ppocr/utils/e2e_utils/extract_textpoint_fast.py
View file @
a48dac50
...
@@ -21,6 +21,7 @@ import math
...
@@ -21,6 +21,7 @@ import math
import
numpy
as
np
import
numpy
as
np
from
itertools
import
groupby
from
itertools
import
groupby
from
cv2.ximgproc
import
thinning
as
thin
from
skimage.morphology._skeletonize
import
thin
from
skimage.morphology._skeletonize
import
thin
...
...
ppocr/utils/e2e_utils/pgnet_pp_utils.py
View file @
a48dac50
...
@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
...
@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w
,
src_h
,
self
.
valid_set
)
src_w
,
src_h
,
self
.
valid_set
)
data
=
{
data
=
{
'points'
:
poly_list
,
'points'
:
poly_list
,
'
str
s'
:
keep_str_list
,
'
text
s'
:
keep_str_list
,
}
}
return
data
return
data
...
@@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
...
@@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
exit
(
-
1
)
exit
(
-
1
)
data
=
{
data
=
{
'points'
:
poly_list
,
'points'
:
poly_list
,
'
str
s'
:
keep_str_list
,
'
text
s'
:
keep_str_list
,
}
}
return
data
return
data
tools/infer/predict_e2e.py
View file @
a48dac50
...
@@ -122,7 +122,7 @@ class TextE2E(object):
...
@@ -122,7 +122,7 @@ class TextE2E(object):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
str
s'
]
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
text
s'
]
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
points
,
ori_im
.
shape
)
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
points
,
ori_im
.
shape
)
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
return
dt_boxes
,
strs
,
elapse
return
dt_boxes
,
strs
,
elapse
...
...
tools/infer_e2e.py
View file @
a48dac50
...
@@ -103,7 +103,7 @@ def main():
...
@@ -103,7 +103,7 @@ def main():
images
=
paddle
.
to_tensor
(
images
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
,
shape_list
)
post_result
=
post_process_class
(
preds
,
shape_list
)
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
str
s'
]
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
text
s'
]
# write resule
# write resule
dt_boxes_json
=
[]
dt_boxes_json
=
[]
for
poly
,
str
in
zip
(
points
,
strs
):
for
poly
,
str
in
zip
(
points
,
strs
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment